From e1c68ec7d09b6ea348ff2f7536ac32aacf4215eb Mon Sep 17 00:00:00 2001 From: hayzam Date: Thu, 12 Feb 2026 13:34:37 +0530 Subject: [PATCH] ruleset: try to fix reloader --- engine/reload_rules_test.go | 119 ++++++++++++++++++++++++++++++++++++ engine/tcp.go | 61 +++++++++++------- engine/udp.go | 61 +++++++++++------- 3 files changed, 195 insertions(+), 46 deletions(-) diff --git a/engine/reload_rules_test.go b/engine/reload_rules_test.go index 7ef75a5..fa3fdea 100644 --- a/engine/reload_rules_test.go +++ b/engine/reload_rules_test.go @@ -84,6 +84,59 @@ func TestUDPStreamUsesUpdatedRuleset(t *testing.T) { } } +func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) { + node, err := snowflake.NewNode(0) + if err != nil { + t.Fatalf("create node: %v", err) + } + f := &udpStreamFactory{ + WorkerID: 0, + Logger: noopTestLogger{}, + Node: node, + Ruleset: fixedRuleset{action: ruleset.ActionAllow}, + } + + ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) + udp := &layers.UDP{ + SrcPort: 12345, + DstPort: 53, + BaseLayer: layers.BaseLayer{ + Payload: []byte("query"), + }, + } + + ctx1 := &udpContext{Verdict: udpVerdictAccept} + s := f.New(ipFlow, udp.TransportFlow(), udp, ctx1) + if !s.Accept(udp, false, ctx1) { + t.Fatalf("unexpected Accept=false before first feed") + } + s.Feed(udp, false, ctx1) + if ctx1.Verdict != udpVerdictAcceptStream { + t.Fatalf("verdict=%v want=%v", ctx1.Verdict, udpVerdictAcceptStream) + } + + if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil { + t.Fatalf("update ruleset: %v", err) + } + + ctx2 := &udpContext{Verdict: udpVerdictAccept} + if !s.Accept(udp, false, ctx2) { + t.Fatalf("expected Accept=true after ruleset update") + } + s.Feed(udp, false, ctx2) + if ctx2.Verdict != udpVerdictDropStream { + t.Fatalf("verdict=%v want=%v", ctx2.Verdict, udpVerdictDropStream) + } + + ctx3 := &udpContext{Verdict: udpVerdictAccept} + if s.Accept(udp, false, ctx3) { + t.Fatalf("expected Accept=false with unchanged ruleset and no active entries") + } + if ctx3.Verdict != udpVerdictDropStream { + t.Fatalf("verdict=%v want=%v", ctx3.Verdict, udpVerdictDropStream) + } +} + func TestTCPStreamUsesUpdatedRuleset(t *testing.T) { node, err := snowflake.NewNode(0) if err != nil { @@ -121,6 +174,72 @@ func TestTCPStreamUsesUpdatedRuleset(t *testing.T) { } } +func TestTCPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) { + node, err := snowflake.NewNode(0) + if err != nil { + t.Fatalf("create node: %v", err) + } + f := &tcpStreamFactory{ + WorkerID: 0, + Logger: noopTestLogger{}, + Node: node, + Ruleset: fixedRuleset{action: ruleset.ActionAllow}, + } + + ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) + tcp := &layers.TCP{ + SrcPort: 12345, + DstPort: 443, + } + ctx1 := &tcpContext{ + PacketMetadata: &gopacket.PacketMetadata{}, + Verdict: tcpVerdictAccept, + } + rs := f.New(ipFlow, tcp.TransportFlow(), tcp, ctx1) + s, ok := rs.(*tcpStream) + if !ok { + t.Fatalf("unexpected stream type %T", rs) + } + + start1 := false + if !s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start1, ctx1) { + t.Fatalf("unexpected Accept=false before first feed") + } + s.ReassembledSG(fakeScatterGather{data: []byte("first")}, ctx1) + if ctx1.Verdict != tcpVerdictAcceptStream { + t.Fatalf("verdict=%v want=%v", ctx1.Verdict, tcpVerdictAcceptStream) + } + + if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil { + t.Fatalf("update ruleset: %v", err) + } + + ctx2 := &tcpContext{ + PacketMetadata: &gopacket.PacketMetadata{}, + Verdict: tcpVerdictAccept, + } + start2 := false + if !s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start2, ctx2) { + t.Fatalf("expected Accept=true after ruleset update") + } + s.ReassembledSG(fakeScatterGather{data: []byte("second")}, ctx2) + if ctx2.Verdict != tcpVerdictDropStream { + t.Fatalf("verdict=%v want=%v", ctx2.Verdict, tcpVerdictDropStream) + } + + ctx3 := &tcpContext{ + PacketMetadata: &gopacket.PacketMetadata{}, + Verdict: tcpVerdictAccept, + } + start3 := false + if s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start3, ctx3) { + t.Fatalf("expected Accept=false with unchanged ruleset and no active entries") + } + if ctx3.Verdict != tcpVerdictDropStream { + t.Fatalf("verdict=%v want=%v", ctx3.Verdict, tcpVerdictDropStream) + } +} + type fakeScatterGather struct { data []byte } diff --git a/engine/tcp.go b/engine/tcp.go index 5f9c171..4bb8c8c 100644 --- a/engine/tcp.go +++ b/engine/tcp.go @@ -40,8 +40,9 @@ type tcpStreamFactory struct { Logger Logger Node *snowflake.Node - RulesetMutex sync.RWMutex - Ruleset ruleset.Ruleset + RulesetMutex sync.RWMutex + Ruleset ruleset.Ruleset + RulesetVersion uint64 } func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream { @@ -60,8 +61,11 @@ func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, a Props: make(analyzer.CombinedPropMap), } f.Logger.TCPStreamNew(f.WorkerID, info) - rs := f.currentRuleset() - ans := analyzersToTCPAnalyzers(rs.Analyzers(info)) + rs, version := f.currentRuleset() + var ans []analyzer.TCPAnalyzer + if rs != nil { + ans = analyzersToTCPAnalyzers(rs.Analyzers(info)) + } // Create entries for each analyzer entries := make([]*tcpStreamEntry, 0, len(ans)) for _, a := range ans { @@ -82,11 +86,12 @@ func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, a }) } return &tcpStream{ - info: info, - virgin: true, - logger: f.Logger, - rulesetSource: f.currentRuleset, - activeEntries: entries, + info: info, + virgin: true, + logger: f.Logger, + rulesetVersion: version, + rulesetSource: f.currentRuleset, + activeEntries: entries, } } @@ -94,23 +99,25 @@ func (f *tcpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error { f.RulesetMutex.Lock() defer f.RulesetMutex.Unlock() f.Ruleset = r + f.RulesetVersion++ return nil } -func (f *tcpStreamFactory) currentRuleset() ruleset.Ruleset { +func (f *tcpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) { f.RulesetMutex.RLock() defer f.RulesetMutex.RUnlock() - return f.Ruleset + return f.Ruleset, f.RulesetVersion } type tcpStream struct { - info ruleset.StreamInfo - virgin bool // true if no packets have been processed - logger Logger - rulesetSource func() ruleset.Ruleset - activeEntries []*tcpStreamEntry - doneEntries []*tcpStreamEntry - lastVerdict tcpVerdict + info ruleset.StreamInfo + virgin bool // true if no packets have been processed + logger Logger + rulesetVersion uint64 + rulesetSource func() (ruleset.Ruleset, uint64) + activeEntries []*tcpStreamEntry + doneEntries []*tcpStreamEntry + lastVerdict tcpVerdict } type tcpStreamEntry struct { @@ -121,7 +128,7 @@ type tcpStreamEntry struct { } func (s *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool { - if len(s.activeEntries) > 0 || s.virgin { + if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() { // Make sure every stream matches against the ruleset at least once, // even if there are no activeEntries, as the ruleset may have built-in // properties that need to be matched. @@ -152,12 +159,15 @@ func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass } } ctx := ac.(*tcpContext) - if updated || s.virgin { + rs, version := s.currentRuleset() + rulesetChanged := version != s.rulesetVersion + s.rulesetVersion = version + if updated || s.virgin || rulesetChanged { s.virgin = false s.logger.TCPStreamPropUpdate(s.info, false) // Match properties against ruleset result := ruleset.MatchResult{Action: ruleset.ActionMaybe} - if rs := s.currentRuleset(); rs != nil { + if rs != nil { result = rs.Match(s.info) } action := result.Action @@ -178,13 +188,18 @@ func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass } } -func (s *tcpStream) currentRuleset() ruleset.Ruleset { +func (s *tcpStream) currentRuleset() (ruleset.Ruleset, uint64) { if s.rulesetSource == nil { - return nil + return nil, s.rulesetVersion } return s.rulesetSource() } +func (s *tcpStream) rulesetChanged() bool { + _, version := s.currentRuleset() + return version != s.rulesetVersion +} + func (s *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { s.closeActiveEntries() return true diff --git a/engine/udp.go b/engine/udp.go index 2407d47..7668ebf 100644 --- a/engine/udp.go +++ b/engine/udp.go @@ -41,8 +41,9 @@ type udpStreamFactory struct { Logger Logger Node *snowflake.Node - RulesetMutex sync.RWMutex - Ruleset ruleset.Ruleset + RulesetMutex sync.RWMutex + Ruleset ruleset.Ruleset + RulesetVersion uint64 } func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) *udpStream { @@ -60,8 +61,11 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u Props: make(analyzer.CombinedPropMap), } f.Logger.UDPStreamNew(f.WorkerID, info) - rs := f.currentRuleset() - ans := analyzersToUDPAnalyzers(rs.Analyzers(info)) + rs, version := f.currentRuleset() + var ans []analyzer.UDPAnalyzer + if rs != nil { + ans = analyzersToUDPAnalyzers(rs.Analyzers(info)) + } // Create entries for each analyzer entries := make([]*udpStreamEntry, 0, len(ans)) for _, a := range ans { @@ -82,11 +86,12 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u }) } return &udpStream{ - info: info, - virgin: true, - logger: f.Logger, - rulesetSource: f.currentRuleset, - activeEntries: entries, + info: info, + virgin: true, + logger: f.Logger, + rulesetVersion: version, + rulesetSource: f.currentRuleset, + activeEntries: entries, } } @@ -94,13 +99,14 @@ func (f *udpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error { f.RulesetMutex.Lock() defer f.RulesetMutex.Unlock() f.Ruleset = r + f.RulesetVersion++ return nil } -func (f *udpStreamFactory) currentRuleset() ruleset.Ruleset { +func (f *udpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) { f.RulesetMutex.RLock() defer f.RulesetMutex.RUnlock() - return f.Ruleset + return f.Ruleset, f.RulesetVersion } type udpStreamManager struct { @@ -187,13 +193,14 @@ func (m *udpStreamManager) findByFlow(ipFlow, udpFlow gopacket.Flow) (key uint32 } type udpStream struct { - info ruleset.StreamInfo - virgin bool // true if no packets have been processed - logger Logger - rulesetSource func() ruleset.Ruleset - activeEntries []*udpStreamEntry - doneEntries []*udpStreamEntry - lastVerdict udpVerdict + info ruleset.StreamInfo + virgin bool // true if no packets have been processed + logger Logger + rulesetVersion uint64 + rulesetSource func() (ruleset.Ruleset, uint64) + activeEntries []*udpStreamEntry + doneEntries []*udpStreamEntry + lastVerdict udpVerdict } type udpStreamEntry struct { @@ -204,7 +211,7 @@ type udpStreamEntry struct { } func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool { - if len(s.activeEntries) > 0 || s.virgin { + if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() { // Make sure every stream matches against the ruleset at least once, // even if there are no activeEntries, as the ruleset may have built-in // properties that need to be matched. @@ -229,12 +236,15 @@ func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) { s.doneEntries = append(s.doneEntries, entry) } } - if updated || s.virgin { + rs, version := s.currentRuleset() + rulesetChanged := version != s.rulesetVersion + s.rulesetVersion = version + if updated || s.virgin || rulesetChanged { s.virgin = false s.logger.UDPStreamPropUpdate(s.info, false) // Match properties against ruleset result := ruleset.MatchResult{Action: ruleset.ActionMaybe} - if rs := s.currentRuleset(); rs != nil { + if rs != nil { result = rs.Match(s.info) } action := result.Action @@ -273,13 +283,18 @@ func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) { } } -func (s *udpStream) currentRuleset() ruleset.Ruleset { +func (s *udpStream) currentRuleset() (ruleset.Ruleset, uint64) { if s.rulesetSource == nil { - return nil + return nil, s.rulesetVersion } return s.rulesetSource() } +func (s *udpStream) rulesetChanged() bool { + _, version := s.currentRuleset() + return version != s.rulesetVersion +} + func (s *udpStream) Close() { s.closeActiveEntries() }