diff --git a/app.go b/app.go index b1cff75..b9d36fc 100644 --- a/app.go +++ b/app.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "runtime" "git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/engine" @@ -17,6 +18,7 @@ type App struct { engine engine.Engine io gfwio.PacketIO rulesetConfig *ruleset.BuiltinConfig + ruleset ruleset.Ruleset analyzers []analyzer.Analyzer modifiers []modifier.Modifier rulesFile string @@ -42,6 +44,11 @@ func New(cfg Config, opts Options) (*App, error) { packetIO := cfg.IO.PacketIO ownsIO := false + workerCount := effectiveWorkerCount(cfg.Workers.Count) + numQueues := cfg.IO.NumQueues + if numQueues <= 0 { + numQueues = workerCount + } if packetIO == nil { packetIO, err = gfwio.NewNFQueuePacketIO(gfwio.NFQueuePacketIOConfig{ QueueSize: cfg.IO.QueueSize, @@ -49,7 +56,7 @@ func New(cfg Config, opts Options) (*App, error) { WriteBuffer: cfg.IO.WriteBuffer, Local: cfg.IO.Local, RST: cfg.IO.RST, - NumQueues: cfg.IO.NumQueues, + NumQueues: numQueues, MaxPacketLen: cfg.IO.MaxPacketLen, }) if err != nil { @@ -79,11 +86,13 @@ func New(cfg Config, opts Options) (*App, error) { Logger: engineLogger, IO: packetIO, Ruleset: rs, - Workers: cfg.Workers.Count, + Workers: workerCount, WorkerQueueSize: cfg.Workers.QueueSize, WorkerTCPMaxBufferedPagesTotal: cfg.Workers.TCPMaxBufferedPagesTotal, WorkerTCPMaxBufferedPagesPerConn: cfg.Workers.TCPMaxBufferedPagesPerConn, WorkerUDPMaxStreams: cfg.Workers.UDPMaxStreams, + OverflowPolicy: cfg.Workers.OverflowPolicy, + AnalyzerSelectionMode: cfg.Workers.AnalyzerSelectionMode, } eng, err := engine.NewEngine(engCfg) if err != nil { @@ -95,6 +104,7 @@ func New(cfg Config, opts Options) (*App, error) { engine: eng, io: packetIO, rulesetConfig: rsConfig, + ruleset: rs, analyzers: analyzers, modifiers: modifiers, rulesFile: rulesFile, @@ -140,6 +150,17 @@ func (a *App) Engine() engine.Engine { return a.engine } +func effectiveWorkerCount(configured int) int { + if configured > 0 { + return configured + } + n := runtime.GOMAXPROCS(0) + if n <= 0 { + return 1 + } + return n +} + func resolveRules(opts Options) ([]ruleset.ExprRule, string, error) { if opts.RulesFile != "" && len(opts.Rules) > 0 { return nil, "", ConfigError{Field: "rules", Err: errors.New("use either RulesFile or Rules")} diff --git a/config.go b/config.go index eaab254..cdbb048 100644 --- a/config.go +++ b/config.go @@ -32,11 +32,13 @@ type IOConfig struct { // WorkersConfig configures engine worker behavior. type WorkersConfig struct { - Count int `mapstructure:"count" yaml:"count"` - QueueSize int `mapstructure:"queueSize" yaml:"queueSize"` - TCPMaxBufferedPagesTotal int `mapstructure:"tcpMaxBufferedPagesTotal" yaml:"tcpMaxBufferedPagesTotal"` - TCPMaxBufferedPagesPerConn int `mapstructure:"tcpMaxBufferedPagesPerConn" yaml:"tcpMaxBufferedPagesPerConn"` - UDPMaxStreams int `mapstructure:"udpMaxStreams" yaml:"udpMaxStreams"` + Count int `mapstructure:"count" yaml:"count"` + QueueSize int `mapstructure:"queueSize" yaml:"queueSize"` + TCPMaxBufferedPagesTotal int `mapstructure:"tcpMaxBufferedPagesTotal" yaml:"tcpMaxBufferedPagesTotal"` + TCPMaxBufferedPagesPerConn int `mapstructure:"tcpMaxBufferedPagesPerConn" yaml:"tcpMaxBufferedPagesPerConn"` + UDPMaxStreams int `mapstructure:"udpMaxStreams" yaml:"udpMaxStreams"` + OverflowPolicy engine.OverflowPolicy `mapstructure:"overflowPolicy" yaml:"overflowPolicy"` + AnalyzerSelectionMode engine.AnalyzerSelectionMode `mapstructure:"analyzerSelectionMode" yaml:"analyzerSelectionMode"` } // RulesetConfig configures built-in rule helpers. diff --git a/engine/analyzer_selector.go b/engine/analyzer_selector.go new file mode 100644 index 0000000..e1aceb7 --- /dev/null +++ b/engine/analyzer_selector.go @@ -0,0 +1,235 @@ +package engine + +import ( + "bytes" + "strings" + + "git.difuse.io/Difuse/Mellaris/analyzer" +) + +type analyzerSelector struct { + mode AnalyzerSelectionMode + stats *statsCounters +} + +func newAnalyzerSelector(mode AnalyzerSelectionMode, stats *statsCounters) *analyzerSelector { + if mode == "" { + mode = AnalyzerSelectionModeSignature + } + return &analyzerSelector{mode: mode, stats: stats} +} + +func (s *analyzerSelector) SelectTCP(ans []analyzer.Analyzer, payload []byte) []analyzer.Analyzer { + if s == nil || s.mode == AnalyzerSelectionModeAlways || len(ans) <= 1 { + return ans + } + allowed := tcpAllowedAnalyzers(payload) + if len(allowed) == 0 { + return ans + } + out := make([]analyzer.Analyzer, 0, len(ans)) + for _, a := range ans { + name := strings.ToLower(a.Name()) + if _, known := knownTCPAnalyzers[name]; !known { + out = append(out, a) + continue + } + if allowed[name] { + out = append(out, a) + } + } + s.recordSelection(len(ans), len(out)) + if len(out) == 0 { + return ans + } + return out +} + +func (s *analyzerSelector) SelectUDP(ans []analyzer.Analyzer, payload []byte) []analyzer.Analyzer { + if s == nil || s.mode == AnalyzerSelectionModeAlways || len(ans) <= 1 { + return ans + } + allowed := udpAllowedAnalyzers(payload) + if len(allowed) == 0 { + return ans + } + out := make([]analyzer.Analyzer, 0, len(ans)) + for _, a := range ans { + name := strings.ToLower(a.Name()) + if _, known := knownUDPAnalyzers[name]; !known { + out = append(out, a) + continue + } + if allowed[name] { + out = append(out, a) + } + } + s.recordSelection(len(ans), len(out)) + if len(out) == 0 { + return ans + } + return out +} + +func (s *analyzerSelector) recordSelection(total, selected int) { + if s == nil || s.stats == nil || total <= 0 { + return + } + s.stats.AnalyzerSelectionsTotal.Add(1) + if selected < total { + s.stats.AnalyzerSelectionsPruned.Add(1) + } +} + +var ( + knownTCPAnalyzers = map[string]struct{}{ + "fet": {}, + "http": {}, + "socks": {}, + "ssh": {}, + "tls": {}, + "trojan": {}, + "dns": {}, + "openvpn": {}, + } + knownUDPAnalyzers = map[string]struct{}{ + "dns": {}, + "openvpn": {}, + "quic": {}, + "wireguard": {}, + } +) + +func tcpAllowedAnalyzers(payload []byte) map[string]bool { + allowed := make(map[string]bool, 4) + if looksLikeTLS(payload) { + allowed["tls"] = true + allowed["trojan"] = true + allowed["fet"] = true + } + if looksLikeHTTP(payload) { + allowed["http"] = true + allowed["fet"] = true + } + if looksLikeSSH(payload) { + allowed["ssh"] = true + allowed["fet"] = true + } + if looksLikeSOCKS(payload) { + allowed["socks"] = true + allowed["fet"] = true + } + if looksLikeDNSTCP(payload) { + allowed["dns"] = true + allowed["fet"] = true + } + if len(allowed) == 0 { + return nil + } + return allowed +} + +func udpAllowedAnalyzers(payload []byte) map[string]bool { + allowed := make(map[string]bool, 4) + if looksLikeWireGuard(payload) { + allowed["wireguard"] = true + } + if looksLikeOpenVPN(payload) { + allowed["openvpn"] = true + } + if looksLikeQUIC(payload) { + allowed["quic"] = true + } + if looksLikeDNSUDP(payload) { + allowed["dns"] = true + } + if len(allowed) == 0 { + return nil + } + return allowed +} + +func looksLikeTLS(payload []byte) bool { + if len(payload) < 3 { + return false + } + return (payload[0] == 0x16 || payload[0] == 0x17) && payload[1] == 0x03 && payload[2] <= 0x09 +} + +func looksLikeHTTP(payload []byte) bool { + if len(payload) < 3 { + return false + } + head := strings.ToUpper(string(payload[:3])) + switch head { + case "GET", "HEA", "POS", "PUT", "DEL", "CON", "OPT", "TRA", "PAT": + return true + default: + return false + } +} + +func looksLikeSSH(payload []byte) bool { + return len(payload) >= 4 && bytes.HasPrefix(payload, []byte("SSH-")) +} + +func looksLikeSOCKS(payload []byte) bool { + if len(payload) < 2 { + return false + } + return payload[0] == 0x04 || payload[0] == 0x05 +} + +func looksLikeDNSTCP(payload []byte) bool { + if len(payload) < 14 { + return false + } + msgLen := int(payload[0])<<8 | int(payload[1]) + if msgLen <= 0 || msgLen+2 > len(payload) { + return false + } + qd := int(payload[6])<<8 | int(payload[7]) + an := int(payload[8])<<8 | int(payload[9]) + return qd+an > 0 +} + +func looksLikeDNSUDP(payload []byte) bool { + if len(payload) < 12 { + return false + } + qd := int(payload[4])<<8 | int(payload[5]) + an := int(payload[6])<<8 | int(payload[7]) + ns := int(payload[8])<<8 | int(payload[9]) + ar := int(payload[10])<<8 | int(payload[11]) + return qd+an+ns+ar > 0 +} + +func looksLikeQUIC(payload []byte) bool { + if len(payload) < 6 { + return false + } + // Long header with non-zero version. + if payload[0]&0x80 == 0 { + return false + } + version := uint32(payload[1])<<24 | uint32(payload[2])<<16 | uint32(payload[3])<<8 | uint32(payload[4]) + return version != 0 +} + +func looksLikeOpenVPN(payload []byte) bool { + if len(payload) == 0 { + return false + } + opcode := payload[0] >> 3 + return opcode >= 1 && opcode <= 11 +} + +func looksLikeWireGuard(payload []byte) bool { + if len(payload) < 4 { + return false + } + if payload[0] < 1 || payload[0] > 4 { + return false + } + return payload[1] == 0 && payload[2] == 0 && payload[3] == 0 +} diff --git a/engine/analyzer_selector_test.go b/engine/analyzer_selector_test.go new file mode 100644 index 0000000..c4acebb --- /dev/null +++ b/engine/analyzer_selector_test.go @@ -0,0 +1,56 @@ +package engine + +import ( + "testing" + + "git.difuse.io/Difuse/Mellaris/analyzer" +) + +type namedAnalyzer struct{ name string } + +func (a namedAnalyzer) Name() string { return a.name } +func (a namedAnalyzer) Limit() int { return 0 } + +func TestSignatureSelectorTCPPrunesByPayloadNotPort(t *testing.T) { + sel := newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{}) + all := []analyzer.Analyzer{ + namedAnalyzer{"http"}, + namedAnalyzer{"tls"}, + namedAnalyzer{"trojan"}, + namedAnalyzer{"ssh"}, + namedAnalyzer{"socks"}, + namedAnalyzer{"fet"}, + } + // TLS record-like prefix, regardless of destination port. + payload := []byte{0x16, 0x03, 0x03, 0x00, 0x10} + selected := sel.SelectTCP(all, payload) + + got := make(map[string]bool) + for _, a := range selected { + got[a.Name()] = true + } + for _, keep := range []string{"tls", "trojan", "fet"} { + if !got[keep] { + t.Fatalf("expected analyzer %q to be selected", keep) + } + } + for _, drop := range []string{"http", "ssh", "socks"} { + if got[drop] { + t.Fatalf("expected analyzer %q to be pruned", drop) + } + } +} + +func TestSignatureSelectorConservativeFallback(t *testing.T) { + sel := newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{}) + all := []analyzer.Analyzer{ + namedAnalyzer{"http"}, + namedAnalyzer{"tls"}, + namedAnalyzer{"custom"}, + } + payload := []byte{0xde, 0xad, 0xbe, 0xef} + selected := sel.SelectTCP(all, payload) + if len(selected) != len(all) { + t.Fatalf("expected conservative fallback to keep all analyzers, got=%d want=%d", len(selected), len(all)) + } +} diff --git a/engine/engine.go b/engine/engine.go index 269a511..9934090 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -2,6 +2,7 @@ package engine import ( "context" + "runtime" "sync" "sync/atomic" @@ -20,18 +21,34 @@ type engine struct { logger Logger io io.PacketIO workers []*worker - verdicts sync.Map // streamID(uint32) → verdictEntry + stats *statsCounters + verdicts sync.Map // streamID(uint32) -> verdictEntry verdictsGen atomic.Int64 // incremented on ruleset update - overflowCh chan *workerPacket - overflowOnce sync.Once + overflowPolicy OverflowPolicy + resultCh chan workerResult } func NewEngine(config Config) (Engine, error) { workerCount := config.Workers if workerCount <= 0 { - workerCount = 1 + workerCount = runtime.GOMAXPROCS(0) + if workerCount <= 0 { + workerCount = 1 + } } + overflowPolicy := config.OverflowPolicy + if overflowPolicy == "" { + overflowPolicy = OverflowPolicyAccept + } + selectionMode := config.AnalyzerSelectionMode + if selectionMode == "" { + selectionMode = AnalyzerSelectionModeSignature + } + + stats := &statsCounters{} + resultCh := make(chan workerResult, workerCount*256) + macResolver := newSourceMACResolver() var err error workers := make([]*worker, workerCount) @@ -45,16 +62,21 @@ func NewEngine(config Config) (Engine, error) { TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal, TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn, UDPMaxStreams: config.WorkerUDPMaxStreams, + AnalyzerSelectionMode: selectionMode, + ResultChan: resultCh, + Stats: stats, }) if err != nil { return nil, err } } e := &engine{ - logger: config.Logger, - io: config.IO, - workers: workers, - overflowCh: make(chan *workerPacket, 1024), + logger: config.Logger, + io: config.IO, + workers: workers, + stats: stats, + overflowPolicy: overflowPolicy, + resultCh: resultCh, } return e, nil } @@ -74,13 +96,10 @@ func (e *engine) Run(ctx context.Context) error { ioCtx, ioCancel := context.WithCancel(ctx) defer ioCancel() - e.overflowOnce.Do(func() { - go e.drainOverflow(ioCtx) - }) - for _, w := range e.workers { go w.Run(ioCtx) } + go e.drainResults(ioCtx) errChan := make(chan error, 1) err := e.io.Register(ioCtx, func(p io.Packet, err error) bool { @@ -121,24 +140,35 @@ func (e *engine) dispatch(p io.Packet) bool { gen := e.verdictsGen.Load() index := streamID % uint32(len(e.workers)) wp := &workerPacket{ - StreamID: streamID, - Data: data, - SetVerdict: func(v io.Verdict, b []byte) error { - if v == io.VerdictAcceptStream || v == io.VerdictDropStream { - e.verdicts.Store(streamID, verdictEntry{Verdict: v, Gen: gen}) - } - return e.io.SetVerdict(p, v, b) - }, + Packet: p, + StreamID: streamID, + Data: data, + Gen: gen, } if !e.workers[index].Feed(wp) { - select { - case e.overflowCh <- wp: + e.stats.OverflowEvents.Add(1) + switch e.overflowPolicy { + case OverflowPolicyDrop: + e.stats.OverflowDrops.Add(1) + _ = e.io.SetVerdict(p, io.VerdictDrop, nil) + case OverflowPolicyBackpressure: + e.stats.OverflowBackpressureEvents.Add(1) + e.workers[index].FeedBlocking(wp) default: + e.stats.OverflowAccepts.Add(1) + _ = e.io.SetVerdict(p, io.VerdictAccept, nil) } } return true } +func (e *engine) applyWorkerResult(r workerResult) { + if r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream { + e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen}) + } + _ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket) +} + func validPacket(data []byte) bool { if len(data) == 0 { return false @@ -156,13 +186,26 @@ func validPacket(data []byte) bool { return false } -func (e *engine) drainOverflow(ctx context.Context) { +func (e *engine) drainResults(ctx context.Context) { for { select { case <-ctx.Done(): return - case wp := <-e.overflowCh: - _ = wp.SetVerdict(io.VerdictAccept, nil) + case r := <-e.resultCh: + e.applyWorkerResult(r) } } } + +func (e *engine) Stats() Stats { + return Stats{ + OverflowEvents: e.stats.OverflowEvents.Load(), + OverflowAccepts: e.stats.OverflowAccepts.Load(), + OverflowDrops: e.stats.OverflowDrops.Load(), + OverflowBackpressureEvents: e.stats.OverflowBackpressureEvents.Load(), + AnalyzerSelectionsTotal: e.stats.AnalyzerSelectionsTotal.Load(), + AnalyzerSelectionsPruned: e.stats.AnalyzerSelectionsPruned.Load(), + UDPTupleLookups: e.stats.UDPTupleLookups.Load(), + UDPTupleHits: e.stats.UDPTupleHits.Load(), + } +} diff --git a/engine/interface.go b/engine/interface.go index d544dbf..4c544b6 100644 --- a/engine/interface.go +++ b/engine/interface.go @@ -2,6 +2,7 @@ package engine import ( "context" + "sync/atomic" "git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/ruleset" @@ -13,6 +14,49 @@ type Engine interface { UpdateRuleset(ruleset.Ruleset) error // Run runs the engine, until an error occurs or the context is cancelled. Run(context.Context) error + // Stats returns a consistent snapshot of runtime counters. + Stats() Stats +} + +type OverflowPolicy string + +const ( + OverflowPolicyAccept OverflowPolicy = "accept" + OverflowPolicyDrop OverflowPolicy = "drop" + OverflowPolicyBackpressure OverflowPolicy = "backpressure" +) + +type AnalyzerSelectionMode string + +const ( + AnalyzerSelectionModeAlways AnalyzerSelectionMode = "always" + AnalyzerSelectionModeSignature AnalyzerSelectionMode = "signature" +) + +type statsCounters struct { + OverflowEvents atomic.Uint64 + OverflowAccepts atomic.Uint64 + OverflowDrops atomic.Uint64 + OverflowBackpressureEvents atomic.Uint64 + + AnalyzerSelectionsTotal atomic.Uint64 + AnalyzerSelectionsPruned atomic.Uint64 + + UDPTupleLookups atomic.Uint64 + UDPTupleHits atomic.Uint64 +} + +type Stats struct { + OverflowEvents uint64 + OverflowAccepts uint64 + OverflowDrops uint64 + OverflowBackpressureEvents uint64 + + AnalyzerSelectionsTotal uint64 + AnalyzerSelectionsPruned uint64 + + UDPTupleLookups uint64 + UDPTupleHits uint64 } // Config is the configuration for the engine. @@ -26,6 +70,8 @@ type Config struct { WorkerTCPMaxBufferedPagesTotal int WorkerTCPMaxBufferedPagesPerConn int WorkerUDPMaxStreams int + OverflowPolicy OverflowPolicy + AnalyzerSelectionMode AnalyzerSelectionMode } // Logger is the combined logging interface for the engine, workers and analyzers. diff --git a/engine/mac_resolver.go b/engine/mac_resolver.go index 91fefd7..30c1168 100644 --- a/engine/mac_resolver.go +++ b/engine/mac_resolver.go @@ -1,3 +1,6 @@ +//go:build linux +// +build linux + package engine import ( diff --git a/engine/mac_resolver_stub.go b/engine/mac_resolver_stub.go new file mode 100644 index 0000000..4cf4e55 --- /dev/null +++ b/engine/mac_resolver_stub.go @@ -0,0 +1,17 @@ +//go:build !linux +// +build !linux + +package engine + +import "net" + +type sourceMACResolver struct{} + +func newSourceMACResolver() *sourceMACResolver { + return &sourceMACResolver{} +} + +func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr { + _ = ip + return nil +} diff --git a/engine/reload_rules_test.go b/engine/reload_rules_test.go index 04ae7c4..cda881a 100644 --- a/engine/reload_rules_test.go +++ b/engine/reload_rules_test.go @@ -142,7 +142,7 @@ func TestTCPFlowUsesUpdatedRuleset(t *testing.T) { if err != nil { t.Fatalf("create node: %v", err) } - mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node) + mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, nil) mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0) l3 := L3Info{ @@ -180,7 +180,7 @@ func TestTCPFlowReevaluatesAfterRulesetVersionChange(t *testing.T) { if err != nil { t.Fatalf("create node: %v", err) } - mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node) + mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, nil) mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0) l3 := L3Info{ diff --git a/engine/tcp_flow.go b/engine/tcp_flow.go index d1c5cca..33ba874 100644 --- a/engine/tcp_flow.go +++ b/engine/tcp_flow.go @@ -163,15 +163,17 @@ type tcpFlowManager struct { rulesetSource func() (ruleset.Ruleset, uint64) workerID int macResolver *sourceMACResolver + selector *analyzerSelector } -func newTCPFlowManager(workerID int, logger Logger, macResolver *sourceMACResolver, node *snowflake.Node) *tcpFlowManager { +func newTCPFlowManager(workerID int, logger Logger, macResolver *sourceMACResolver, node *snowflake.Node, selector *analyzerSelector) *tcpFlowManager { return &tcpFlowManager{ flows: make(map[uint32]*tcpFlow), sfNode: node, logger: logger, workerID: workerID, macResolver: macResolver, + selector: selector, } } @@ -179,7 +181,7 @@ func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload m.mu.Lock() flow, ok := m.flows[streamID] if !ok { - flow = m.createFlow(streamID, l3, tcp, srcMAC, dstMAC) + flow = m.createFlow(streamID, l3, tcp, payload, srcMAC, dstMAC) m.flows[streamID] = flow } m.mu.Unlock() @@ -195,7 +197,7 @@ func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload return verdict } -func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, srcMAC, dstMAC net.HardwareAddr) *tcpFlow { +func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) *tcpFlow { id := m.sfNode.Generate() ipSrc := net.IP(l3.SrcIP[:]) ipDst := net.IP(l3.DstIP[:]) @@ -217,7 +219,9 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, src rs, version := m.rulesetSource() var ans []analyzer.TCPAnalyzer if rs != nil { - ans = analyzersToTCPAnalyzers(rs.Analyzers(info)) + baseAns := rs.Analyzers(info) + baseAns = m.selector.SelectTCP(baseAns, payload) + ans = analyzersToTCPAnalyzers(baseAns) } entries := make([]*tcpFlowEntry, 0, len(ans)) for _, a := range ans { diff --git a/engine/udp.go b/engine/udp.go index 7668ebf..5d43177 100644 --- a/engine/udp.go +++ b/engine/udp.go @@ -1,6 +1,7 @@ package engine import ( + "bytes" "errors" "net" "sync" @@ -40,6 +41,8 @@ type udpStreamFactory struct { WorkerID int Logger Logger Node *snowflake.Node + Selector *analyzerSelector + Stats *statsCounters RulesetMutex sync.RWMutex Ruleset ruleset.Ruleset @@ -64,7 +67,11 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u rs, version := f.currentRuleset() var ans []analyzer.UDPAnalyzer if rs != nil { - ans = analyzersToUDPAnalyzers(rs.Analyzers(info)) + baseAns := rs.Analyzers(info) + if f.Selector != nil { + baseAns = f.Selector.SelectUDP(baseAns, udp.Payload) + } + ans = analyzersToUDPAnalyzers(baseAns) } // Create entries for each analyzer entries := make([]*udpStreamEntry, 0, len(ans)) @@ -110,8 +117,11 @@ func (f *udpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) { } type udpStreamManager struct { - factory *udpStreamFactory - streams *lru.Cache[uint32, *udpStreamValue] + factory *udpStreamFactory + streams *lru.Cache[uint32, *udpStreamValue] + tupleIndex map[udpTupleKey]uint32 + streamTuples map[uint32]udpTupleKey + stats *statsCounters } type udpStreamValue struct { @@ -120,36 +130,71 @@ type udpStreamValue struct { UDPFlow gopacket.Flow } +type udpTupleKey struct { + AIP [16]byte + BIP [16]byte + ALen uint8 + BLen uint8 + APort uint16 + BPort uint16 +} + func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) { fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse() return fwd || rev, rev } -func newUDPStreamManager(factory *udpStreamFactory, maxStreams int) (*udpStreamManager, error) { - ss, err := lru.New[uint32, *udpStreamValue](maxStreams) +func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) { + m := &udpStreamManager{ + factory: factory, + tupleIndex: make(map[udpTupleKey]uint32, maxStreams), + streamTuples: make(map[uint32]udpTupleKey, maxStreams), + stats: stats, + } + ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) { + m.removeTupleMappingLocked(k) + }) if err != nil { return nil, err } - return &udpStreamManager{ - factory: factory, - streams: ss, - }, nil + m.streams = ss + return m, nil } func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) { rev := false value, ok := m.streams.Get(streamID) + tuple := canonicalUDPTupleKey(ipFlow, udp) if !ok { - // Fallback: conntrack IDs can change during early flow lifetime on some systems. - // Try to find an existing stream by 5-tuple before creating a new stream. - matchedKey, matchedValue, matchedRev, found := m.findByFlow(ipFlow, udp.TransportFlow()) + if m.stats != nil { + m.stats.UDPTupleLookups.Add(1) + } + // Conntrack IDs can change during early flow lifetime on some systems. + // Rebind by canonical 5-tuple in O(1). + matchedKey, found := m.tupleIndex[tuple] + var matchedValue *udpStreamValue + var matchedRev bool if found { + if m.stats != nil { + m.stats.UDPTupleHits.Add(1) + } + var hasValue bool + matchedValue, hasValue = m.streams.Get(matchedKey) + if !hasValue || matchedValue == nil { + delete(m.tupleIndex, tuple) + delete(m.streamTuples, matchedKey) + found = false + } + } + if found { + _, matchedRev = matchedValue.Match(ipFlow, udp.TransportFlow()) value = matchedValue rev = matchedRev if matchedKey != streamID { m.streams.Remove(matchedKey) m.streams.Add(streamID, matchedValue) + m.bindTupleLocked(streamID, tuple) } } else { // New stream @@ -159,6 +204,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo UDPFlow: udp.TransportFlow(), } m.streams.Add(streamID, value) + m.bindTupleLocked(streamID, tuple) } } else { // Stream ID exists, but is it really the same stream? @@ -172,6 +218,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo UDPFlow: udp.TransportFlow(), } m.streams.Add(streamID, value) + m.bindTupleLocked(streamID, tuple) } } if value.Stream.Accept(udp, rev, uc) { @@ -179,17 +226,58 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo } } -func (m *udpStreamManager) findByFlow(ipFlow, udpFlow gopacket.Flow) (key uint32, value *udpStreamValue, rev bool, found bool) { - for _, k := range m.streams.Keys() { - v, ok := m.streams.Peek(k) - if !ok || v == nil { - continue - } - if ok2, rev2 := v.Match(ipFlow, udpFlow); ok2 { - return k, v, rev2, true +func (m *udpStreamManager) bindTupleLocked(streamID uint32, key udpTupleKey) { + m.removeTupleMappingLocked(streamID) + m.tupleIndex[key] = streamID + m.streamTuples[streamID] = key +} + +func (m *udpStreamManager) removeTupleMappingLocked(streamID uint32) { + if key, ok := m.streamTuples[streamID]; ok { + delete(m.streamTuples, streamID) + current, exists := m.tupleIndex[key] + if exists && current == streamID { + delete(m.tupleIndex, key) } } - return 0, nil, false, false +} + +func canonicalUDPTupleKey(ipFlow gopacket.Flow, udp *layers.UDP) udpTupleKey { + srcIP := ipFlow.Src().Raw() + dstIP := ipFlow.Dst().Raw() + srcPort := uint16(udp.SrcPort) + dstPort := uint16(udp.DstPort) + + if compareIPEndpoint(srcIP, srcPort, dstIP, dstPort) > 0 { + srcIP, dstIP = dstIP, srcIP + srcPort, dstPort = dstPort, srcPort + } + + var key udpTupleKey + key.ALen = uint8(copy(key.AIP[:], srcIP)) + key.BLen = uint8(copy(key.BIP[:], dstIP)) + key.APort = srcPort + key.BPort = dstPort + return key +} + +func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int { + if len(aIP) != len(bIP) { + if len(aIP) < len(bIP) { + return -1 + } + return 1 + } + if c := bytes.Compare(aIP, bIP); c != 0 { + return c + } + if aPort < bPort { + return -1 + } + if aPort > bPort { + return 1 + } + return 0 } type udpStream struct { diff --git a/engine/udp_manager_bench_test.go b/engine/udp_manager_bench_test.go new file mode 100644 index 0000000..49a8fb9 --- /dev/null +++ b/engine/udp_manager_bench_test.go @@ -0,0 +1,122 @@ +package engine + +import ( + "net" + "testing" + + "git.difuse.io/Difuse/Mellaris/analyzer" + "git.difuse.io/Difuse/Mellaris/ruleset" + + "github.com/bwmarrin/snowflake" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +type legacyUDPStreamValue struct { + IPFlow gopacket.Flow + UDPFlow gopacket.Flow +} + +type emptyRuleset struct{} + +func (emptyRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer { return nil } +func (emptyRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult { + return ruleset.MatchResult{Action: ruleset.ActionMaybe} +} + +func benchmarkUDPManager(b *testing.B, churn bool) { + node, err := snowflake.NewNode(0) + if err != nil { + b.Fatalf("create node: %v", err) + } + factory := &udpStreamFactory{WorkerID: 0, Logger: noopTestLogger{}, Node: node, Ruleset: emptyRuleset{}} + mgr, err := newUDPStreamManager(factory, 200000, &statsCounters{}) + if err != nil { + b.Fatalf("new manager: %v", err) + } + + const flowCount = 20000 + flows := make([]gopacket.Flow, flowCount) + udps := make([]*layers.UDP, flowCount) + for i := 0; i < flowCount; i++ { + a := byte(i >> 8) + c := byte(i) + flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4()) + udps[i] = &layers.UDP{ + SrcPort: layers.UDPPort(1024 + i%20000), + DstPort: layers.UDPPort(20000 + (i*7)%20000), + BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}, + } + } + + ctx := &udpContext{Verdict: udpVerdictAccept} + b.ResetTimer() + for i := 0; i < b.N; i++ { + idx := i % flowCount + streamID := uint32(idx + 1) + if churn { + streamID = uint32((i % flowCount) + 1 + ((i / flowCount) * flowCount)) + } + ctx.Verdict = udpVerdictAccept + ctx.Packet = nil + mgr.MatchWithContext(streamID, flows[idx], udps[idx], ctx) + } +} + +func BenchmarkUDPManagerMatchStableStreamID(b *testing.B) { + benchmarkUDPManager(b, false) +} + +func BenchmarkUDPManagerMatchStreamIDChurn(b *testing.B) { + benchmarkUDPManager(b, true) +} + +func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) { + const flowCount = 5000 + flows := make([]gopacket.Flow, flowCount) + udps := make([]*layers.UDP, flowCount) + for i := 0; i < flowCount; i++ { + a := byte(i >> 8) + c := byte(i) + flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4()) + udps[i] = &layers.UDP{ + SrcPort: layers.UDPPort(1024 + i%20000), + DstPort: layers.UDPPort(20000 + (i*7)%20000), + BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}, + } + } + + streams := make(map[uint32]*legacyUDPStreamValue, flowCount) + keys := make([]uint32, 0, flowCount) + for i := 0; i < flowCount; i++ { + streamID := uint32(i + 1) + streams[streamID] = &legacyUDPStreamValue{ + IPFlow: flows[i], + UDPFlow: udps[i].TransportFlow(), + } + keys = append(keys, streamID) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + idx := i % flowCount + streamID := uint32((i % flowCount) + 1 + ((i / flowCount) * flowCount)) + if _, ok := streams[streamID]; ok { + continue + } + ipFlow := flows[idx] + udpFlow := udps[idx].TransportFlow() + for _, k := range keys { + v, ok := streams[k] + if !ok || v == nil { + continue + } + if (v.IPFlow == ipFlow && v.UDPFlow == udpFlow) || + (v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()) { + delete(streams, k) + streams[streamID] = v + break + } + } + } +} diff --git a/engine/udp_manager_tuple_test.go b/engine/udp_manager_tuple_test.go new file mode 100644 index 0000000..3d9e5f2 --- /dev/null +++ b/engine/udp_manager_tuple_test.go @@ -0,0 +1,71 @@ +package engine + +import ( + "net" + "sync/atomic" + "testing" + + "git.difuse.io/Difuse/Mellaris/analyzer" + "git.difuse.io/Difuse/Mellaris/ruleset" + + "github.com/bwmarrin/snowflake" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +type countingRuleset struct { + ans []analyzer.Analyzer +} + +func (r countingRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer { return r.ans } +func (r countingRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult { + return ruleset.MatchResult{Action: ruleset.ActionMaybe} +} + +type countingUDPAnalyzer struct{ newCalls *atomic.Uint64 } + +func (a countingUDPAnalyzer) Name() string { return "countudp" } +func (a countingUDPAnalyzer) Limit() int { return 0 } +func (a countingUDPAnalyzer) NewUDP(analyzer.UDPInfo, analyzer.Logger) analyzer.UDPStream { + a.newCalls.Add(1) + return countingUDPStream{} +} + +type countingUDPStream struct{} + +func (countingUDPStream) Feed(bool, []byte) (*analyzer.PropUpdate, bool) { return nil, false } +func (countingUDPStream) Close(bool) *analyzer.PropUpdate { return nil } + +func TestUDPStreamManagerRebindsByTupleInO1Path(t *testing.T) { + node, err := snowflake.NewNode(0) + if err != nil { + t.Fatalf("create node: %v", err) + } + var newCalls atomic.Uint64 + rs := countingRuleset{ans: []analyzer.Analyzer{countingUDPAnalyzer{newCalls: &newCalls}}} + factory := &udpStreamFactory{ + WorkerID: 0, + Logger: noopTestLogger{}, + Node: node, + Ruleset: rs, + } + mgr, err := newUDPStreamManager(factory, 64, &statsCounters{}) + if err != nil { + t.Fatalf("new manager: %v", err) + } + + ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) + udp := &layers.UDP{SrcPort: 50000, DstPort: 443, BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}} + + ctx1 := &udpContext{Verdict: udpVerdictAccept} + mgr.MatchWithContext(100, ipFlow, udp, ctx1) + if got := newCalls.Load(); got != 1 { + t.Fatalf("new stream calls=%d want=1", got) + } + + ctx2 := &udpContext{Verdict: udpVerdictAccept} + mgr.MatchWithContext(200, ipFlow, udp, ctx2) + if got := newCalls.Load(); got != 1 { + t.Fatalf("expected stream reuse by tuple, new stream calls=%d want=1", got) + } +} diff --git a/engine/worker.go b/engine/worker.go index 96fbe86..684acbe 100644 --- a/engine/worker.go +++ b/engine/worker.go @@ -12,24 +12,32 @@ import ( "github.com/google/gopacket/layers" ) -var _ Engine = (*engine)(nil) - type workerPacket struct { - StreamID uint32 - Data []byte - SrcMAC net.HardwareAddr - DstMAC net.HardwareAddr - SetVerdict func(io.Verdict, []byte) error + Packet io.Packet + StreamID uint32 + Data []byte + SrcMAC net.HardwareAddr + DstMAC net.HardwareAddr + Gen int64 +} + +type workerResult struct { + Packet io.Packet + StreamID uint32 + Verdict io.Verdict + ModifiedPacket []byte + Gen int64 } type worker struct { id int packetChan chan *workerPacket + resultChan chan workerResult logger Logger macResolver *sourceMACResolver - tcpFlowMgr *tcpFlowManager - udpSM *udpStreamManager + tcpFlowMgr *tcpFlowManager + udpSM *udpStreamManager modSerializeBuffer gopacket.SerializeBuffer } @@ -43,6 +51,9 @@ type workerConfig struct { TCPMaxBufferedPagesTotal int // unused, kept for config compat TCPMaxBufferedPagesPerConn int // unused, kept for config compat UDPMaxStreams int + AnalyzerSelectionMode AnalyzerSelectionMode + ResultChan chan workerResult + Stats *statsCounters } func (c *workerConfig) fillDefaults() { @@ -61,7 +72,8 @@ func newWorker(config workerConfig) (*worker, error) { return nil, err } - tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode) + selector := newAnalyzerSelector(config.AnalyzerSelectionMode, config.Stats) + tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode, selector) if config.Ruleset != nil { tcpMgr.updateRuleset(config.Ruleset, 0) } @@ -71,8 +83,10 @@ func newWorker(config workerConfig) (*worker, error) { Logger: config.Logger, Node: sfNode, Ruleset: config.Ruleset, + Selector: selector, + Stats: config.Stats, } - udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams) + udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams, config.Stats) if err != nil { return nil, err } @@ -80,6 +94,7 @@ func newWorker(config workerConfig) (*worker, error) { return &worker{ id: config.ID, packetChan: make(chan *workerPacket, config.ChanSize), + resultChan: config.ResultChan, logger: config.Logger, macResolver: config.MACResolver, tcpFlowMgr: tcpMgr, @@ -97,6 +112,10 @@ func (w *worker) Feed(p *workerPacket) bool { } } +func (w *worker) FeedBlocking(p *workerPacket) { + w.packetChan <- p +} + func (w *worker) Run(ctx context.Context) { w.logger.WorkerStart(w.id) defer w.logger.WorkerStop(w.id) @@ -109,7 +128,13 @@ func (w *worker) Run(ctx context.Context) { return } v, b := w.handle(wp) - _ = wp.SetVerdict(v, b) + w.resultChan <- workerResult{ + Packet: wp.Packet, + StreamID: wp.StreamID, + Verdict: v, + ModifiedPacket: b, + Gen: wp.Gen, + } } } } @@ -185,8 +210,7 @@ func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []by SrcMAC: srcMAC, DstMAC: dstMAC, } - // Temporarily set payload on a UDP layer so existing UDP handling works - // We pass the payload through the context + // Temporarily set payload on a UDP layer so existing UDP handling works. w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{ BaseLayer: layers.BaseLayer{Payload: payload}, SrcPort: layers.UDPPort(udp.SrcPort), diff --git a/io/nfqueue.go b/io/nfqueue.go index e880a4a..40e3ae8 100644 --- a/io/nfqueue.go +++ b/io/nfqueue.go @@ -1,3 +1,6 @@ +//go:build linux +// +build linux + package io import ( diff --git a/io/nfqueue_stub.go b/io/nfqueue_stub.go new file mode 100644 index 0000000..3f5ac54 --- /dev/null +++ b/io/nfqueue_stub.go @@ -0,0 +1,43 @@ +//go:build !linux +// +build !linux + +package io + +import ( + "context" + "errors" + "net" +) + +var errNFQueueUnsupported = errors.New("nfqueue packet io is only supported on linux") + +type NFQueuePacketIOConfig struct { + QueueSize uint32 + ReadBuffer int + WriteBuffer int + Local bool + RST bool + NumQueues int + MaxPacketLen uint32 +} + +func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) { + _ = config + return nil, errNFQueueUnsupported +} + +func (*unsupportedPacketIO) Register(context.Context, PacketCallback) error { + return errNFQueueUnsupported +} + +func (*unsupportedPacketIO) SetVerdict(Packet, Verdict, []byte) error { + return errNFQueueUnsupported +} + +func (*unsupportedPacketIO) ProtectedDialContext(context.Context, string, string) (net.Conn, error) { + return nil, errNFQueueUnsupported +} + +func (*unsupportedPacketIO) Close() error { return nil } + +type unsupportedPacketIO struct{} diff --git a/ruleset/builtins/geo/geo_matcher.go b/ruleset/builtins/geo/geo_matcher.go index 1bb0f30..7f07ec7 100644 --- a/ruleset/builtins/geo/geo_matcher.go +++ b/ruleset/builtins/geo/geo_matcher.go @@ -1,52 +1,45 @@ package geo import ( + "container/list" "net" + "sort" "strings" "sync" ) +const ( + geoSiteResultCacheSize = 1 << 16 + geoSiteSetResultCacheSize = 1 << 16 +) + type GeoMatcher struct { geoLoader GeoLoader geoSiteMatcher map[string]hostMatcher - siteMatcherLock sync.Mutex + siteMatcherLock sync.RWMutex + geoSiteSets map[string][]hostMatcher + siteSetLock sync.RWMutex geoIpMatcher map[string]hostMatcher - ipMatcherLock sync.Mutex + ipMatcherLock sync.RWMutex + geoSiteResult *boolLRUCache + geoSiteSetCache *boolLRUCache } func NewGeoMatcher(geoSiteFilename, geoIpFilename string) *GeoMatcher { return &GeoMatcher{ - geoLoader: NewDefaultGeoLoader(geoSiteFilename, geoIpFilename), - geoSiteMatcher: make(map[string]hostMatcher), - geoIpMatcher: make(map[string]hostMatcher), + geoLoader: NewDefaultGeoLoader(geoSiteFilename, geoIpFilename), + geoSiteMatcher: make(map[string]hostMatcher), + geoSiteSets: make(map[string][]hostMatcher), + geoIpMatcher: make(map[string]hostMatcher), + geoSiteResult: newBoolLRUCache(geoSiteResultCacheSize), + geoSiteSetCache: newBoolLRUCache(geoSiteSetResultCacheSize), } } func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool { - g.ipMatcherLock.Lock() - defer g.ipMatcherLock.Unlock() - - matcher, ok := g.geoIpMatcher[condition] - if !ok { - // GeoIP matcher - condition = strings.ToLower(condition) - country := condition - if len(country) == 0 { - return false - } - gMap, err := g.geoLoader.LoadGeoIP() - if err != nil { - return false - } - list, ok := gMap[country] - if !ok || list == nil { - return false - } - matcher, err = newGeoIPMatcher(list) - if err != nil { - return false - } - g.geoIpMatcher[condition] = matcher + matcher, ok := g.getOrCreateGeoIPMatcher(condition) + if !ok || matcher == nil { + return false } parseIp := net.ParseIP(ip) if parseIp == nil { @@ -64,32 +57,69 @@ func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool { } func (g *GeoMatcher) MatchGeoSite(site, condition string) bool { - g.siteMatcherLock.Lock() - defer g.siteMatcherLock.Unlock() - - matcher, ok := g.geoSiteMatcher[condition] - if !ok { - // MatchGeoSite matcher - condition = strings.ToLower(condition) - name, attrs := parseGeoSiteName(condition) - if len(name) == 0 { - return false - } - gMap, err := g.geoLoader.LoadGeoSite() - if err != nil { - return false - } - list, ok := gMap[name] - if !ok || list == nil { - return false - } - matcher, err = newGeositeMatcher(list, attrs) - if err != nil { - return false - } - g.geoSiteMatcher[condition] = matcher + conditionKey := strings.TrimSpace(strings.ToLower(condition)) + if conditionKey == "" { + return false } - return matcher.Match(HostInfo{Name: site}) + cacheKey := site + "\x1f" + conditionKey + if v, ok := g.geoSiteResult.Get(cacheKey); ok { + return v + } + matcher, ok := g.getOrCreateGeoSiteMatcher(condition) + if !ok || matcher == nil { + return false + } + result := matcher.Match(HostInfo{Name: site}) + g.geoSiteResult.Set(cacheKey, result) + return result +} + +func (g *GeoMatcher) MatchGeoSiteSet(site string, set *SiteConditionSet) bool { + if set == nil { + return false + } + conditions := normalizeGeoSiteSetConditions(set.Conditions) + if len(conditions) == 0 { + return false + } + key := strings.Join(conditions, "\x1f") + cacheKey := site + "\x1e" + key + if v, ok := g.geoSiteSetCache.Get(cacheKey); ok { + return v + } + + g.siteSetLock.RLock() + matchers, ok := g.geoSiteSets[key] + g.siteSetLock.RUnlock() + if !ok { + compiled := make([]hostMatcher, 0, len(conditions)) + for _, condition := range conditions { + m, ok := g.getOrCreateGeoSiteMatcher(condition) + if ok && m != nil { + compiled = append(compiled, m) + } + } + g.siteSetLock.Lock() + if existing, exists := g.geoSiteSets[key]; exists { + matchers = existing + } else { + g.geoSiteSets[key] = compiled + matchers = compiled + } + g.siteSetLock.Unlock() + } + if len(matchers) == 0 { + return false + } + host := HostInfo{Name: site} + for _, matcher := range matchers { + if matcher.Match(host) { + g.geoSiteSetCache.Set(cacheKey, true) + return true + } + } + g.geoSiteSetCache.Set(cacheKey, false) + return false } func (g *GeoMatcher) LoadGeoSite() error { @@ -111,3 +141,152 @@ func parseGeoSiteName(s string) (string, []string) { } return base, attrs } + +func (g *GeoMatcher) getOrCreateGeoSiteMatcher(condition string) (hostMatcher, bool) { + condition = strings.TrimSpace(strings.ToLower(condition)) + if condition == "" { + return nil, false + } + g.siteMatcherLock.RLock() + matcher, ok := g.geoSiteMatcher[condition] + g.siteMatcherLock.RUnlock() + if ok { + return matcher, true + } + + name, attrs := parseGeoSiteName(condition) + if len(name) == 0 { + return nil, false + } + gMap, err := g.geoLoader.LoadGeoSite() + if err != nil { + return nil, false + } + list, ok := gMap[name] + if !ok || list == nil { + return nil, false + } + matcher, err = newGeositeMatcher(list, attrs) + if err != nil { + return nil, false + } + g.siteMatcherLock.Lock() + if existing, exists := g.geoSiteMatcher[condition]; exists { + matcher = existing + } else { + g.geoSiteMatcher[condition] = matcher + } + g.siteMatcherLock.Unlock() + return matcher, true +} + +func (g *GeoMatcher) getOrCreateGeoIPMatcher(condition string) (hostMatcher, bool) { + condition = strings.TrimSpace(strings.ToLower(condition)) + if condition == "" { + return nil, false + } + g.ipMatcherLock.RLock() + matcher, ok := g.geoIpMatcher[condition] + g.ipMatcherLock.RUnlock() + if ok { + return matcher, true + } + + gMap, err := g.geoLoader.LoadGeoIP() + if err != nil { + return nil, false + } + list, ok := gMap[condition] + if !ok || list == nil { + return nil, false + } + matcher, err = newGeoIPMatcher(list) + if err != nil { + return nil, false + } + g.ipMatcherLock.Lock() + if existing, exists := g.geoIpMatcher[condition]; exists { + matcher = existing + } else { + g.geoIpMatcher[condition] = matcher + } + g.ipMatcherLock.Unlock() + return matcher, true +} + +func normalizeGeoSiteSetConditions(in []string) []string { + if len(in) == 0 { + return nil + } + out := make([]string, 0, len(in)) + seen := make(map[string]struct{}, len(in)) + for _, v := range in { + s := strings.TrimSpace(strings.ToLower(v)) + if s == "" { + continue + } + if _, ok := seen[s]; ok { + continue + } + seen[s] = struct{}{} + out = append(out, s) + } + sort.Strings(out) + return out +} + +type boolLRUCache struct { + mu sync.Mutex + cap int + ll *list.List + items map[string]*list.Element +} + +type boolCacheEntry struct { + key string + value bool +} + +func newBoolLRUCache(capacity int) *boolLRUCache { + if capacity <= 0 { + capacity = 1 + } + return &boolLRUCache{ + cap: capacity, + ll: list.New(), + items: make(map[string]*list.Element, capacity), + } +} + +func (c *boolLRUCache) Get(key string) (bool, bool) { + c.mu.Lock() + defer c.mu.Unlock() + if ele, ok := c.items[key]; ok { + c.ll.MoveToFront(ele) + entry := ele.Value.(boolCacheEntry) + return entry.value, true + } + return false, false +} + +func (c *boolLRUCache) Set(key string, value bool) { + c.mu.Lock() + defer c.mu.Unlock() + if ele, ok := c.items[key]; ok { + ele.Value = boolCacheEntry{key: key, value: value} + c.ll.MoveToFront(ele) + return + } + ele := c.ll.PushFront(boolCacheEntry{key: key, value: value}) + c.items[key] = ele + if c.ll.Len() <= c.cap { + return + } + back := c.ll.Back() + if back == nil { + return + } + entry := back.Value.(boolCacheEntry) + delete(c.items, entry.key) + c.ll.Remove(back) +} diff --git a/ruleset/builtins/geo/geo_matcher_test.go b/ruleset/builtins/geo/geo_matcher_test.go index a4bf580..5e3e860 100644 --- a/ruleset/builtins/geo/geo_matcher_test.go +++ b/ruleset/builtins/geo/geo_matcher_test.go @@ -1,13 +1,14 @@ package geo import ( + "sync/atomic" "testing" "git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo" ) type fakeGeoLoader struct { - geoip map[string]*v2geo.GeoIP + geoip map[string]*v2geo.GeoIP geosite map[string]*v2geo.GeoSite } @@ -110,6 +111,83 @@ func TestGeoMatcher_MatchGeoSite_MissingSite(t *testing.T) { } } +func TestGeoMatcher_MatchGeoSiteSet(t *testing.T) { + loader := &fakeGeoLoader{ + geosite: map[string]*v2geo.GeoSite{ + "openai": { + Domain: []*v2geo.Domain{ + {Type: v2geo.Domain_Plain, Value: "openai"}, + }, + }, + "google": { + Domain: []*v2geo.Domain{ + {Type: v2geo.Domain_RootDomain, Value: "google.com"}, + }, + }, + }, + } + g := NewGeoMatcher("", "") + g.geoLoader = loader + + set := &SiteConditionSet{Conditions: []string{" google ", "openai", "OPENAI"}} + if !g.MatchGeoSiteSet("api.openai.com", set) { + t.Error("MatchGeoSiteSet should match openai") + } + if !g.MatchGeoSiteSet("mail.google.com", set) { + t.Error("MatchGeoSiteSet should match google") + } + if g.MatchGeoSiteSet("example.com", set) { + t.Error("MatchGeoSiteSet should not match unrelated host") + } +} + +type countingMatcher struct { + calls *atomic.Uint64 + match bool +} + +func (m countingMatcher) Match(host HostInfo) bool { + _ = host + m.calls.Add(1) + return m.match +} + +func TestGeoMatcher_MatchGeoSite_UsesResultCache(t *testing.T) { + g := NewGeoMatcher("", "") + var calls atomic.Uint64 + g.geoSiteMatcher["openai"] = countingMatcher{calls: &calls, match: true} + + if !g.MatchGeoSite("api.openai.com", "openai") { + t.Fatal("expected match") + } + if !g.MatchGeoSite("api.openai.com", "openai") { + t.Fatal("expected cached match") + } + if got := calls.Load(); got != 1 { + t.Fatalf("matcher calls=%d want=1", got) + } +} + +func TestGeoMatcher_MatchGeoSiteSet_UsesResultCache(t *testing.T) { + g := NewGeoMatcher("", "") + var calls atomic.Uint64 + g.geoSiteSets["openai\x1fyoutube"] = []hostMatcher{ + countingMatcher{calls: &calls, match: false}, + countingMatcher{calls: &calls, match: true}, + } + + set := &SiteConditionSet{Conditions: []string{"youtube", "openai"}} + if !g.MatchGeoSiteSet("www.youtube.com", set) { + t.Fatal("expected match") + } + if !g.MatchGeoSiteSet("www.youtube.com", set) { + t.Fatal("expected cached match") + } + if got := calls.Load(); got != 2 { + t.Fatalf("matcher calls=%d want=2", got) + } +} + func ipv4(a, b, c, d byte) []byte { return []byte{a, b, c, d} } diff --git a/ruleset/builtins/geo/interface.go b/ruleset/builtins/geo/interface.go index 81c8bdc..79bb094 100644 --- a/ruleset/builtins/geo/interface.go +++ b/ruleset/builtins/geo/interface.go @@ -13,6 +13,10 @@ type HostInfo struct { IPv6 net.IP } +type SiteConditionSet struct { + Conditions []string +} + func (h HostInfo) String() string { return fmt.Sprintf("%s|%s|%s", h.Name, h.IPv4, h.IPv6) } diff --git a/ruleset/expr.go b/ruleset/expr.go index 5c76c3d..a1bc2b9 100644 --- a/ruleset/expr.go +++ b/ruleset/expr.go @@ -60,8 +60,8 @@ type compiledExprRule struct { ModInstance modifier.Instance Program *vm.Program GeoSiteConditions []string - StartTimeSecs int // seconds since midnight, -1 if unset - StopTimeSecs int // seconds since midnight, -1 if unset + StartTimeSecs int // seconds since midnight, -1 if unset + StopTimeSecs int // seconds since midnight, -1 if unset Weekdays []time.Weekday WeekdaysNegated bool } @@ -86,6 +86,7 @@ type exprRuleset struct { Ans []analyzer.Analyzer Logger Logger GeoMatcher *geo.GeoMatcher + stats *statsCounters } func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer { @@ -93,9 +94,24 @@ func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer { } func (r *exprRuleset) Match(info StreamInfo) MatchResult { + start := time.Now() + if r.stats != nil { + r.stats.MatchCalls.Add(1) + defer func() { + r.stats.MatchLatencyNanos.Add(uint64(time.Since(start).Nanoseconds())) + }() + } + env := envPool.Get().(map[string]any) clear(env) - populateExprEnv(env, info) + macMap, ipMap, portMap := populateExprEnv(env, info) + releaseEnv := func() { + clear(env) + envPool.Put(env) + putSubMap(macMap) + putSubMap(ipMap) + putSubMap(portMap) + } now := time.Now() for _, rule := range r.Rules { if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) { @@ -103,6 +119,9 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult { } v, err := vm.Run(rule.Program, env) if err != nil { + if r.stats != nil { + r.stats.MatchErrors.Add(1) + } r.Logger.MatchError(info, rule.Name, err) continue } @@ -115,7 +134,7 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult { r.Logger.Log(logInfo, rule.Name) } if rule.Action != nil { - envPool.Put(env) + releaseEnv() return MatchResult{ Action: *rule.Action, ModInstance: rule.ModInstance, @@ -123,12 +142,26 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult { } } } - envPool.Put(env) + releaseEnv() return MatchResult{ Action: ActionMaybe, } } +func (r *exprRuleset) Stats() Stats { + if r == nil || r.stats == nil { + return Stats{} + } + return Stats{ + MatchCalls: r.stats.MatchCalls.Load(), + MatchErrors: r.stats.MatchErrors.Load(), + MatchLatencyNanos: r.stats.MatchLatencyNanos.Load(), + LookupCalls: r.stats.LookupCalls.Load(), + LookupErrors: r.stats.LookupErrors.Load(), + LookupLatencyNanos: r.stats.LookupLatencyNanos.Load(), + } +} + // CompileExprRules compiles a list of expression rules into a ruleset. // It returns an error if any of the rules are invalid, or if any of the analyzers // used by the rules are unknown (not provided in the analyzer list). @@ -137,7 +170,8 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier fullAnMap := analyzersToMap(ans) fullModMap := modifiersToMap(mods) depAnMap := make(map[string]analyzer.Analyzer) - funcMap, geoMatcher := buildFunctionMap(config) + stats := &statsCounters{} + funcMap, geoMatcher := buildFunctionMap(config, stats) // Compile all rules and build a map of analyzers that are used by the rules. for _, rule := range rules { if rule.Action == "" && !rule.Log { @@ -152,7 +186,7 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier action = &a } visitor := &idVisitor{Variables: make(map[string]bool), Identifiers: make(map[string]bool)} - patcher := &idPatcher{FuncMap: funcMap} + patcher := &idPatcher{FuncMap: funcMap, GeoMatcher: geoMatcher} program, err := expr.Compile(rule.Expr, func(c *conf.Config) { c.Strict = false @@ -242,29 +276,47 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier Ans: depAns, Logger: config.Logger, GeoMatcher: geoMatcher, + stats: stats, }, nil } -func populateExprEnv(m map[string]any, info StreamInfo) { +func populateExprEnv(m map[string]any, info StreamInfo) (macMap, ipMap, portMap map[string]any) { + macMap = getSubMap() + ipMap = getSubMap() + portMap = getSubMap() + + macMap["src"] = info.SrcMAC.String() + macMap["dst"] = info.DstMAC.String() + ipMap["src"] = info.SrcIP.String() + ipMap["dst"] = info.DstIP.String() + portMap["src"] = info.SrcPort + portMap["dst"] = info.DstPort + m["id"] = info.ID m["proto"] = info.Protocol.String() - m["mac"] = map[string]string{ - "src": info.SrcMAC.String(), - "dst": info.DstMAC.String(), - } - m["ip"] = map[string]string{ - "src": info.SrcIP.String(), - "dst": info.DstIP.String(), - } - m["port"] = map[string]uint16{ - "src": info.SrcPort, - "dst": info.DstPort, - } + m["mac"] = macMap + m["ip"] = ipMap + m["port"] = portMap for anName, anProps := range info.Props { if len(anProps) != 0 { m[anName] = anProps } } + return macMap, ipMap, portMap +} + +func getSubMap() map[string]any { + m := subMapPool.Get().(map[string]any) + clear(m) + return m +} + +func putSubMap(m map[string]any) { + if m == nil { + return + } + clear(m) + subMapPool.Put(m) } func isBuiltInAnalyzer(name string) bool { @@ -329,11 +381,15 @@ func (v *idVisitor) Visit(node *ast.Node) { // idPatcher patches the AST during expr compilation, replacing certain values with // their internal representations for better runtime performance. type idPatcher struct { - FuncMap map[string]*Function - Err error + FuncMap map[string]*Function + GeoMatcher *geo.GeoMatcher + Err error } func (p *idPatcher) Visit(node *ast.Node) { + if p.tryPatchGeoSiteORChain(node) { + return + } switch (*node).(type) { case *ast.CallNode: callNode := (*node).(*ast.CallNode) @@ -352,6 +408,108 @@ func (p *idPatcher) Visit(node *ast.Node) { } } +func (p *idPatcher) tryPatchGeoSiteORChain(node *ast.Node) bool { + if p == nil || p.GeoMatcher == nil { + return false + } + terms, ok := collectGeoSiteORChain(*node) + if !ok || len(terms) < 2 { + return false + } + hostExpr := strings.TrimSpace(terms[0].hostExpr) + if hostExpr == "" { + return false + } + conditions := make([]string, 0, len(terms)) + for _, term := range terms { + if strings.TrimSpace(term.hostExpr) != hostExpr { + return false + } + conditions = append(conditions, term.condition) + } + normalized := normalizeUniqueLowerStrings(conditions) + if len(normalized) < 2 { + return false + } + hostNode, err := parser.Parse(hostExpr) + if err != nil || hostNode == nil || hostNode.Node == nil { + return false + } + call := &ast.CallNode{ + Callee: &ast.IdentifierNode{Value: "geosite_set"}, + Arguments: []ast.Node{ + hostNode.Node, + &ast.ConstantNode{Value: &geo.SiteConditionSet{Conditions: normalized}}, + }, + } + ast.Patch(node, call) + return true +} + +type geositeTerm struct { + hostExpr string + condition string +} + +func collectGeoSiteORChain(node ast.Node) ([]geositeTerm, bool) { + switch n := node.(type) { + case *ast.BinaryNode: + if n.Operator != "or" && n.Operator != "||" { + return nil, false + } + left, ok := collectGeoSiteORChain(n.Left) + if !ok { + return nil, false + } + right, ok := collectGeoSiteORChain(n.Right) + if !ok { + return nil, false + } + out := make([]geositeTerm, 0, len(left)+len(right)) + out = append(out, left...) + out = append(out, right...) + return out, true + case *ast.CallNode: + idNode, ok := n.Callee.(*ast.IdentifierNode) + if !ok || len(n.Arguments) < 2 { + return nil, false + } + name := strings.ToLower(idNode.Value) + if name == "geosite" { + condNode, ok := n.Arguments[1].(*ast.StringNode) + if !ok { + return nil, false + } + return []geositeTerm{{ + hostExpr: n.Arguments[0].String(), + condition: condNode.Value, + }}, true + } + if name != "geosite_set" { + return nil, false + } + setNode, ok := n.Arguments[1].(*ast.ConstantNode) + if !ok || setNode.Value == nil { + return nil, false + } + set, ok := setNode.Value.(*geo.SiteConditionSet) + if !ok || set == nil { + return nil, false + } + if len(set.Conditions) == 0 { + return nil, false + } + out := make([]geositeTerm, 0, len(set.Conditions)) + hostExpr := n.Arguments[0].String() + for _, condition := range set.Conditions { + out = append(out, geositeTerm{hostExpr: hostExpr, condition: condition}) + } + return out, true + default: + return nil, false + } +} + type Function struct { InitFunc func() error PatchFunc func(args *[]ast.Node) error @@ -359,7 +517,7 @@ type Function struct { Types []reflect.Type } -func buildFunctionMap(config *BuiltinConfig) (map[string]*Function, *geo.GeoMatcher) { +func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*Function, *geo.GeoMatcher) { geoMatcher := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename) return map[string]*Function{ "geoip": { @@ -378,6 +536,16 @@ func buildFunctionMap(config *BuiltinConfig) (map[string]*Function, *geo.GeoMatc }, Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)}, }, + "geosite_set": { + InitFunc: geoMatcher.LoadGeoSite, + PatchFunc: nil, + Func: func(params ...any) (any, error) { + return geoMatcher.MatchGeoSiteSet(params[0].(string), params[1].(*geo.SiteConditionSet)), nil + }, + Types: []reflect.Type{ + reflect.TypeOf((func(string, *geo.SiteConditionSet) bool)(nil)), + }, + }, "cidr": { InitFunc: nil, PatchFunc: func(args *[]ast.Node) error { @@ -425,9 +593,20 @@ func buildFunctionMap(config *BuiltinConfig) (map[string]*Function, *geo.GeoMatc return nil }, Func: func(params ...any) (any, error) { + start := time.Now() + if stats != nil { + stats.LookupCalls.Add(1) + defer func() { + stats.LookupLatencyNanos.Add(uint64(time.Since(start).Nanoseconds())) + }() + } ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) defer cancel() - return params[1].(*net.Resolver).LookupHost(ctx, params[0].(string)) + out, err := params[1].(*net.Resolver).LookupHost(ctx, params[0].(string)) + if err != nil && stats != nil { + stats.LookupErrors.Add(1) + } + return out, err }, Types: []reflect.Type{ reflect.TypeOf((func(string, *net.Resolver) []string)(nil)), diff --git a/ruleset/expr_test.go b/ruleset/expr_test.go index 23345f0..320a9f2 100644 --- a/ruleset/expr_test.go +++ b/ruleset/expr_test.go @@ -2,9 +2,14 @@ package ruleset import ( "reflect" + "strings" "testing" "git.difuse.io/Difuse/Mellaris/analyzer" + "git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo" + + "github.com/expr-lang/expr/ast" + "github.com/expr-lang/expr/parser" ) func TestExtractGeoSiteConditions(t *testing.T) { @@ -63,3 +68,23 @@ func TestMatchGeoSiteConditions(t *testing.T) { t.Fatalf("matchGeoSiteConditions() = %v, want %v", got, want) } } + +func TestIDPatcher_PatchesGeoSiteORChainToGeoSiteSet(t *testing.T) { + tree, err := parser.Parse(`geosite(tls.req.sni, "google") || geosite(tls.req.sni, "youtube") || geosite(tls.req.sni, "openai")`) + if err != nil { + t.Fatalf("parse expression: %v", err) + } + root := tree.Node + patcher := &idPatcher{GeoMatcher: geo.NewGeoMatcher("", "")} + ast.Walk(&root, patcher) + if patcher.Err != nil { + t.Fatalf("patch error: %v", patcher.Err) + } + got := root.String() + if !strings.Contains(got, "geosite_set(") { + t.Fatalf("expected geosite_set rewrite, got %q", got) + } + if strings.Contains(got, "||") || strings.Contains(got, " or ") { + t.Fatalf("expected OR chain to be collapsed, got %q", got) + } +} diff --git a/ruleset/interface.go b/ruleset/interface.go index f7821eb..7351f0a 100644 --- a/ruleset/interface.go +++ b/ruleset/interface.go @@ -4,6 +4,7 @@ import ( "context" "net" "strconv" + "sync/atomic" "git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/modifier" @@ -95,6 +96,28 @@ type Ruleset interface { Match(StreamInfo) MatchResult } +type Stats struct { + MatchCalls uint64 + MatchErrors uint64 + MatchLatencyNanos uint64 + LookupCalls uint64 + LookupErrors uint64 + LookupLatencyNanos uint64 +} + +type statsCounters struct { + MatchCalls atomic.Uint64 + MatchErrors atomic.Uint64 + MatchLatencyNanos atomic.Uint64 + LookupCalls atomic.Uint64 + LookupErrors atomic.Uint64 + LookupLatencyNanos atomic.Uint64 +} + +type StatsProvider interface { + Stats() Stats +} + // Logger is the logging interface for the ruleset. type Logger interface { Log(info StreamInfo, name string) diff --git a/stats.go b/stats.go new file mode 100644 index 0000000..43a7a58 --- /dev/null +++ b/stats.go @@ -0,0 +1,22 @@ +package mellaris + +import ( + "git.difuse.io/Difuse/Mellaris/engine" + "git.difuse.io/Difuse/Mellaris/ruleset" +) + +type Stats struct { + Engine engine.Stats + Ruleset ruleset.Stats +} + +func (a *App) Stats() Stats { + if a == nil || a.engine == nil { + return Stats{} + } + out := Stats{Engine: a.engine.Stats()} + if rs, ok := a.ruleset.(ruleset.StatsProvider); ok { + out.Ruleset = rs.Stats() + } + return out +}