diff --git a/engine/reload_rules_test.go b/engine/reload_rules_test.go new file mode 100644 index 0000000..7ef75a5 --- /dev/null +++ b/engine/reload_rules_test.go @@ -0,0 +1,154 @@ +package engine + +import ( + "net" + "testing" + + "git.difuse.io/Difuse/Mellaris/analyzer" + "git.difuse.io/Difuse/Mellaris/ruleset" + + "github.com/bwmarrin/snowflake" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/reassembly" +) + +type fixedRuleset struct { + action ruleset.Action +} + +func (r fixedRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer { + return nil +} + +func (r fixedRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult { + return ruleset.MatchResult{Action: r.action} +} + +type noopTestLogger struct{} + +func (noopTestLogger) WorkerStart(int) {} +func (noopTestLogger) WorkerStop(int) {} + +func (noopTestLogger) TCPStreamNew(int, ruleset.StreamInfo) {} +func (noopTestLogger) TCPStreamPropUpdate(ruleset.StreamInfo, bool) { +} +func (noopTestLogger) TCPStreamAction(ruleset.StreamInfo, ruleset.Action, bool) { +} + +func (noopTestLogger) UDPStreamNew(int, ruleset.StreamInfo) {} +func (noopTestLogger) UDPStreamPropUpdate(ruleset.StreamInfo, bool) { +} +func (noopTestLogger) UDPStreamAction(ruleset.StreamInfo, ruleset.Action, bool) { +} + +func (noopTestLogger) ModifyError(ruleset.StreamInfo, error) {} + +func (noopTestLogger) AnalyzerDebugf(int64, string, string, ...interface{}) {} +func (noopTestLogger) AnalyzerInfof(int64, string, string, ...interface{}) {} +func (noopTestLogger) AnalyzerErrorf(int64, string, string, ...interface{}) {} + +func TestUDPStreamUsesUpdatedRuleset(t *testing.T) { + node, err := snowflake.NewNode(0) + if err != nil { + t.Fatalf("create node: %v", err) + } + f := &udpStreamFactory{ + WorkerID: 0, + Logger: noopTestLogger{}, + Node: node, + Ruleset: fixedRuleset{action: ruleset.ActionAllow}, + } + + ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) + udp := &layers.UDP{ + SrcPort: 12345, + DstPort: 53, + BaseLayer: layers.BaseLayer{ + Payload: []byte("query"), + }, + } + ctx := &udpContext{Verdict: udpVerdictAccept} + s := f.New(ipFlow, udp.TransportFlow(), udp, ctx) + + if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil { + t.Fatalf("update ruleset: %v", err) + } + + if !s.Accept(udp, false, ctx) { + t.Fatalf("unexpected Accept=false for virgin stream") + } + s.Feed(udp, false, ctx) + if ctx.Verdict != udpVerdictDropStream { + t.Fatalf("verdict=%v want=%v", ctx.Verdict, udpVerdictDropStream) + } +} + +func TestTCPStreamUsesUpdatedRuleset(t *testing.T) { + node, err := snowflake.NewNode(0) + if err != nil { + t.Fatalf("create node: %v", err) + } + f := &tcpStreamFactory{ + WorkerID: 0, + Logger: noopTestLogger{}, + Node: node, + Ruleset: fixedRuleset{action: ruleset.ActionAllow}, + } + + ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) + tcp := &layers.TCP{ + SrcPort: 12345, + DstPort: 443, + } + ctx := &tcpContext{ + PacketMetadata: &gopacket.PacketMetadata{}, + Verdict: tcpVerdictAccept, + } + rs := f.New(ipFlow, tcp.TransportFlow(), tcp, ctx) + s, ok := rs.(*tcpStream) + if !ok { + t.Fatalf("unexpected stream type %T", rs) + } + + if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil { + t.Fatalf("update ruleset: %v", err) + } + + s.ReassembledSG(fakeScatterGather{data: []byte("payload")}, ctx) + if ctx.Verdict != tcpVerdictDropStream { + t.Fatalf("verdict=%v want=%v", ctx.Verdict, tcpVerdictDropStream) + } +} + +type fakeScatterGather struct { + data []byte +} + +func (s fakeScatterGather) Lengths() (int, int) { + return len(s.data), 0 +} + +func (s fakeScatterGather) Fetch(length int) []byte { + if length < 0 { + return nil + } + if length > len(s.data) { + length = len(s.data) + } + return s.data[:length] +} + +func (fakeScatterGather) KeepFrom(int) {} + +func (fakeScatterGather) CaptureInfo(int) gopacket.CaptureInfo { + return gopacket.CaptureInfo{} +} + +func (fakeScatterGather) Info() (reassembly.TCPFlowDirection, bool, bool, int) { + return reassembly.TCPDirClientToServer, true, false, 0 +} + +func (fakeScatterGather) Stats() reassembly.TCPAssemblyStats { + return reassembly.TCPAssemblyStats{} +} diff --git a/engine/tcp.go b/engine/tcp.go index 52b8f3d..5f9c171 100644 --- a/engine/tcp.go +++ b/engine/tcp.go @@ -60,9 +60,7 @@ func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, a Props: make(analyzer.CombinedPropMap), } f.Logger.TCPStreamNew(f.WorkerID, info) - f.RulesetMutex.RLock() - rs := f.Ruleset - f.RulesetMutex.RUnlock() + rs := f.currentRuleset() ans := analyzersToTCPAnalyzers(rs.Analyzers(info)) // Create entries for each analyzer entries := make([]*tcpStreamEntry, 0, len(ans)) @@ -87,7 +85,7 @@ func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, a info: info, virgin: true, logger: f.Logger, - ruleset: rs, + rulesetSource: f.currentRuleset, activeEntries: entries, } } @@ -99,11 +97,17 @@ func (f *tcpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error { return nil } +func (f *tcpStreamFactory) currentRuleset() ruleset.Ruleset { + f.RulesetMutex.RLock() + defer f.RulesetMutex.RUnlock() + return f.Ruleset +} + type tcpStream struct { info ruleset.StreamInfo virgin bool // true if no packets have been processed logger Logger - ruleset ruleset.Ruleset + rulesetSource func() ruleset.Ruleset activeEntries []*tcpStreamEntry doneEntries []*tcpStreamEntry lastVerdict tcpVerdict @@ -152,7 +156,10 @@ func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass s.virgin = false s.logger.TCPStreamPropUpdate(s.info, false) // Match properties against ruleset - result := s.ruleset.Match(s.info) + result := ruleset.MatchResult{Action: ruleset.ActionMaybe} + if rs := s.currentRuleset(); rs != nil { + result = rs.Match(s.info) + } action := result.Action if action != ruleset.ActionMaybe && action != ruleset.ActionModify { verdict := actionToTCPVerdict(action) @@ -171,6 +178,13 @@ func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass } } +func (s *tcpStream) currentRuleset() ruleset.Ruleset { + if s.rulesetSource == nil { + return nil + } + return s.rulesetSource() +} + func (s *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { s.closeActiveEntries() return true diff --git a/engine/udp.go b/engine/udp.go index 6f3b6e9..2407d47 100644 --- a/engine/udp.go +++ b/engine/udp.go @@ -60,9 +60,7 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u Props: make(analyzer.CombinedPropMap), } f.Logger.UDPStreamNew(f.WorkerID, info) - f.RulesetMutex.RLock() - rs := f.Ruleset - f.RulesetMutex.RUnlock() + rs := f.currentRuleset() ans := analyzersToUDPAnalyzers(rs.Analyzers(info)) // Create entries for each analyzer entries := make([]*udpStreamEntry, 0, len(ans)) @@ -87,7 +85,7 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u info: info, virgin: true, logger: f.Logger, - ruleset: rs, + rulesetSource: f.currentRuleset, activeEntries: entries, } } @@ -99,6 +97,12 @@ func (f *udpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error { return nil } +func (f *udpStreamFactory) currentRuleset() ruleset.Ruleset { + f.RulesetMutex.RLock() + defer f.RulesetMutex.RUnlock() + return f.Ruleset +} + type udpStreamManager struct { factory *udpStreamFactory streams *lru.Cache[uint32, *udpStreamValue] @@ -186,7 +190,7 @@ type udpStream struct { info ruleset.StreamInfo virgin bool // true if no packets have been processed logger Logger - ruleset ruleset.Ruleset + rulesetSource func() ruleset.Ruleset activeEntries []*udpStreamEntry doneEntries []*udpStreamEntry lastVerdict udpVerdict @@ -229,7 +233,10 @@ func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) { s.virgin = false s.logger.UDPStreamPropUpdate(s.info, false) // Match properties against ruleset - result := s.ruleset.Match(s.info) + result := ruleset.MatchResult{Action: ruleset.ActionMaybe} + if rs := s.currentRuleset(); rs != nil { + result = rs.Match(s.info) + } action := result.Action if action == ruleset.ActionModify { // Call the modifier instance @@ -266,6 +273,13 @@ func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) { } } +func (s *udpStream) currentRuleset() ruleset.Ruleset { + if s.rulesetSource == nil { + return nil + } + return s.rulesetSource() +} + func (s *udpStream) Close() { s.closeActiveEntries() }