ruleset: try to fix reloader

This commit is contained in:
2026-02-12 13:27:53 +05:30
parent a8f8b43f3e
commit beaaddad2b
3 changed files with 194 additions and 12 deletions

154
engine/reload_rules_test.go Normal file
View File

@@ -0,0 +1,154 @@
package engine
import (
"net"
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/reassembly"
)
type fixedRuleset struct {
action ruleset.Action
}
func (r fixedRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer {
return nil
}
func (r fixedRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult {
return ruleset.MatchResult{Action: r.action}
}
type noopTestLogger struct{}
func (noopTestLogger) WorkerStart(int) {}
func (noopTestLogger) WorkerStop(int) {}
func (noopTestLogger) TCPStreamNew(int, ruleset.StreamInfo) {}
func (noopTestLogger) TCPStreamPropUpdate(ruleset.StreamInfo, bool) {
}
func (noopTestLogger) TCPStreamAction(ruleset.StreamInfo, ruleset.Action, bool) {
}
func (noopTestLogger) UDPStreamNew(int, ruleset.StreamInfo) {}
func (noopTestLogger) UDPStreamPropUpdate(ruleset.StreamInfo, bool) {
}
func (noopTestLogger) UDPStreamAction(ruleset.StreamInfo, ruleset.Action, bool) {
}
func (noopTestLogger) ModifyError(ruleset.StreamInfo, error) {}
func (noopTestLogger) AnalyzerDebugf(int64, string, string, ...interface{}) {}
func (noopTestLogger) AnalyzerInfof(int64, string, string, ...interface{}) {}
func (noopTestLogger) AnalyzerErrorf(int64, string, string, ...interface{}) {}
func TestUDPStreamUsesUpdatedRuleset(t *testing.T) {
node, err := snowflake.NewNode(0)
if err != nil {
t.Fatalf("create node: %v", err)
}
f := &udpStreamFactory{
WorkerID: 0,
Logger: noopTestLogger{},
Node: node,
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
}
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
udp := &layers.UDP{
SrcPort: 12345,
DstPort: 53,
BaseLayer: layers.BaseLayer{
Payload: []byte("query"),
},
}
ctx := &udpContext{Verdict: udpVerdictAccept}
s := f.New(ipFlow, udp.TransportFlow(), udp, ctx)
if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil {
t.Fatalf("update ruleset: %v", err)
}
if !s.Accept(udp, false, ctx) {
t.Fatalf("unexpected Accept=false for virgin stream")
}
s.Feed(udp, false, ctx)
if ctx.Verdict != udpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx.Verdict, udpVerdictDropStream)
}
}
func TestTCPStreamUsesUpdatedRuleset(t *testing.T) {
node, err := snowflake.NewNode(0)
if err != nil {
t.Fatalf("create node: %v", err)
}
f := &tcpStreamFactory{
WorkerID: 0,
Logger: noopTestLogger{},
Node: node,
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
}
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
tcp := &layers.TCP{
SrcPort: 12345,
DstPort: 443,
}
ctx := &tcpContext{
PacketMetadata: &gopacket.PacketMetadata{},
Verdict: tcpVerdictAccept,
}
rs := f.New(ipFlow, tcp.TransportFlow(), tcp, ctx)
s, ok := rs.(*tcpStream)
if !ok {
t.Fatalf("unexpected stream type %T", rs)
}
if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil {
t.Fatalf("update ruleset: %v", err)
}
s.ReassembledSG(fakeScatterGather{data: []byte("payload")}, ctx)
if ctx.Verdict != tcpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx.Verdict, tcpVerdictDropStream)
}
}
type fakeScatterGather struct {
data []byte
}
func (s fakeScatterGather) Lengths() (int, int) {
return len(s.data), 0
}
func (s fakeScatterGather) Fetch(length int) []byte {
if length < 0 {
return nil
}
if length > len(s.data) {
length = len(s.data)
}
return s.data[:length]
}
func (fakeScatterGather) KeepFrom(int) {}
func (fakeScatterGather) CaptureInfo(int) gopacket.CaptureInfo {
return gopacket.CaptureInfo{}
}
func (fakeScatterGather) Info() (reassembly.TCPFlowDirection, bool, bool, int) {
return reassembly.TCPDirClientToServer, true, false, 0
}
func (fakeScatterGather) Stats() reassembly.TCPAssemblyStats {
return reassembly.TCPAssemblyStats{}
}

View File

