ruleset: try to fix reloader

This commit is contained in:
2026-02-12 13:34:37 +05:30
parent beaaddad2b
commit e1c68ec7d0
3 changed files with 195 additions and 46 deletions

View File

@@ -84,6 +84,59 @@ func TestUDPStreamUsesUpdatedRuleset(t *testing.T) {
} }
} }
func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
node, err := snowflake.NewNode(0)
if err != nil {
t.Fatalf("create node: %v", err)
}
f := &udpStreamFactory{
WorkerID: 0,
Logger: noopTestLogger{},
Node: node,
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
}
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
udp := &layers.UDP{
SrcPort: 12345,
DstPort: 53,
BaseLayer: layers.BaseLayer{
Payload: []byte("query"),
},
}
ctx1 := &udpContext{Verdict: udpVerdictAccept}
s := f.New(ipFlow, udp.TransportFlow(), udp, ctx1)
if !s.Accept(udp, false, ctx1) {
t.Fatalf("unexpected Accept=false before first feed")
}
s.Feed(udp, false, ctx1)
if ctx1.Verdict != udpVerdictAcceptStream {
t.Fatalf("verdict=%v want=%v", ctx1.Verdict, udpVerdictAcceptStream)
}
if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil {
t.Fatalf("update ruleset: %v", err)
}
ctx2 := &udpContext{Verdict: udpVerdictAccept}
if !s.Accept(udp, false, ctx2) {
t.Fatalf("expected Accept=true after ruleset update")
}
s.Feed(udp, false, ctx2)
if ctx2.Verdict != udpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx2.Verdict, udpVerdictDropStream)
}
ctx3 := &udpContext{Verdict: udpVerdictAccept}
if s.Accept(udp, false, ctx3) {
t.Fatalf("expected Accept=false with unchanged ruleset and no active entries")
}
if ctx3.Verdict != udpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx3.Verdict, udpVerdictDropStream)
}
}
func TestTCPStreamUsesUpdatedRuleset(t *testing.T) { func TestTCPStreamUsesUpdatedRuleset(t *testing.T) {
node, err := snowflake.NewNode(0) node, err := snowflake.NewNode(0)
if err != nil { if err != nil {
@@ -121,6 +174,72 @@ func TestTCPStreamUsesUpdatedRuleset(t *testing.T) {
} }
} }
func TestTCPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
node, err := snowflake.NewNode(0)
if err != nil {
t.Fatalf("create node: %v", err)
}
f := &tcpStreamFactory{
WorkerID: 0,
Logger: noopTestLogger{},
Node: node,
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
}
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
tcp := &layers.TCP{
SrcPort: 12345,
DstPort: 443,
}
ctx1 := &tcpContext{
PacketMetadata: &gopacket.PacketMetadata{},
Verdict: tcpVerdictAccept,
}
rs := f.New(ipFlow, tcp.TransportFlow(), tcp, ctx1)
s, ok := rs.(*tcpStream)
if !ok {
t.Fatalf("unexpected stream type %T", rs)
}
start1 := false
if !s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start1, ctx1) {
t.Fatalf("unexpected Accept=false before first feed")
}
s.ReassembledSG(fakeScatterGather{data: []byte("first")}, ctx1)
if ctx1.Verdict != tcpVerdictAcceptStream {
t.Fatalf("verdict=%v want=%v", ctx1.Verdict, tcpVerdictAcceptStream)
}
if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil {
t.Fatalf("update ruleset: %v", err)
}
ctx2 := &tcpContext{
PacketMetadata: &gopacket.PacketMetadata{},
Verdict: tcpVerdictAccept,
}
start2 := false
if !s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start2, ctx2) {
t.Fatalf("expected Accept=true after ruleset update")
}
s.ReassembledSG(fakeScatterGather{data: []byte("second")}, ctx2)
if ctx2.Verdict != tcpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx2.Verdict, tcpVerdictDropStream)
}
ctx3 := &tcpContext{
PacketMetadata: &gopacket.PacketMetadata{},
Verdict: tcpVerdictAccept,
}
start3 := false
if s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start3, ctx3) {
t.Fatalf("expected Accept=false with unchanged ruleset and no active entries")
}
if ctx3.Verdict != tcpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx3.Verdict, tcpVerdictDropStream)
}
}
type fakeScatterGather struct { type fakeScatterGather struct {
data []byte data []byte
} }

View File

