package engine import ( "net" "testing" "git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/ruleset" "github.com/bwmarrin/snowflake" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/google/gopacket/reassembly" ) type fixedRuleset struct { action ruleset.Action } func (r fixedRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer { return nil } func (r fixedRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult { return ruleset.MatchResult{Action: r.action} } type noopTestLogger struct{} func (noopTestLogger) WorkerStart(int) {} func (noopTestLogger) WorkerStop(int) {} func (noopTestLogger) TCPStreamNew(int, ruleset.StreamInfo) {} func (noopTestLogger) TCPStreamPropUpdate(ruleset.StreamInfo, bool) { } func (noopTestLogger) TCPStreamAction(ruleset.StreamInfo, ruleset.Action, bool) { } func (noopTestLogger) UDPStreamNew(int, ruleset.StreamInfo) {} func (noopTestLogger) UDPStreamPropUpdate(ruleset.StreamInfo, bool) { } func (noopTestLogger) UDPStreamAction(ruleset.StreamInfo, ruleset.Action, bool) { } func (noopTestLogger) ModifyError(ruleset.StreamInfo, error) {} func (noopTestLogger) AnalyzerDebugf(int64, string, string, ...interface{}) {} func (noopTestLogger) AnalyzerInfof(int64, string, string, ...interface{}) {} func (noopTestLogger) AnalyzerErrorf(int64, string, string, ...interface{}) {} func TestUDPStreamUsesUpdatedRuleset(t *testing.T) { node, err := snowflake.NewNode(0) if err != nil { t.Fatalf("create node: %v", err) } f := &udpStreamFactory{ WorkerID: 0, Logger: noopTestLogger{}, Node: node, Ruleset: fixedRuleset{action: ruleset.ActionAllow}, } ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) udp := &layers.UDP{ SrcPort: 12345, DstPort: 53, BaseLayer: layers.BaseLayer{ Payload: []byte("query"), }, } ctx := &udpContext{Verdict: udpVerdictAccept} s := f.New(ipFlow, udp.TransportFlow(), udp, ctx) if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil { t.Fatalf("update ruleset: %v", err) } if !s.Accept(udp, false, ctx) { t.Fatalf("unexpected Accept=false for virgin stream") } s.Feed(udp, false, ctx) if ctx.Verdict != udpVerdictDropStream { t.Fatalf("verdict=%v want=%v", ctx.Verdict, udpVerdictDropStream) } } func TestTCPStreamUsesUpdatedRuleset(t *testing.T) { node, err := snowflake.NewNode(0) if err != nil { t.Fatalf("create node: %v", err) } f := &tcpStreamFactory{ WorkerID: 0, Logger: noopTestLogger{}, Node: node, Ruleset: fixedRuleset{action: ruleset.ActionAllow}, } ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) tcp := &layers.TCP{ SrcPort: 12345, DstPort: 443, } ctx := &tcpContext{ PacketMetadata: &gopacket.PacketMetadata{}, Verdict: tcpVerdictAccept, } rs := f.New(ipFlow, tcp.TransportFlow(), tcp, ctx) s, ok := rs.(*tcpStream) if !ok { t.Fatalf("unexpected stream type %T", rs) } if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil { t.Fatalf("update ruleset: %v", err) } s.ReassembledSG(fakeScatterGather{data: []byte("payload")}, ctx) if ctx.Verdict != tcpVerdictDropStream { t.Fatalf("verdict=%v want=%v", ctx.Verdict, tcpVerdictDropStream) } } type fakeScatterGather struct { data []byte } func (s fakeScatterGather) Lengths() (int, int) { return len(s.data), 0 } func (s fakeScatterGather) Fetch(length int) []byte { if length < 0 { return nil } if length > len(s.data) { length = len(s.data) } return s.data[:length] } func (fakeScatterGather) KeepFrom(int) {} func (fakeScatterGather) CaptureInfo(int) gopacket.CaptureInfo { return gopacket.CaptureInfo{} } func (fakeScatterGather) Info() (reassembly.TCPFlowDirection, bool, bool, int) { return reassembly.TCPDirClientToServer, true, false, 0 } func (fakeScatterGather) Stats() reassembly.TCPAssemblyStats { return reassembly.TCPAssemblyStats{} }