@@ -60,9 +60,7 @@ func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, a
Props: make(analyzer.CombinedPropMap), Props: make(analyzer.CombinedPropMap),
} }
f.Logger.TCPStreamNew(f.WorkerID, info) f.Logger.TCPStreamNew(f.WorkerID, info)
f.RulesetMutex.RLock() rs := f.currentRuleset()
rs := f.Ruleset
f.RulesetMutex.RUnlock()
ans := analyzersToTCPAnalyzers(rs.Analyzers(info)) ans := analyzersToTCPAnalyzers(rs.Analyzers(info))
// Create entries for each analyzer // Create entries for each analyzer
entries := make([]*tcpStreamEntry, 0, len(ans)) entries := make([]*tcpStreamEntry, 0, len(ans))
@@ -87,7 +85,7 @@ func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, a
info: info, info: info,
virgin: true, virgin: true,
logger: f.Logger, logger: f.Logger,
ruleset: rs, rulesetSource: f.currentRuleset,
activeEntries: entries, activeEntries: entries,
} }
} }
@@ -99,11 +97,17 @@ func (f *tcpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error {
return nil return nil
} }
func (f *tcpStreamFactory) currentRuleset() ruleset.Ruleset {
f.RulesetMutex.RLock()
defer f.RulesetMutex.RUnlock()
return f.Ruleset
}
type tcpStream struct { type tcpStream struct {
info ruleset.StreamInfo info ruleset.StreamInfo
virgin bool // true if no packets have been processed virgin bool // true if no packets have been processed
logger Logger logger Logger
ruleset ruleset.Ruleset rulesetSource func() ruleset.Ruleset
activeEntries []*tcpStreamEntry activeEntries []*tcpStreamEntry
doneEntries []*tcpStreamEntry doneEntries []*tcpStreamEntry
lastVerdict tcpVerdict lastVerdict tcpVerdict
@@ -152,7 +156,10 @@ func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass
s.virgin = false s.virgin = false
s.logger.TCPStreamPropUpdate(s.info, false) s.logger.TCPStreamPropUpdate(s.info, false)
// Match properties against ruleset // Match properties against ruleset
result := s.ruleset.Match(s.info) result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
if rs := s.currentRuleset(); rs != nil {
result = rs.Match(s.info)
}
action := result.Action action := result.Action
if action != ruleset.ActionMaybe && action != ruleset.ActionModify { if action != ruleset.ActionMaybe && action != ruleset.ActionModify {
verdict := actionToTCPVerdict(action) verdict := actionToTCPVerdict(action)
@@ -171,6 +178,13 @@ func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass
} }
} }
func (s *tcpStream) currentRuleset() ruleset.Ruleset {
if s.rulesetSource == nil {
return nil
}
return s.rulesetSource()
}
func (s *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { func (s *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool {
s.closeActiveEntries() s.closeActiveEntries()
return true return true

View File

@@ -60,9 +60,7 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
Props: make(analyzer.CombinedPropMap), Props: make(analyzer.CombinedPropMap),
} }
f.Logger.UDPStreamNew(f.WorkerID, info) f.Logger.UDPStreamNew(f.WorkerID, info)
f.RulesetMutex.RLock() rs := f.currentRuleset()
rs := f.Ruleset
f.RulesetMutex.RUnlock()
ans := analyzersToUDPAnalyzers(rs.Analyzers(info)) ans := analyzersToUDPAnalyzers(rs.Analyzers(info))
// Create entries for each analyzer // Create entries for each analyzer
entries := make([]*udpStreamEntry, 0, len(ans)) entries := make([]*udpStreamEntry, 0, len(ans))
@@ -87,7 +85,7 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
info: info, info: info,
virgin: true, virgin: true,
logger: f.Logger, logger: f.Logger,
ruleset: rs, rulesetSource: f.currentRuleset,
activeEntries: entries, activeEntries: entries,
} }
} }
@@ -99,6 +97,12 @@ func (f *udpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error {
return nil return nil
} }
func (f *udpStreamFactory) currentRuleset() ruleset.Ruleset {
f.RulesetMutex.RLock()
defer f.RulesetMutex.RUnlock()
return f.Ruleset
}
type udpStreamManager struct { type udpStreamManager struct {
factory *udpStreamFactory factory *udpStreamFactory
streams *lru.Cache[uint32, *udpStreamValue] streams *lru.Cache[uint32, *udpStreamValue]
@@ -186,7 +190,7 @@ type udpStream struct {
info ruleset.StreamInfo info ruleset.StreamInfo
virgin bool // true if no packets have been processed virgin bool // true if no packets have been processed
logger Logger logger Logger
ruleset ruleset.Ruleset rulesetSource func() ruleset.Ruleset
activeEntries []*udpStreamEntry activeEntries []*udpStreamEntry
doneEntries []*udpStreamEntry doneEntries []*udpStreamEntry
lastVerdict udpVerdict lastVerdict udpVerdict
@@ -229,7 +233,10 @@ func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) {
s.virgin = false s.virgin = false
s.logger.UDPStreamPropUpdate(s.info, false) s.logger.UDPStreamPropUpdate(s.info, false)
// Match properties against ruleset // Match properties against ruleset
result := s.ruleset.Match(s.info) result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
if rs := s.currentRuleset(); rs != nil {
result = rs.Match(s.info)
}
action := result.Action action := result.Action
if action == ruleset.ActionModify { if action == ruleset.ActionModify {
// Call the modifier instance // Call the modifier instance
@@ -266,6 +273,13 @@ func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) {
} }
} }
func (s *udpStream) currentRuleset() ruleset.Ruleset {
if s.rulesetSource == nil {
return nil
}
return s.rulesetSource()
}
func (s *udpStream) Close() { func (s *udpStream) Close() {
s.closeActiveEntries() s.closeActiveEntries()
} }