Files
Mellaris/engine/udp_manager_tuple_test.go
T

69 lines
2.0 KiB
Go

package engine
import (
"sync/atomic"
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
)
type countingRuleset struct {
ans []analyzer.Analyzer
}
func (r countingRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer { return r.ans }
func (r countingRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult {
return ruleset.MatchResult{Action: ruleset.ActionMaybe}
}
type countingUDPAnalyzer struct{ newCalls *atomic.Uint64 }
func (a countingUDPAnalyzer) Name() string { return "countudp" }
func (a countingUDPAnalyzer) Limit() int { return 0 }
func (a countingUDPAnalyzer) NewUDP(analyzer.UDPInfo, analyzer.Logger) analyzer.UDPStream {
a.newCalls.Add(1)
return countingUDPStream{}
}
type countingUDPStream struct{}
func (countingUDPStream) Feed(bool, []byte) (*analyzer.PropUpdate, bool) { return nil, false }
func (countingUDPStream) Close(bool) *analyzer.PropUpdate { return nil }
func TestUDPStreamManagerRebindsByTupleInO1Path(t *testing.T) {
node, err := snowflake.NewNode(0)
if err != nil {
t.Fatalf("create node: %v", err)
}
var newCalls atomic.Uint64
rs := countingRuleset{ans: []analyzer.Analyzer{countingUDPAnalyzer{newCalls: &newCalls}}}
factory := &udpStreamFactory{
WorkerID: 0,
Logger: noopTestLogger{},
Node: node,
Ruleset: rs,
}
mgr, err := newUDPStreamManager(factory, 64, &statsCounters{})
if err != nil {
t.Fatalf("new manager: %v", err)
}
tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 50000, BPort: 443}
payload := []byte{0x01, 0x00, 0x00, 0x00}
ctx1 := &udpContext{Verdict: udpVerdictAccept}
mgr.MatchWithContext(100, tuple, false, payload, ctx1)
if got := newCalls.Load(); got != 1 {
t.Fatalf("new stream calls=%d want=1", got)
}
ctx2 := &udpContext{Verdict: udpVerdictAccept}
mgr.MatchWithContext(200, tuple, false, payload, ctx2)
if got := newCalls.Load(); got != 1 {
t.Fatalf("expected stream reuse by tuple, new stream calls=%d want=1", got)
}
}