@@ -42,6 +42,7 @@ type tcpStreamFactory struct {
RulesetMutex sync.RWMutex RulesetMutex sync.RWMutex
Ruleset ruleset.Ruleset Ruleset ruleset.Ruleset
RulesetVersion uint64
} }
func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream { func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream {
@@ -60,8 +61,11 @@ func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, a
Props: make(analyzer.CombinedPropMap), Props: make(analyzer.CombinedPropMap),
} }
f.Logger.TCPStreamNew(f.WorkerID, info) f.Logger.TCPStreamNew(f.WorkerID, info)
rs := f.currentRuleset() rs, version := f.currentRuleset()
ans := analyzersToTCPAnalyzers(rs.Analyzers(info)) var ans []analyzer.TCPAnalyzer
if rs != nil {
ans = analyzersToTCPAnalyzers(rs.Analyzers(info))
}
// Create entries for each analyzer // Create entries for each analyzer
entries := make([]*tcpStreamEntry, 0, len(ans)) entries := make([]*tcpStreamEntry, 0, len(ans))
for _, a := range ans { for _, a := range ans {
@@ -85,6 +89,7 @@ func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, a
info: info, info: info,
virgin: true, virgin: true,
logger: f.Logger, logger: f.Logger,
rulesetVersion: version,
rulesetSource: f.currentRuleset, rulesetSource: f.currentRuleset,
activeEntries: entries, activeEntries: entries,
} }
@@ -94,20 +99,22 @@ func (f *tcpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error {
f.RulesetMutex.Lock() f.RulesetMutex.Lock()
defer f.RulesetMutex.Unlock() defer f.RulesetMutex.Unlock()
f.Ruleset = r f.Ruleset = r
f.RulesetVersion++
return nil return nil
} }
func (f *tcpStreamFactory) currentRuleset() ruleset.Ruleset { func (f *tcpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) {
f.RulesetMutex.RLock() f.RulesetMutex.RLock()
defer f.RulesetMutex.RUnlock() defer f.RulesetMutex.RUnlock()
return f.Ruleset return f.Ruleset, f.RulesetVersion
} }
type tcpStream struct { type tcpStream struct {
info ruleset.StreamInfo info ruleset.StreamInfo
virgin bool // true if no packets have been processed virgin bool // true if no packets have been processed
logger Logger logger Logger
rulesetSource func() ruleset.Ruleset rulesetVersion uint64
rulesetSource func() (ruleset.Ruleset, uint64)
activeEntries []*tcpStreamEntry activeEntries []*tcpStreamEntry
doneEntries []*tcpStreamEntry doneEntries []*tcpStreamEntry
lastVerdict tcpVerdict lastVerdict tcpVerdict
@@ -121,7 +128,7 @@ type tcpStreamEntry struct {
} }
func (s *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool { func (s *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool {
if len(s.activeEntries) > 0 || s.virgin { if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() {
// Make sure every stream matches against the ruleset at least once, // Make sure every stream matches against the ruleset at least once,
// even if there are no activeEntries, as the ruleset may have built-in // even if there are no activeEntries, as the ruleset may have built-in
// properties that need to be matched. // properties that need to be matched.
@@ -152,12 +159,15 @@ func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass
} }
} }
ctx := ac.(*tcpContext) ctx := ac.(*tcpContext)
if updated || s.virgin { rs, version := s.currentRuleset()
rulesetChanged := version != s.rulesetVersion
s.rulesetVersion = version
if updated || s.virgin || rulesetChanged {
s.virgin = false s.virgin = false
s.logger.TCPStreamPropUpdate(s.info, false) s.logger.TCPStreamPropUpdate(s.info, false)
// Match properties against ruleset // Match properties against ruleset
result := ruleset.MatchResult{Action: ruleset.ActionMaybe} result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
if rs := s.currentRuleset(); rs != nil { if rs != nil {
result = rs.Match(s.info) result = rs.Match(s.info)
} }
action := result.Action action := result.Action
@@ -178,13 +188,18 @@ func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass
} }
} }
func (s *tcpStream) currentRuleset() ruleset.Ruleset { func (s *tcpStream) currentRuleset() (ruleset.Ruleset, uint64) {
if s.rulesetSource == nil { if s.rulesetSource == nil {
return nil return nil, s.rulesetVersion
} }
return s.rulesetSource() return s.rulesetSource()
} }
func (s *tcpStream) rulesetChanged() bool {
_, version := s.currentRuleset()
return version != s.rulesetVersion
}
func (s *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { func (s *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool {
s.closeActiveEntries() s.closeActiveEntries()
return true return true

View File

@@ -43,6 +43,7 @@ type udpStreamFactory struct {
RulesetMutex sync.RWMutex RulesetMutex sync.RWMutex
Ruleset ruleset.Ruleset Ruleset ruleset.Ruleset
RulesetVersion uint64
} }
func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) *udpStream { func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) *udpStream {
@@ -60,8 +61,11 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
Props: make(analyzer.CombinedPropMap), Props: make(analyzer.CombinedPropMap),
} }
f.Logger.UDPStreamNew(f.WorkerID, info) f.Logger.UDPStreamNew(f.WorkerID, info)
rs := f.currentRuleset() rs, version := f.currentRuleset()
ans := analyzersToUDPAnalyzers(rs.Analyzers(info)) var ans []analyzer.UDPAnalyzer
if rs != nil {
ans = analyzersToUDPAnalyzers(rs.Analyzers(info))
}
// Create entries for each analyzer // Create entries for each analyzer
entries := make([]*udpStreamEntry, 0, len(ans)) entries := make([]*udpStreamEntry, 0, len(ans))
for _, a := range ans { for _, a := range ans {
@@ -85,6 +89,7 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
info: info, info: info,
virgin: true, virgin: true,
logger: f.Logger, logger: f.Logger,
rulesetVersion: version,
rulesetSource: f.currentRuleset, rulesetSource: f.currentRuleset,
activeEntries: entries, activeEntries: entries,
} }
@@ -94,13 +99,14 @@ func (f *udpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error {
f.RulesetMutex.Lock() f.RulesetMutex.Lock()
defer f.RulesetMutex.Unlock() defer f.RulesetMutex.Unlock()
f.Ruleset = r f.Ruleset = r
f.RulesetVersion++
return nil return nil
} }
func (f *udpStreamFactory) currentRuleset() ruleset.Ruleset { func (f *udpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) {
f.RulesetMutex.RLock() f.RulesetMutex.RLock()
defer f.RulesetMutex.RUnlock() defer f.RulesetMutex.RUnlock()
return f.Ruleset return f.Ruleset, f.RulesetVersion
} }
type udpStreamManager struct { type udpStreamManager struct {
@@ -190,7 +196,8 @@ type udpStream struct {
info ruleset.StreamInfo info ruleset.StreamInfo
virgin bool // true if no packets have been processed virgin bool // true if no packets have been processed
logger Logger logger Logger
rulesetSource func() ruleset.Ruleset rulesetVersion uint64
rulesetSource func() (ruleset.Ruleset, uint64)
activeEntries []*udpStreamEntry activeEntries []*udpStreamEntry
doneEntries []*udpStreamEntry doneEntries []*udpStreamEntry
lastVerdict udpVerdict lastVerdict udpVerdict
@@ -204,7 +211,7 @@ type udpStreamEntry struct {
} }
func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool { func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool {
if len(s.activeEntries) > 0 || s.virgin { if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() {
// Make sure every stream matches against the ruleset at least once, // Make sure every stream matches against the ruleset at least once,
// even if there are no activeEntries, as the ruleset may have built-in // even if there are no activeEntries, as the ruleset may have built-in
// properties that need to be matched. // properties that need to be matched.
@@ -229,12 +236,15 @@ func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) {
s.doneEntries = append(s.doneEntries, entry) s.doneEntries = append(s.doneEntries, entry)
} }
} }
if updated || s.virgin { rs, version := s.currentRuleset()
rulesetChanged := version != s.rulesetVersion
s.rulesetVersion = version
if updated || s.virgin || rulesetChanged {
s.virgin = false s.virgin = false
s.logger.UDPStreamPropUpdate(s.info, false) s.logger.UDPStreamPropUpdate(s.info, false)
// Match properties against ruleset // Match properties against ruleset
result := ruleset.MatchResult{Action: ruleset.ActionMaybe} result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
if rs := s.currentRuleset(); rs != nil { if rs != nil {
result = rs.Match(s.info) result = rs.Match(s.info)
} }
action := result.Action action := result.Action
@@ -273,13 +283,18 @@ func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) {
} }
} }
func (s *udpStream) currentRuleset() ruleset.Ruleset { func (s *udpStream) currentRuleset() (ruleset.Ruleset, uint64) {
if s.rulesetSource == nil { if s.rulesetSource == nil {
return nil return nil, s.rulesetVersion
} }
return s.rulesetSource() return s.rulesetSource()
} }
func (s *udpStream) rulesetChanged() bool {
_, version := s.currentRuleset()
return version != s.rulesetVersion
}
func (s *udpStream) Close() { func (s *udpStream) Close() {
s.closeActiveEntries() s.closeActiveEntries()
} }