diff --git a/engine/engine.go b/engine/engine.go index a221187..500363c 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -27,6 +27,7 @@ const ( type engine struct { logger Logger io io.PacketIO + macResolver *sourceMACResolver workers []*worker stats *statsCounters verdicts sync.Map // streamID(uint32) -> verdictEntry @@ -46,7 +47,7 @@ func NewEngine(config Config) (Engine, error) { } overflowPolicy := config.OverflowPolicy if overflowPolicy == "" { - overflowPolicy = OverflowPolicyDrop + overflowPolicy = OverflowPolicyAccept } selectionMode := config.AnalyzerSelectionMode if selectionMode == "" { @@ -80,6 +81,7 @@ func NewEngine(config Config) (Engine, error) { e := &engine{ logger: config.Logger, io: config.IO, + macResolver: macResolver, workers: workers, stats: stats, overflowPolicy: overflowPolicy, @@ -105,6 +107,9 @@ func (e *engine) Run(ctx context.Context) error { for _, w := range e.workers { go w.Run(ioCtx) } + if e.macResolver != nil { + go e.macResolver.Run(ioCtx) + } go e.drainResults(ioCtx) go e.sweepVerdicts(ioCtx) diff --git a/engine/mac_resolver.go b/engine/mac_resolver.go index 30c1168..99dbf68 100644 --- a/engine/mac_resolver.go +++ b/engine/mac_resolver.go @@ -5,6 +5,7 @@ package engine import ( "bufio" + "context" "net" "os" "os/exec" @@ -52,38 +53,6 @@ func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr { return nil } - now := time.Now() - r.mu.RLock() - ifaceRefreshDue := now.Sub(r.lastIfaceRefresh) > ifaceCacheTTL - arpRefreshDue := now.Sub(r.lastARPRefresh) > arpCacheTTL - ndpRefreshDue := now.Sub(r.lastNDPRefresh) > ndpCacheTTL - if mac := r.ifaceByIP[ipKey]; len(mac) != 0 { - out := append(net.HardwareAddr(nil), mac...) - r.mu.RUnlock() - return out - } - if mac := r.arpByIP[ipKey]; len(mac) != 0 && !arpRefreshDue { - out := append(net.HardwareAddr(nil), mac...) - r.mu.RUnlock() - return out - } - if mac := r.ndpByIP[ipKey]; len(mac) != 0 && !ndpRefreshDue { - out := append(net.HardwareAddr(nil), mac...) - r.mu.RUnlock() - return out - } - r.mu.RUnlock() - - if ifaceRefreshDue { - r.refreshIfaceCache(now) - } - if arpRefreshDue { - r.refreshARPCache(now) - } - if ndpRefreshDue { - r.refreshNDPCache(now) - } - r.mu.RLock() defer r.mu.RUnlock() if mac := r.ifaceByIP[ipKey]; len(mac) != 0 { @@ -95,18 +64,38 @@ func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr { if mac := r.ndpByIP[ipKey]; len(mac) != 0 { return append(net.HardwareAddr(nil), mac...) } + return nil +} - // On-demand IPv6 neighbor lookup via route-netlink as a last fast path. - if ip.To4() == nil { - if mac, ok := lookupNeighborMACNetlink(ip); ok { - out := append(net.HardwareAddr(nil), mac...) - r.mu.Lock() - r.ndpByIP[ipKey] = append(net.HardwareAddr(nil), mac...) - r.mu.Unlock() - return out +func (r *sourceMACResolver) Run(ctx context.Context) { + r.refreshAll(time.Now()) + ticker := time.NewTicker(arpCacheTTL) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case now := <-ticker.C: + r.refreshAll(now) } } - return nil +} + +func (r *sourceMACResolver) refreshAll(now time.Time) { + r.mu.RLock() + ifaceRefreshDue := now.Sub(r.lastIfaceRefresh) > ifaceCacheTTL + arpRefreshDue := now.Sub(r.lastARPRefresh) > arpCacheTTL + ndpRefreshDue := now.Sub(r.lastNDPRefresh) > ndpCacheTTL + r.mu.RUnlock() + if ifaceRefreshDue { + r.refreshIfaceCache(now) + } + if arpRefreshDue { + r.refreshARPCache(now) + } + if ndpRefreshDue { + r.refreshNDPCache(now) + } } func (r *sourceMACResolver) refreshIfaceCache(now time.Time) { diff --git a/engine/mac_resolver_stub.go b/engine/mac_resolver_stub.go index 4cf4e55..bd27c3d 100644 --- a/engine/mac_resolver_stub.go +++ b/engine/mac_resolver_stub.go @@ -3,7 +3,10 @@ package engine -import "net" +import ( + "context" + "net" +) type sourceMACResolver struct{} @@ -15,3 +18,7 @@ func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr { _ = ip return nil } + +func (r *sourceMACResolver) Run(ctx context.Context) { + <-ctx.Done() +} diff --git a/engine/overflow_test.go b/engine/overflow_test.go new file mode 100644 index 0000000..99abc33 --- /dev/null +++ b/engine/overflow_test.go @@ -0,0 +1,70 @@ +package engine + +import ( + "context" + "net" + "testing" + + "git.difuse.io/Difuse/Mellaris/io" +) + +type recordingPacket struct { + streamID uint32 + data []byte +} + +func (p recordingPacket) StreamID() uint32 { return p.streamID } +func (p recordingPacket) Data() []byte { return p.data } + +type recordingPacketIO struct { + verdict io.Verdict +} + +func (r *recordingPacketIO) Register(context.Context, io.PacketCallback) error { return nil } +func (r *recordingPacketIO) SetVerdict(_ io.Packet, v io.Verdict, _ []byte) error { + r.verdict = v + return nil +} +func (r *recordingPacketIO) ProtectedDialContext(context.Context, string, string) (net.Conn, error) { + return nil, nil +} +func (r *recordingPacketIO) Close() error { return nil } + +func TestEngineDefaultOverflowPolicyAccepts(t *testing.T) { + packetIO := &recordingPacketIO{} + eng, err := NewEngine(Config{ + Logger: noopTestLogger{}, + IO: packetIO, + Workers: 1, + WorkerQueueSize: 1, + }) + if err != nil { + t.Fatalf("NewEngine error: %v", err) + } + e := eng.(*engine) + if e.overflowPolicy != OverflowPolicyAccept { + t.Fatalf("overflow policy=%v want=%v", e.overflowPolicy, OverflowPolicyAccept) + } + + e.workers[0].packetChan <- &workerPacket{} + packet := recordingPacket{ + streamID: 1, + data: serializeIPv6TCP( + t, + net.ParseIP("2001:db8::11").To16(), + net.ParseIP("2001:db8::22").To16(), + 42310, + 443, + 1000, + ), + } + + e.dispatch(packet) + stats := e.Stats() + if packetIO.verdict != io.VerdictAccept { + t.Fatalf("overflow verdict=%v want=%v", packetIO.verdict, io.VerdictAccept) + } + if stats.OverflowEvents != 1 || stats.OverflowAccepts != 1 || stats.OverflowDrops != 0 { + t.Fatalf("overflow stats=%+v", stats) + } +} diff --git a/engine/reload_rules_test.go b/engine/reload_rules_test.go index 147457c..06dfdaf 100644 --- a/engine/reload_rules_test.go +++ b/engine/reload_rules_test.go @@ -22,6 +22,89 @@ func (r fixedRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult { return ruleset.MatchResult{Action: r.action} } +type analyzerRuleset struct { + action ruleset.Action + ans []analyzer.Analyzer +} + +func (r analyzerRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer { + return r.ans +} + +func (r analyzerRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult { + return ruleset.MatchResult{Action: r.action} +} + +type countingTCPAnalyzer struct { + newCalls *int + feedCalls *int +} + +func (a countingTCPAnalyzer) Name() string { return "tls" } +func (a countingTCPAnalyzer) Limit() int { return 0 } +func (a countingTCPAnalyzer) NewTCP(analyzer.TCPInfo, analyzer.Logger) analyzer.TCPStream { + (*a.newCalls)++ + return countingTCPStream{feedCalls: a.feedCalls} +} + +type countingTCPStream struct { + feedCalls *int +} + +func (s countingTCPStream) Feed(bool, bool, bool, int, []byte) (*analyzer.PropUpdate, bool) { + (*s.feedCalls)++ + return nil, false +} + +func (s countingTCPStream) Close(bool) *analyzer.PropUpdate { + return nil +} + +type logFinalizingRuleset struct { + ans []analyzer.Analyzer +} + +func (r logFinalizingRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer { + return r.ans +} + +func (r logFinalizingRuleset) Match(info ruleset.StreamInfo) ruleset.MatchResult { + if _, ok := info.Props["tls"]; ok { + return ruleset.MatchResult{Action: ruleset.ActionMaybe, Logged: true} + } + return ruleset.MatchResult{Action: ruleset.ActionMaybe} +} + +func (r logFinalizingRuleset) CanFinalizeAfterLog(ruleset.StreamInfo, []string) bool { + return true +} + +type requestPropTCPAnalyzer struct { + closeCalls *int +} + +func (a requestPropTCPAnalyzer) Name() string { return "tls" } +func (a requestPropTCPAnalyzer) Limit() int { return 0 } +func (a requestPropTCPAnalyzer) NewTCP(analyzer.TCPInfo, analyzer.Logger) analyzer.TCPStream { + return requestPropTCPStream{closeCalls: a.closeCalls} +} + +type requestPropTCPStream struct { + closeCalls *int +} + +func (s requestPropTCPStream) Feed(bool, bool, bool, int, []byte) (*analyzer.PropUpdate, bool) { + return &analyzer.PropUpdate{ + Type: analyzer.PropUpdateMerge, + M: analyzer.PropMap{"req": analyzer.PropMap{"sni": "good.example"}}, + }, false +} + +func (s requestPropTCPStream) Close(bool) *analyzer.PropUpdate { + (*s.closeCalls)++ + return nil +} + type noopTestLogger struct{} func (noopTestLogger) WorkerStart(int) {} @@ -207,3 +290,90 @@ func TestTCPFlowReevaluatesAfterRulesetVersionChange(t *testing.T) { t.Fatalf("cached verdict after update=%v want=%v", v, io.VerdictDropStream) } } + +func TestTCPFlowDelaysAnalyzerCreationUntilPayload(t *testing.T) { + node, err := snowflake.NewNode(0) + if err != nil { + t.Fatalf("create node: %v", err) + } + newCalls := 0 + feedCalls := 0 + mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{})) + mgr.updateRuleset(analyzerRuleset{ + action: ruleset.ActionMaybe, + ans: []analyzer.Analyzer{countingTCPAnalyzer{ + newCalls: &newCalls, + feedCalls: &feedCalls, + }}, + }, 0) + + l3 := L3Info{ + Version: 4, + Protocol: 6, + SrcIP: [4]byte{10, 0, 0, 1}, + DstIP: [4]byte{10, 0, 0, 2}, + } + tcp := TCPInfo{ + SrcPort: 12345, + DstPort: 443, + Seq: 100, + } + + v := mgr.handle(1, l3, tcp, nil, nil, nil) + if v != io.VerdictAccept { + t.Fatalf("empty packet verdict=%v want=%v", v, io.VerdictAccept) + } + if newCalls != 0 || feedCalls != 0 { + t.Fatalf("empty packet created/feed analyzer: new=%d feed=%d", newCalls, feedCalls) + } + + tcp.Seq = 101 + v = mgr.handle(1, l3, tcp, []byte{0x16, 0x03, 0x01}, nil, nil) + if v != io.VerdictAccept { + t.Fatalf("payload verdict=%v want=%v", v, io.VerdictAccept) + } + if newCalls != 1 || feedCalls != 1 { + t.Fatalf("payload should create/feed analyzer once: new=%d feed=%d", newCalls, feedCalls) + } +} + +func TestTCPFlowFinalizesAfterLogClassification(t *testing.T) { + node, err := snowflake.NewNode(0) + if err != nil { + t.Fatalf("create node: %v", err) + } + closeCalls := 0 + mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{})) + mgr.updateRuleset(logFinalizingRuleset{ + ans: []analyzer.Analyzer{requestPropTCPAnalyzer{closeCalls: &closeCalls}}, + }, 0) + + l3 := L3Info{ + Version: 4, + Protocol: 6, + SrcIP: [4]byte{10, 0, 0, 1}, + DstIP: [4]byte{10, 0, 0, 2}, + } + tcp := TCPInfo{ + SrcPort: 12345, + DstPort: 443, + Seq: 100, + } + + v := mgr.handle(1, l3, tcp, nil, nil, nil) + if v != io.VerdictAccept { + t.Fatalf("empty packet verdict=%v want=%v", v, io.VerdictAccept) + } + + tcp.Seq = 101 + v = mgr.handle(1, l3, tcp, []byte{0x16, 0x03, 0x01}, nil, nil) + if v != io.VerdictAcceptStream { + t.Fatalf("payload verdict=%v want=%v", v, io.VerdictAcceptStream) + } + if closeCalls != 1 { + t.Fatalf("expected analyzer to be closed once after finalization, got %d", closeCalls) + } + if _, ok := mgr.flows[1]; ok { + t.Fatal("expected finalized TCP flow to be removed from manager") + } +} diff --git a/engine/tcp_flow.go b/engine/tcp_flow.go index 2d8e944..45041d0 100644 --- a/engine/tcp_flow.go +++ b/engine/tcp_flow.go @@ -27,6 +27,8 @@ type tcpFlow struct { streamID uint32 srcPort uint16 dstPort uint16 + srcIP net.IP + dstIP net.IP dirSeq [2]uint32 dirBuf [2][]byte @@ -41,6 +43,9 @@ type tcpFlow struct { lastVerdict io.Verdict feedCalled [2]bool lastSeen time.Time + + pendingAnalyzers []analyzer.Analyzer + selector *analyzerSelector } type tcpFlowEntry struct { @@ -54,7 +59,7 @@ func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict { rs, version := f.currentRuleset() rulesetChanged := version != f.rulesetVersion - if !f.virgin && !rulesetChanged && len(f.activeEntries) == 0 { + if !f.virgin && !rulesetChanged && len(f.activeEntries) == 0 && !f.hasPendingAnalyzers() { return f.lastVerdict } @@ -68,6 +73,9 @@ func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict { propUpdated := false if len(payload) > 0 { dir, rev := f.resolveDirection(tcp) + if len(f.pendingAnalyzers) > 0 { + f.initPendingAnalyzers(payload) + } expected := f.dirSeq[dir] if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected { f.feedCalled[dir] = true @@ -108,6 +116,52 @@ func (f *tcpFlow) feedAnalyzers(rev bool) bool { return updated } +func (f *tcpFlow) initPendingAnalyzers(payload []byte) { + baseAns := f.pendingAnalyzers + f.pendingAnalyzers = nil + if f.selector != nil { + baseAns = f.selector.SelectTCP(baseAns, payload) + } + ans := analyzersToTCPAnalyzers(baseAns) + if len(ans) == 0 { + return + } + entries := make([]*tcpFlowEntry, 0, len(ans)) + for _, a := range ans { + entries = append(entries, &tcpFlowEntry{ + Name: a.Name(), + Stream: a.NewTCP(analyzer.TCPInfo{ + SrcIP: f.srcIP, + DstIP: f.dstIP, + SrcPort: f.srcPort, + DstPort: f.dstPort, + }, &analyzerLogger{ + StreamID: f.info.ID, + Name: a.Name(), + Logger: f.logger, + }), + HasLimit: a.Limit() > 0, + Quota: a.Limit(), + }) + } + f.activeEntries = append(f.activeEntries, entries...) +} + +func (f *tcpFlow) hasPendingAnalyzers() bool { + return len(f.pendingAnalyzers) > 0 +} + +func (f *tcpFlow) analyzerNames() []string { + names := make([]string, 0, len(f.activeEntries)+len(f.pendingAnalyzers)) + for _, entry := range f.activeEntries { + names = append(names, entry.Name) + } + for _, a := range f.pendingAnalyzers { + names = append(names, a.Name()) + } + return names +} + func (f *tcpFlow) runMatch(rs ruleset.Ruleset, version uint64, rulesetChanged bool, propUpdated bool) { if !propUpdated && !f.virgin && !rulesetChanged { return @@ -125,11 +179,15 @@ func (f *tcpFlow) runMatch(rs ruleset.Ruleset, version uint64, rulesetChanged bo f.lastVerdict = verdict f.closeActiveEntries() f.logger.TCPStreamAction(f.info, action, false) + } else if result.Logged && canFinalizeAfterLog(rs, f.info, f.analyzerNames()) { + f.lastVerdict = io.VerdictAcceptStream + f.closeActiveEntries() + f.logger.TCPStreamAction(f.info, ruleset.ActionAllow, true) } } func (f *tcpFlow) maybeFinalizeVerdict() { - if len(f.activeEntries) == 0 && f.lastVerdict == io.VerdictAccept { + if len(f.activeEntries) == 0 && !f.hasPendingAnalyzers() && f.lastVerdict == io.VerdictAccept { f.lastVerdict = io.VerdictAcceptStream f.logger.TCPStreamAction(f.info, ruleset.ActionAllow, true) } @@ -231,8 +289,10 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay var ans []analyzer.TCPAnalyzer if rs != nil { baseAns := rs.Analyzers(info) - baseAns = m.selector.SelectTCP(baseAns, payload) - ans = analyzersToTCPAnalyzers(baseAns) + if len(payload) > 0 { + baseAns = m.selector.SelectTCP(baseAns, payload) + ans = analyzersToTCPAnalyzers(baseAns) + } } entries := make([]*tcpFlowEntry, 0, len(ans)) for _, a := range ans { @@ -257,6 +317,8 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay streamID: streamID, srcPort: tcp.SrcPort, dstPort: tcp.DstPort, + srcIP: ipSrc, + dstIP: ipDst, info: info, virgin: true, logger: m.logger, @@ -265,6 +327,10 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay activeEntries: entries, lastVerdict: io.VerdictAccept, lastSeen: time.Now(), + selector: m.selector, + } + if len(payload) == 0 && rs != nil { + flow.pendingAnalyzers = rs.Analyzers(info) } flow.dirSeq[tcpDirC2S] = tcp.Seq + 1 return flow @@ -325,3 +391,8 @@ func actionToTCPVerdict(a ruleset.Action) io.Verdict { return io.VerdictAcceptStream } } + +func canFinalizeAfterLog(rs ruleset.Ruleset, info ruleset.StreamInfo, activeAnalyzers []string) bool { + finalizer, ok := rs.(ruleset.LogFinalizer) + return ok && finalizer.CanFinalizeAfterLog(info, activeAnalyzers) +} diff --git a/engine/udp.go b/engine/udp.go index e732c23..2339b9f 100644 --- a/engine/udp.go +++ b/engine/udp.go @@ -2,6 +2,7 @@ package engine import ( "bytes" + "container/list" "errors" "net" "sync" @@ -12,7 +13,6 @@ import ( "git.difuse.io/Difuse/Mellaris/ruleset" "github.com/bwmarrin/snowflake" - lru "github.com/hashicorp/golang-lru/v2" ) // udpVerdict is a subset of io.Verdict for UDP streams. @@ -116,15 +116,18 @@ func (f *udpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) { type udpStreamManager struct { factory *udpStreamFactory - streams *lru.Cache[uint32, *udpStreamValue] + streams map[uint32]*list.Element + order *list.List + maxStreams int tupleIndex map[udpTupleKey]uint32 streamTuples map[uint32]udpTupleKey stats *statsCounters } type udpStreamValue struct { - Stream *udpStream - Tuple udpTupleKey + StreamID uint32 + Stream *udpStream + Tuple udpTupleKey } func (v *udpStreamValue) Match(k udpTupleKey) (ok, rev bool) { @@ -143,27 +146,23 @@ type udpTupleKey struct { } func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) { + if maxStreams <= 0 { + maxStreams = 1 + } m := &udpStreamManager{ factory: factory, + streams: make(map[uint32]*list.Element, maxStreams), + order: list.New(), + maxStreams: maxStreams, tupleIndex: make(map[udpTupleKey]uint32, maxStreams), streamTuples: make(map[uint32]udpTupleKey, maxStreams), stats: stats, } - ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) { - if v != nil && v.Stream != nil { - v.Stream.Close() - } - m.removeTupleMappingLocked(k) - }) - if err != nil { - return nil, err - } - m.streams = ss return m, nil } func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, rev bool, payload []byte, uc *udpContext) { - value, ok := m.streams.Get(streamID) + value, ok := m.get(streamID) if !ok { if m.stats != nil { m.stats.UDPTupleLookups.Add(1) @@ -176,7 +175,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, m.stats.UDPTupleHits.Add(1) } var hasValue bool - matchedValue, hasValue = m.streams.Get(matchedKey) + matchedValue, hasValue = m.get(matchedKey) if !hasValue || matchedValue == nil { delete(m.tupleIndex, tuple) delete(m.streamTuples, matchedKey) @@ -188,16 +187,18 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, value = matchedValue rev = matchedRev if matchedKey != streamID { - m.streams.Remove(matchedKey) - m.streams.Add(streamID, matchedValue) + m.remove(matchedKey, false) + matchedValue.StreamID = streamID + m.add(streamID, matchedValue) m.bindTupleLocked(streamID, tuple) } } else { value = &udpStreamValue{ - Stream: m.factory.New(tuple, payload, uc), - Tuple: tuple, + StreamID: streamID, + Stream: m.factory.New(tuple, payload, uc), + Tuple: tuple, } - m.streams.Add(streamID, value) + m.add(streamID, value) m.bindTupleLocked(streamID, tuple) } } else { @@ -205,10 +206,11 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, if !ok { value.Stream.Close() value = &udpStreamValue{ - Stream: m.factory.New(tuple, payload, uc), - Tuple: tuple, + StreamID: streamID, + Stream: m.factory.New(tuple, payload, uc), + Tuple: tuple, } - m.streams.Add(streamID, value) + m.add(streamID, value) m.bindTupleLocked(streamID, tuple) } } @@ -217,6 +219,55 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, } } +func (m *udpStreamManager) get(streamID uint32) (*udpStreamValue, bool) { + ele, ok := m.streams[streamID] + if !ok || ele == nil { + return nil, false + } + m.order.MoveToFront(ele) + value, ok := ele.Value.(*udpStreamValue) + return value, ok && value != nil +} + +func (m *udpStreamManager) add(streamID uint32, value *udpStreamValue) { + if value == nil { + return + } + if existing, ok := m.streams[streamID]; ok { + existing.Value = value + m.order.MoveToFront(existing) + return + } + value.StreamID = streamID + m.streams[streamID] = m.order.PushFront(value) + for len(m.streams) > m.maxStreams { + back := m.order.Back() + if back == nil { + return + } + evicted, _ := back.Value.(*udpStreamValue) + if evicted == nil { + m.order.Remove(back) + continue + } + m.remove(evicted.StreamID, true) + } +} + +func (m *udpStreamManager) remove(streamID uint32, closeStream bool) { + ele, ok := m.streams[streamID] + if !ok || ele == nil { + return + } + value, _ := ele.Value.(*udpStreamValue) + delete(m.streams, streamID) + m.order.Remove(ele) + m.removeTupleMappingLocked(streamID) + if closeStream && value != nil && value.Stream != nil { + value.Stream.Close() + } +} + func (m *udpStreamManager) bindTupleLocked(streamID uint32, key udpTupleKey) { m.removeTupleMappingLocked(streamID) m.tupleIndex[key] = streamID diff --git a/ruleset/builtins/geo/matchers_v2geo.go b/ruleset/builtins/geo/matchers_v2geo.go index 68da971..44b1d20 100644 --- a/ruleset/builtins/geo/matchers_v2geo.go +++ b/ruleset/builtins/geo/matchers_v2geo.go @@ -112,23 +112,35 @@ type geositeDomain struct { } type geositeMatcher struct { - Domains []geositeDomain + Domains []geositeDomain // legacy slow path for tests and manual construction + Plain []geositeDomain + Regex []geositeDomain + Root map[string]geositeDomain + Full map[string]geositeDomain // Attributes are matched using "and" logic - if you have multiple attributes here, // a domain must have all of those attributes to be considered a match. Attrs []string } -func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool { - // Match attributes first - if len(m.Attrs) > 0 { - if len(domain.Attrs) == 0 { +func (m *geositeMatcher) attrsMatch(domain geositeDomain) bool { + if len(m.Attrs) == 0 { + return true + } + if len(domain.Attrs) == 0 { + return false + } + for _, attr := range m.Attrs { + if !domain.Attrs[attr] { return false } - for _, attr := range m.Attrs { - if !domain.Attrs[attr] { - return false - } - } + } + return true +} + +func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool { + // Match attributes first + if !m.attrsMatch(domain) { + return false } switch domain.Type { @@ -152,7 +164,35 @@ func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool { } func (m *geositeMatcher) Match(host HostInfo) bool { - for _, domain := range m.Domains { + if host.Name == "" { + return false + } + if domain, ok := m.Full[host.Name]; ok && m.attrsMatch(domain) { + return true + } + for name := host.Name; name != ""; { + if domain, ok := m.Root[name]; ok && m.attrsMatch(domain) { + return true + } + idx := strings.IndexByte(name, '.') + if idx < 0 { + break + } + name = name[idx+1:] + } + for _, domain := range m.Plain { + if m.matchDomain(domain, host) { + return true + } + } + if len(m.Plain) == 0 && len(m.Regex) == 0 && len(m.Root) == 0 && len(m.Full) == 0 { + for _, domain := range m.Domains { + if m.matchDomain(domain, host) { + return true + } + } + } + for _, domain := range m.Regex { if m.matchDomain(domain, host) { return true } @@ -161,45 +201,53 @@ func (m *geositeMatcher) Match(host HostInfo) bool { } func newGeositeMatcher(list *v2geo.GeoSite, attrs []string) (*geositeMatcher, error) { - domains := make([]geositeDomain, len(list.Domain)) - for i, domain := range list.Domain { + matcher := &geositeMatcher{ + Root: make(map[string]geositeDomain), + Full: make(map[string]geositeDomain), + Attrs: attrs, + } + for _, domain := range list.Domain { + var compiled geositeDomain switch domain.Type { case v2geo.Domain_Plain: - domains[i] = geositeDomain{ + compiled = geositeDomain{ Type: geositeDomainPlain, Value: domain.Value, Attrs: domainAttributeToMap(domain.Attribute), } + matcher.Plain = append(matcher.Plain, compiled) case v2geo.Domain_Regex: regex, err := regexp.Compile(domain.Value) if err != nil { return nil, err } - domains[i] = geositeDomain{ + compiled = geositeDomain{ Type: geositeDomainRegex, + Value: domain.Value, Regex: regex, Attrs: domainAttributeToMap(domain.Attribute), } + matcher.Regex = append(matcher.Regex, compiled) case v2geo.Domain_Full: - domains[i] = geositeDomain{ + compiled = geositeDomain{ Type: geositeDomainFull, Value: domain.Value, Attrs: domainAttributeToMap(domain.Attribute), } + matcher.Full[domain.Value] = compiled case v2geo.Domain_RootDomain: - domains[i] = geositeDomain{ + compiled = geositeDomain{ Type: geositeDomainRoot, Value: domain.Value, Attrs: domainAttributeToMap(domain.Attribute), } + matcher.Root[domain.Value] = compiled default: return nil, errors.New("unsupported domain type") } + matcher.Domains = append(matcher.Domains, compiled) } - return &geositeMatcher{ - Domains: domains, - Attrs: attrs, - }, nil + return matcher, nil } func domainAttributeToMap(attrs []*v2geo.Domain_Attribute) map[string]bool { diff --git a/ruleset/expr.go b/ruleset/expr.go index 7838db9..4b8b4f2 100644 --- a/ruleset/expr.go +++ b/ruleset/expr.go @@ -59,6 +59,8 @@ type compiledExprRule struct { Log bool ModInstance modifier.Instance Program *vm.Program + Native nativeExpr + AnalyzerRefs map[string]analyzerRuleRef GeoSiteConditions []string StartTimeSecs int // seconds since midnight, -1 if unset StopTimeSecs int // seconds since midnight, -1 if unset @@ -67,6 +69,7 @@ type compiledExprRule struct { } var _ Ruleset = (*exprRuleset)(nil) +var _ LogFinalizer = (*exprRuleset)(nil) var ( envPool = sync.Pool{ @@ -102,10 +105,12 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult { }() } - env := envPool.Get().(map[string]any) - clear(env) - macMap, ipMap, portMap := populateExprEnv(env, info) + var env map[string]any + var macMap, ipMap, portMap map[string]any releaseEnv := func() { + if env == nil { + return + } clear(env) envPool.Put(env) putSubMap(macMap) @@ -113,31 +118,45 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult { putSubMap(portMap) } now := time.Now() + logged := false for _, rule := range r.Rules { if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) { continue } - v, err := vm.Run(rule.Program, env) - if err != nil { - if r.stats != nil { - r.stats.MatchErrors.Add(1) + matched := false + if rule.Native != nil { + matched = rule.Native.Match(info) + } else { + if env == nil { + env = envPool.Get().(map[string]any) + clear(env) + macMap, ipMap, portMap = populateExprEnv(env, info) } - r.Logger.MatchError(info, rule.Name, err) - continue + v, err := vm.Run(rule.Program, env) + if err != nil { + if r.stats != nil { + r.stats.MatchErrors.Add(1) + } + r.Logger.MatchError(info, rule.Name, err) + continue + } + matched, _ = v.(bool) } - if vBool, ok := v.(bool); ok && vBool { + if matched { if rule.Log { logInfo := info if len(rule.GeoSiteConditions) > 0 && r.GeoMatcher != nil { logInfo = addGeoSiteLogMetadata(logInfo, r.GeoMatcher, rule.GeoSiteConditions) } r.Logger.Log(logInfo, rule.Name) + logged = true } if rule.Action != nil { releaseEnv() return MatchResult{ Action: *rule.Action, ModInstance: rule.ModInstance, + Logged: logged, } } } @@ -145,9 +164,40 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult { releaseEnv() return MatchResult{ Action: ActionMaybe, + Logged: logged, } } +func (r *exprRuleset) CanFinalizeAfterLog(info StreamInfo, activeAnalyzers []string) bool { + active := make(map[string]bool, len(activeAnalyzers)) + for _, name := range activeAnalyzers { + active[name] = true + } + for _, rule := range r.Rules { + if rule.Action == nil { + continue + } + if *rule.Action == ActionModify { + return false + } + if rule.StartTimeSecs != -1 || rule.StopTimeSecs != -1 || len(rule.Weekdays) != 0 { + return false + } + for name, ref := range rule.AnalyzerRefs { + if !active[name] { + continue + } + if ref.ResponseSide { + return false + } + if _, ok := info.Props[name]; !ok { + return false + } + } + } + return true +} + func (r *exprRuleset) Stats() Stats { if r == nil || r.stats == nil { return Stats{} @@ -242,17 +292,23 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier if err != nil { return nil, fmt.Errorf("rule %q has invalid weekdays: %w", rule.Name, err) } + var analyzerRefs map[string]analyzerRuleRef + if refTree, err := parser.Parse(rule.Expr); err == nil && refTree != nil { + analyzerRefs = collectAnalyzerRefs(refTree.Node, fullAnMap) + } cr := compiledExprRule{ Name: rule.Name, Action: action, Log: rule.Log, Program: program, + AnalyzerRefs: analyzerRefs, GeoSiteConditions: extractGeoSiteConditions(rule.Expr), StartTimeSecs: startSecs, StopTimeSecs: stopSecs, Weekdays: weekdays, WeekdaysNegated: weekdaysNegated, } + cr.Native = compileNativeExpr(rule.Expr, funcMap, geoMatcher) if action != nil && *action == ActionModify { mod, ok := fullModMap[rule.Modifier.Name] if !ok { @@ -266,9 +322,16 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier } compiledRules = append(compiledRules, cr) } + depAns := make([]analyzer.Analyzer, 0, len(depAnMap)) + for _, a := range ans { + if depAnMap[a.Name()] != nil { + depAns = append(depAns, a) + } + } + return &exprRuleset{ Rules: compiledRules, - Ans: ans, + Ans: depAns, Logger: config.Logger, GeoMatcher: geoMatcher, stats: stats, @@ -373,6 +436,58 @@ func (v *idVisitor) Visit(node *ast.Node) { } } +type analyzerRuleRef struct { + ResponseSide bool +} + +type analyzerRefVisitor struct { + Analyzers map[string]analyzer.Analyzer + Refs map[string]analyzerRuleRef +} + +func collectAnalyzerRefs(root ast.Node, analyzers map[string]analyzer.Analyzer) map[string]analyzerRuleRef { + visitor := &analyzerRefVisitor{ + Analyzers: analyzers, + Refs: make(map[string]analyzerRuleRef), + } + ast.Walk(&root, visitor) + return visitor.Refs +} + +func (v *analyzerRefVisitor) Visit(node *ast.Node) { + switch n := (*node).(type) { + case *ast.IdentifierNode: + if _, ok := v.Analyzers[n.Value]; ok { + v.add(n.Value, false) + } + case *ast.MemberNode: + path := memberPath(n) + if len(path) == 0 { + return + } + name := path[0] + if _, ok := v.Analyzers[name]; !ok { + return + } + v.add(name, len(path) > 1 && isResponseSideAnalyzerPath(path[1])) + } +} + +func (v *analyzerRefVisitor) add(name string, responseSide bool) { + ref := v.Refs[name] + ref.ResponseSide = ref.ResponseSide || responseSide + v.Refs[name] = ref +} + +func isResponseSideAnalyzerPath(name string) bool { + switch name { + case "resp", "server", "answers", "response": + return true + default: + return false + } +} + // idPatcher patches the AST during expr compilation, replacing certain values with // their internal representations for better runtime performance. type idPatcher struct { diff --git a/ruleset/expr_test.go b/ruleset/expr_test.go index 320a9f2..14e0161 100644 --- a/ruleset/expr_test.go +++ b/ruleset/expr_test.go @@ -1,6 +1,7 @@ package ruleset import ( + "net" "reflect" "strings" "testing" @@ -12,6 +13,13 @@ import ( "github.com/expr-lang/expr/parser" ) +type testAnalyzer struct { + name string +} + +func (a testAnalyzer) Name() string { return a.name } +func (a testAnalyzer) Limit() int { return 0 } + func TestExtractGeoSiteConditions(t *testing.T) { expression := ` (geosite(tls.req.sni, "openai") || geosite(quic.req.sni, "OpenAI")) && @@ -88,3 +96,93 @@ func TestIDPatcher_PatchesGeoSiteORChainToGeoSiteSet(t *testing.T) { t.Fatalf("expected OR chain to be collapsed, got %q", got) } } + +func TestCompileExprRulesPrunesUnusedAnalyzers(t *testing.T) { + rs, err := CompileExprRules([]ExprRule{ + {Name: "network-only", Action: "allow", Expr: `proto == "tcp" && port.dst == 443`}, + }, []analyzer.Analyzer{testAnalyzer{name: "tls"}, testAnalyzer{name: "quic"}}, nil, &BuiltinConfig{}) + if err != nil { + t.Fatalf("CompileExprRules error: %v", err) + } + exprRS := rs.(*exprRuleset) + if len(exprRS.Ans) != 0 { + t.Fatalf("expected no analyzers for network-only rule, got %d", len(exprRS.Ans)) + } + if exprRS.Rules[0].Native == nil { + t.Fatalf("expected network-only rule to compile to native matcher") + } + got := rs.Match(StreamInfo{Protocol: ProtocolTCP, DstPort: 443}) + if got.Action != ActionAllow { + t.Fatalf("native match action=%v want=%v", got.Action, ActionAllow) + } +} + +func TestCompileExprRulesKeepsReferencedAnalyzersOnly(t *testing.T) { + rs, err := CompileExprRules([]ExprRule{ + {Name: "tls-only", Action: "allow", Expr: `tls != nil && tls.req != nil && tls.req.sni == "example.com"`}, + }, []analyzer.Analyzer{testAnalyzer{name: "tls"}, testAnalyzer{name: "quic"}}, nil, &BuiltinConfig{}) + if err != nil { + t.Fatalf("CompileExprRules error: %v", err) + } + exprRS := rs.(*exprRuleset) + if len(exprRS.Ans) != 1 || exprRS.Ans[0].Name() != "tls" { + t.Fatalf("expected only tls analyzer, got %#v", exprRS.Ans) + } +} + +func TestNativeCIDRMatcher(t *testing.T) { + funcMap, geoMatcher := buildFunctionMapForTest() + n := compileNativeExpr(`cidr(ip.src, "192.168.1.0/24") && port.dst >= 80 && port.dst <= 443`, funcMap, geoMatcher) + if n == nil { + t.Fatal("expected native matcher") + } + if !n.Match(StreamInfo{SrcIP: net.ParseIP("192.168.1.10"), DstPort: 443}) { + t.Fatal("expected native CIDR matcher to match") + } + if n.Match(StreamInfo{SrcIP: net.ParseIP("10.0.0.1"), DstPort: 443}) { + t.Fatal("expected native CIDR matcher not to match") + } +} + +func TestCanFinalizeAfterLogForRequestOnlyActionRules(t *testing.T) { + rs, err := CompileExprRules([]ExprRule{ + {Name: "log-host", Log: true, Expr: `tls != nil && tls.req != nil && tls.req.sni != nil`}, + {Name: "block-bad-host", Action: "block", Expr: `tls != nil && tls.req != nil && tls.req.sni == "bad.example"`}, + }, []analyzer.Analyzer{testAnalyzer{name: "tls"}}, nil, &BuiltinConfig{}) + if err != nil { + t.Fatalf("CompileExprRules error: %v", err) + } + + info := StreamInfo{ + Props: analyzer.CombinedPropMap{ + "tls": analyzer.PropMap{"req": analyzer.PropMap{"sni": "good.example"}}, + }, + } + if !rs.(LogFinalizer).CanFinalizeAfterLog(info, []string{"tls"}) { + t.Fatal("expected request-only rules to allow log finalization once request props exist") + } +} + +func TestCanFinalizeAfterLogWaitsForResponseActionRules(t *testing.T) { + rs, err := CompileExprRules([]ExprRule{ + {Name: "log-host", Log: true, Expr: `tls != nil && tls.req != nil && tls.req.sni != nil`}, + {Name: "block-response", Action: "block", Expr: `tls != nil && tls.resp != nil && tls.resp.cipher_suite == "bad"`}, + }, []analyzer.Analyzer{testAnalyzer{name: "tls"}}, nil, &BuiltinConfig{}) + if err != nil { + t.Fatalf("CompileExprRules error: %v", err) + } + + info := StreamInfo{ + Props: analyzer.CombinedPropMap{ + "tls": analyzer.PropMap{"req": analyzer.PropMap{"sni": "good.example"}}, + }, + } + if rs.(LogFinalizer).CanFinalizeAfterLog(info, []string{"tls"}) { + t.Fatal("expected response-side rule to keep inspection open") + } +} + +func buildFunctionMapForTest() (map[string]*Function, *geo.GeoMatcher) { + m, g := buildFunctionMap(&BuiltinConfig{}, nil) + return m, g +} diff --git a/ruleset/interface.go b/ruleset/interface.go index 7351f0a..fb06271 100644 --- a/ruleset/interface.go +++ b/ruleset/interface.go @@ -85,6 +85,7 @@ func (i StreamInfo) DstString() string { type MatchResult struct { Action Action ModInstance modifier.Instance + Logged bool } type Ruleset interface { @@ -96,6 +97,10 @@ type Ruleset interface { Match(StreamInfo) MatchResult } +type LogFinalizer interface { + CanFinalizeAfterLog(StreamInfo, []string) bool +} + type Stats struct { MatchCalls uint64 MatchErrors uint64 diff --git a/ruleset/native.go b/ruleset/native.go new file mode 100644 index 0000000..b96d484 --- /dev/null +++ b/ruleset/native.go @@ -0,0 +1,273 @@ +package ruleset + +import ( + "net" + "strconv" + "strings" + + "github.com/expr-lang/expr/ast" + "github.com/expr-lang/expr/parser" + + "git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo" +) + +type nativeExpr interface { + Match(StreamInfo) bool +} + +type nativeBoolFunc func(StreamInfo) bool + +func (f nativeBoolFunc) Match(info StreamInfo) bool { + return f(info) +} + +type nativeValueFunc func(StreamInfo) (any, bool) + +func compileNativeExpr(expression string, funcMap map[string]*Function, gm *geo.GeoMatcher) nativeExpr { + tree, err := parser.Parse(expression) + if err != nil || tree == nil || tree.Node == nil { + return nil + } + root := tree.Node + patcher := &idPatcher{FuncMap: funcMap, GeoMatcher: gm} + ast.Walk(&root, patcher) + if patcher.Err != nil { + return nil + } + return compileNativeBool(root) +} + +func compileNativeBool(node ast.Node) nativeExpr { + switch n := node.(type) { + case *ast.BinaryNode: + switch n.Operator { + case "&&", "and": + left := compileNativeBool(n.Left) + right := compileNativeBool(n.Right) + if left == nil || right == nil { + return nil + } + return nativeBoolFunc(func(info StreamInfo) bool { + return left.Match(info) && right.Match(info) + }) + case "||", "or": + left := compileNativeBool(n.Left) + right := compileNativeBool(n.Right) + if left == nil || right == nil { + return nil + } + return nativeBoolFunc(func(info StreamInfo) bool { + return left.Match(info) || right.Match(info) + }) + case "==", "!=", ">", ">=", "<", "<=": + left := compileNativeValue(n.Left) + right := compileNativeValue(n.Right) + if left == nil || right == nil { + return nil + } + op := n.Operator + return nativeBoolFunc(func(info StreamInfo) bool { + lv, lok := left(info) + rv, rok := right(info) + if !lok || !rok { + return false + } + result, ok := compareNativeValues(lv, rv, op) + return ok && result + }) + default: + return nil + } + case *ast.UnaryNode: + if n.Operator != "!" && n.Operator != "not" { + return nil + } + child := compileNativeBool(n.Node) + if child == nil { + return nil + } + return nativeBoolFunc(func(info StreamInfo) bool { + return !child.Match(info) + }) + case *ast.CallNode: + return compileNativeCall(n) + case *ast.BoolNode: + value := n.Value + return nativeBoolFunc(func(StreamInfo) bool { return value }) + default: + return nil + } +} + +func compileNativeCall(n *ast.CallNode) nativeExpr { + id, ok := n.Callee.(*ast.IdentifierNode) + if !ok || strings.ToLower(id.Value) != "cidr" || len(n.Arguments) != 2 { + return nil + } + ipValue := compileNativeValue(n.Arguments[0]) + if ipValue == nil { + return nil + } + var cidr *net.IPNet + switch arg := n.Arguments[1].(type) { + case *ast.ConstantNode: + cidr, _ = arg.Value.(*net.IPNet) + case *ast.StringNode: + _, parsed, err := net.ParseCIDR(arg.Value) + if err == nil { + cidr = parsed + } + } + if cidr == nil { + return nil + } + return nativeBoolFunc(func(info StreamInfo) bool { + value, ok := ipValue(info) + if !ok { + return false + } + switch v := value.(type) { + case net.IP: + return cidr.Contains(v) + case string: + ip := net.ParseIP(v) + return ip != nil && cidr.Contains(ip) + default: + return false + } + }) +} + +func compileNativeValue(node ast.Node) nativeValueFunc { + switch n := node.(type) { + case *ast.StringNode: + value := n.Value + return func(StreamInfo) (any, bool) { return value, true } + case *ast.IntegerNode: + value := int64(n.Value) + return func(StreamInfo) (any, bool) { return value, true } + case *ast.IdentifierNode: + switch strings.ToLower(n.Value) { + case "proto": + return func(info StreamInfo) (any, bool) { return info.Protocol.String(), true } + default: + return nil + } + case *ast.MemberNode: + return compileNativeMember(n) + default: + return nil + } +} + +func compileNativeMember(n *ast.MemberNode) nativeValueFunc { + path := memberPath(n) + switch strings.Join(path, ".") { + case "mac.src": + return func(info StreamInfo) (any, bool) { return info.SrcMAC.String(), true } + case "mac.dst": + return func(info StreamInfo) (any, bool) { return info.DstMAC.String(), true } + case "ip.src": + return func(info StreamInfo) (any, bool) { return info.SrcIP, info.SrcIP != nil } + case "ip.dst": + return func(info StreamInfo) (any, bool) { return info.DstIP, info.DstIP != nil } + case "port.src": + return func(info StreamInfo) (any, bool) { return int64(info.SrcPort), true } + case "port.dst": + return func(info StreamInfo) (any, bool) { return int64(info.DstPort), true } + default: + return nil + } +} + +func memberPath(node ast.Node) []string { + switch n := node.(type) { + case *ast.IdentifierNode: + return []string{strings.ToLower(n.Value)} + case *ast.MemberNode: + base := memberPath(n.Node) + prop, ok := n.Property.(*ast.StringNode) + if !ok { + return nil + } + return append(base, strings.ToLower(prop.Value)) + default: + return nil + } +} + +func compareNativeValues(left, right any, op string) (bool, bool) { + if li, lok := nativeInt(left); lok { + ri, rok := nativeInt(right) + if !rok { + return false, false + } + return compareNativeOrdered(li, ri, op), true + } + ls, lok := nativeString(left) + if !lok { + return false, false + } + rs, rok := nativeString(right) + if !rok { + return false, false + } + switch op { + case "==": + return ls == rs, true + case "!=": + return ls != rs, true + default: + return false, false + } +} + +func compareNativeOrdered(left, right int64, op string) bool { + switch op { + case "==": + return left == right + case "!=": + return left != right + case ">": + return left > right + case ">=": + return left >= right + case "<": + return left < right + case "<=": + return left <= right + default: + return false + } +} + +func nativeInt(v any) (int64, bool) { + switch n := v.(type) { + case int: + return int64(n), true + case int64: + return n, true + case uint16: + return int64(n), true + case *ast.IntegerNode: + return int64(n.Value), true + default: + return 0, false + } +} + +func nativeString(v any) (string, bool) { + switch s := v.(type) { + case string: + return s, true + case net.IP: + if s == nil { + return "", false + } + return s.String(), true + case int64: + return strconv.FormatInt(s, 10), true + default: + return "", false + } +}