diff --git a/analyzer/tcp/http.go b/analyzer/tcp/http.go index 9b8761b..9823080 100644 --- a/analyzer/tcp/http.go +++ b/analyzer/tcp/http.go @@ -130,8 +130,8 @@ func (s *httpStream) parseResponseLine() utils.LSMAction { return utils.LSMActionCancel } version := fields[0] - status, _ := strconv.Atoi(fields[1]) - if !strings.HasPrefix(version, "HTTP/") || status == 0 { + status, err := strconv.Atoi(fields[1]) + if err != nil || !strings.HasPrefix(version, "HTTP/") || status == 0 { // Invalid version return utils.LSMActionCancel } diff --git a/analyzer/tcp/tls.go b/analyzer/tcp/tls.go index c9c488d..bb05aae 100644 --- a/analyzer/tcp/tls.go +++ b/analyzer/tcp/tls.go @@ -6,6 +6,8 @@ import ( "git.difuse.io/Difuse/Mellaris/analyzer/utils" ) +const maxHandshakeLen = 65536 + var _ analyzer.TCPAnalyzer = (*TLSAnalyzer)(nil) type TLSAnalyzer struct{} @@ -123,7 +125,7 @@ func (s *tlsStream) tlsClientHelloPreprocess() utils.LSMAction { } s.clientHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8]) - if s.clientHelloLen < minDataSize { + if s.clientHelloLen < minDataSize || s.clientHelloLen > maxHandshakeLen { return utils.LSMActionCancel } @@ -167,7 +169,7 @@ func (s *tlsStream) tlsServerHelloPreprocess() utils.LSMAction { } s.serverHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8]) - if s.serverHelloLen < minDataSize { + if s.serverHelloLen < minDataSize || s.serverHelloLen > maxHandshakeLen { return utils.LSMActionCancel } diff --git a/analyzer/udp/openvpn.go b/analyzer/udp/openvpn.go index baa1904..a0a13cf 100644 --- a/analyzer/udp/openvpn.go +++ b/analyzer/udp/openvpn.go @@ -38,6 +38,7 @@ const ( OpenVPNMinPktLen = 6 OpenVPNTCPPktDefaultLimit = 256 OpenVPNUDPPktDefaultLimit = 256 + OpenVPNTCPMaxPktLen = 4096 ) type OpenVPNAnalyzer struct{} @@ -195,7 +196,7 @@ func newOpenVPNUDPStream(logger analyzer.Logger) *openvpnUDPStream { } func (o *openvpnUDPStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, d bool) { - if len(data) == 0 { + if len(data) < OpenVPNMinPktLen { return nil, false } var update *analyzer.PropUpdate @@ -338,7 +339,7 @@ func (o *openvpnTCPStream) parsePkt(rev bool) (p *openvpnPkt, action utils.LSMAc return nil, utils.LSMActionPause } - if pktLen < OpenVPNMinPktLen { + if pktLen < OpenVPNMinPktLen || pktLen > OpenVPNTCPMaxPktLen { return nil, utils.LSMActionCancel } diff --git a/analyzer/udp/quic.go b/analyzer/udp/quic.go index cd3cdcf..0b9f238 100644 --- a/analyzer/udp/quic.go +++ b/analyzer/udp/quic.go @@ -14,6 +14,7 @@ import ( const ( quicInvalidCountThreshold = 16 quicMaxCryptoDataLen = 256 * 1024 + quicMaxFrameEntries = 100 ) var ( @@ -158,6 +159,9 @@ func (s *quicStream) mergeFrame(offset int64, data []byte) { if len(data) == 0 || offset < 0 { return } + if len(s.frames) >= quicMaxFrameEntries { + return + } if s.frames == nil { s.frames = make(map[int64][]byte) } diff --git a/engine/engine.go b/engine/engine.go index 9934090..a221187 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -5,6 +5,7 @@ import ( "runtime" "sync" "sync/atomic" + "time" "git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/ruleset" @@ -13,10 +14,16 @@ import ( var _ Engine = (*engine)(nil) type verdictEntry struct { - Verdict io.Verdict - Gen int64 + Verdict io.Verdict + Gen int64 + CreatedAt time.Time } +const ( + verdictTTL = 15 * time.Second + verdictSweepInterval = 15 * time.Second +) + type engine struct { logger Logger io io.PacketIO @@ -39,7 +46,7 @@ func NewEngine(config Config) (Engine, error) { } overflowPolicy := config.OverflowPolicy if overflowPolicy == "" { - overflowPolicy = OverflowPolicyAccept + overflowPolicy = OverflowPolicyDrop } selectionMode := config.AnalyzerSelectionMode if selectionMode == "" { @@ -83,7 +90,6 @@ func NewEngine(config Config) (Engine, error) { func (e *engine) UpdateRuleset(r ruleset.Ruleset) error { e.verdictsGen.Add(1) - e.verdicts = sync.Map{} for _, w := range e.workers { if err := w.UpdateRuleset(r); err != nil { return err @@ -100,6 +106,7 @@ func (e *engine) Run(ctx context.Context) error { go w.Run(ioCtx) } go e.drainResults(ioCtx) + go e.sweepVerdicts(ioCtx) errChan := make(chan error, 1) err := e.io.Register(ioCtx, func(p io.Packet, err error) bool { @@ -124,11 +131,13 @@ func (e *engine) Run(ctx context.Context) error { func (e *engine) dispatch(p io.Packet) bool { streamID := p.StreamID() - if v, ok := e.verdicts.Load(streamID); ok { - entry := v.(verdictEntry) - if entry.Gen == e.verdictsGen.Load() { - _ = e.io.SetVerdict(p, entry.Verdict, nil) - return true + if streamID != 0 { + if v, ok := e.verdicts.Load(streamID); ok { + entry := v.(verdictEntry) + if entry.Gen == e.verdictsGen.Load() { + _ = e.io.SetVerdict(p, entry.Verdict, nil) + return true + } } } @@ -163,12 +172,32 @@ func (e *engine) dispatch(p io.Packet) bool { } func (e *engine) applyWorkerResult(r workerResult) { - if r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream { - e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen}) + if r.StreamID != 0 && (r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream) { + e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen, CreatedAt: time.Now()}) } _ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket) } +func (e *engine) sweepVerdicts(ctx context.Context) { + ticker := time.NewTicker(verdictSweepInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + now := time.Now() + e.verdicts.Range(func(key, value interface{}) bool { + entry := value.(verdictEntry) + if now.Sub(entry.CreatedAt) > verdictTTL { + e.verdicts.Delete(key) + } + return true + }) + } + } +} + func validPacket(data []byte) bool { if len(data) == 0 { return false diff --git a/engine/packet.go b/engine/packet.go index 6b47e4e..cdd7543 100644 --- a/engine/packet.go +++ b/engine/packet.go @@ -59,7 +59,10 @@ func ParseL3(data []byte) (l3 L3Info, transport []byte, ok bool) { return } totalLen := int(uint16(data[2])<<8 | uint16(data[3])) - if totalLen < int(ihl)*4 || totalLen > len(data) { + if totalLen < int(ihl)*4 { + return + } + if totalLen > len(data) { totalLen = len(data) } return L3Info{ diff --git a/engine/reload_rules_test.go b/engine/reload_rules_test.go index cda881a..147457c 100644 --- a/engine/reload_rules_test.go +++ b/engine/reload_rules_test.go @@ -1,7 +1,6 @@ package engine import ( - "net" "testing" "git.difuse.io/Difuse/Mellaris/analyzer" @@ -9,8 +8,6 @@ import ( "git.difuse.io/Difuse/Mellaris/ruleset" "github.com/bwmarrin/snowflake" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" ) type fixedRuleset struct { @@ -60,25 +57,19 @@ func TestUDPStreamUsesUpdatedRuleset(t *testing.T) { Ruleset: fixedRuleset{action: ruleset.ActionAllow}, } - ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) - udp := &layers.UDP{ - SrcPort: 12345, - DstPort: 53, - BaseLayer: layers.BaseLayer{ - Payload: []byte("query"), - }, - } + tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 12345, BPort: 53} + payload := []byte("query") ctx := &udpContext{Verdict: udpVerdictAccept} - s := f.New(ipFlow, udp.TransportFlow(), udp, ctx) + s := f.New(tuple, payload, ctx) if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil { t.Fatalf("update ruleset: %v", err) } - if !s.Accept(udp, false, ctx) { + if !s.Accept(false, ctx) { t.Fatalf("unexpected Accept=false for virgin stream") } - s.Feed(udp, false, ctx) + s.Feed(false, payload, ctx) if ctx.Verdict != udpVerdictDropStream { t.Fatalf("verdict=%v want=%v", ctx.Verdict, udpVerdictDropStream) } @@ -96,21 +87,15 @@ func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) { Ruleset: fixedRuleset{action: ruleset.ActionAllow}, } - ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) - udp := &layers.UDP{ - SrcPort: 12345, - DstPort: 53, - BaseLayer: layers.BaseLayer{ - Payload: []byte("query"), - }, - } + tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 12345, BPort: 53} + payload := []byte("query") ctx1 := &udpContext{Verdict: udpVerdictAccept} - s := f.New(ipFlow, udp.TransportFlow(), udp, ctx1) - if !s.Accept(udp, false, ctx1) { + s := f.New(tuple, payload, ctx1) + if !s.Accept(false, ctx1) { t.Fatalf("unexpected Accept=false before first feed") } - s.Feed(udp, false, ctx1) + s.Feed(false, payload, ctx1) if ctx1.Verdict != udpVerdictAcceptStream { t.Fatalf("verdict=%v want=%v", ctx1.Verdict, udpVerdictAcceptStream) } @@ -120,16 +105,16 @@ func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) { } ctx2 := &udpContext{Verdict: udpVerdictAccept} - if !s.Accept(udp, false, ctx2) { + if !s.Accept(false, ctx2) { t.Fatalf("expected Accept=true after ruleset update") } - s.Feed(udp, false, ctx2) + s.Feed(false, payload, ctx2) if ctx2.Verdict != udpVerdictDropStream { t.Fatalf("verdict=%v want=%v", ctx2.Verdict, udpVerdictDropStream) } ctx3 := &udpContext{Verdict: udpVerdictAccept} - if s.Accept(udp, false, ctx3) { + if s.Accept(false, ctx3) { t.Fatalf("expected Accept=false with unchanged ruleset and no active entries") } if ctx3.Verdict != udpVerdictDropStream { diff --git a/engine/tcp_flow.go b/engine/tcp_flow.go index 82726a4..2d8e944 100644 --- a/engine/tcp_flow.go +++ b/engine/tcp_flow.go @@ -3,6 +3,7 @@ package engine import ( "net" "sync" + "time" "git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/io" @@ -13,6 +14,8 @@ import ( const tcpFlowMaxBuffer = 16384 +const tcpFlowIdleTimeout = 10 * time.Minute + type tcpFlowDirection uint8 const ( @@ -37,6 +40,7 @@ type tcpFlow struct { doneEntries []*tcpFlowEntry lastVerdict io.Verdict feedCalled [2]bool + lastSeen time.Time } type tcpFlowEntry struct { @@ -67,16 +71,17 @@ func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict { expected := f.dirSeq[dir] if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected { f.feedCalled[dir] = true - f.dirBuf[dir] = append(f.dirBuf[dir], payload...) - f.dirSeq[dir] = tcp.Seq + uint32(len(payload)) if len(f.dirBuf[dir]) <= tcpFlowMaxBuffer { + f.dirBuf[dir] = append(f.dirBuf[dir], payload...) propUpdated = f.feedAnalyzers(rev) } + f.dirSeq[dir] = tcp.Seq + uint32(len(payload)) } } f.runMatch(rs, version, rulesetChanged, propUpdated) f.maybeFinalizeVerdict() + f.lastSeen = time.Now() return f.lastVerdict } @@ -218,7 +223,11 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay Props: make(analyzer.CombinedPropMap), } m.logger.TCPStreamNew(m.workerID, info) - rs, version := m.rulesetSource() + var rs ruleset.Ruleset + var version uint64 + if m.rulesetSource != nil { + rs, version = m.rulesetSource() + } var ans []analyzer.TCPAnalyzer if rs != nil { baseAns := rs.Analyzers(info) @@ -255,6 +264,7 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay rulesetVersion: version, activeEntries: entries, lastVerdict: io.VerdictAccept, + lastSeen: time.Now(), } flow.dirSeq[tcpDirC2S] = tcp.Seq + 1 return flow @@ -266,6 +276,17 @@ func (m *tcpFlowManager) updateRuleset(r ruleset.Ruleset, version uint64) { } } +func (m *tcpFlowManager) cleanupIdle(now time.Time) { + m.mu.Lock() + defer m.mu.Unlock() + for id, flow := range m.flows { + if now.Sub(flow.lastSeen) > tcpFlowIdleTimeout { + flow.closeActiveEntries() + delete(m.flows, id) + } + } +} + func feedFlowEntry(entry *tcpFlowEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) { if !entry.HasLimit { update, done = entry.Stream.Feed(rev, true, false, 0, data) diff --git a/engine/udp.go b/engine/udp.go index 5d43177..e732c23 100644 --- a/engine/udp.go +++ b/engine/udp.go @@ -12,8 +12,6 @@ import ( "git.difuse.io/Difuse/Mellaris/ruleset" "github.com/bwmarrin/snowflake" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" lru "github.com/hashicorp/golang-lru/v2" ) @@ -49,9 +47,10 @@ type udpStreamFactory struct { RulesetVersion uint64 } -func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) *udpStream { +func (f *udpStreamFactory) New(k udpTupleKey, payload []byte, uc *udpContext) *udpStream { id := f.Node.Generate() - ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw()) + ipSrc := net.IP(k.AIP[:k.ALen]) + ipDst := net.IP(k.BIP[:k.BLen]) info := ruleset.StreamInfo{ ID: id.Int64(), Protocol: ruleset.ProtocolUDP, @@ -59,8 +58,8 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u DstMAC: append(net.HardwareAddr(nil), uc.DstMAC...), SrcIP: ipSrc, DstIP: ipDst, - SrcPort: uint16(udp.SrcPort), - DstPort: uint16(udp.DstPort), + SrcPort: k.APort, + DstPort: k.BPort, Props: make(analyzer.CombinedPropMap), } f.Logger.UDPStreamNew(f.WorkerID, info) @@ -69,11 +68,10 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u if rs != nil { baseAns := rs.Analyzers(info) if f.Selector != nil { - baseAns = f.Selector.SelectUDP(baseAns, udp.Payload) + baseAns = f.Selector.SelectUDP(baseAns, payload) } ans = analyzersToUDPAnalyzers(baseAns) } - // Create entries for each analyzer entries := make([]*udpStreamEntry, 0, len(ans)) for _, a := range ans { entries = append(entries, &udpStreamEntry{ @@ -81,8 +79,8 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u Stream: a.NewUDP(analyzer.UDPInfo{ SrcIP: ipSrc, DstIP: ipDst, - SrcPort: uint16(udp.SrcPort), - DstPort: uint16(udp.DstPort), + SrcPort: k.APort, + DstPort: k.BPort, }, &analyzerLogger{ StreamID: id.Int64(), Name: a.Name(), @@ -125,9 +123,14 @@ type udpStreamManager struct { } type udpStreamValue struct { - Stream *udpStream - IPFlow gopacket.Flow - UDPFlow gopacket.Flow + Stream *udpStream + Tuple udpTupleKey +} + +func (v *udpStreamValue) Match(k udpTupleKey) (ok, rev bool) { + fwd := v.Tuple == k + rev = v.Tuple == reverseTuple(k) + return fwd || rev, rev } type udpTupleKey struct { @@ -139,12 +142,6 @@ type udpTupleKey struct { BPort uint16 } -func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) { - fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow - rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse() - return fwd || rev, rev -} - func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) { m := &udpStreamManager{ factory: factory, @@ -153,6 +150,9 @@ func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *stats stats: stats, } ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) { + if v != nil && v.Stream != nil { + v.Stream.Close() + } m.removeTupleMappingLocked(k) }) if err != nil { @@ -162,16 +162,12 @@ func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *stats return m, nil } -func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) { - rev := false +func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, rev bool, payload []byte, uc *udpContext) { value, ok := m.streams.Get(streamID) - tuple := canonicalUDPTupleKey(ipFlow, udp) if !ok { if m.stats != nil { m.stats.UDPTupleLookups.Add(1) } - // Conntrack IDs can change during early flow lifetime on some systems. - // Rebind by canonical 5-tuple in O(1). matchedKey, found := m.tupleIndex[tuple] var matchedValue *udpStreamValue var matchedRev bool @@ -188,7 +184,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo } } if found { - _, matchedRev = matchedValue.Match(ipFlow, udp.TransportFlow()) + _, matchedRev = matchedValue.Match(tuple) value = matchedValue rev = matchedRev if matchedKey != streamID { @@ -197,32 +193,27 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo m.bindTupleLocked(streamID, tuple) } } else { - // New stream value = &udpStreamValue{ - Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc), - IPFlow: ipFlow, - UDPFlow: udp.TransportFlow(), + Stream: m.factory.New(tuple, payload, uc), + Tuple: tuple, } m.streams.Add(streamID, value) m.bindTupleLocked(streamID, tuple) } } else { - // Stream ID exists, but is it really the same stream? - ok, rev = value.Match(ipFlow, udp.TransportFlow()) + ok, rev = value.Match(tuple) if !ok { - // It's not - close the old stream & replace it with a new one value.Stream.Close() value = &udpStreamValue{ - Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc), - IPFlow: ipFlow, - UDPFlow: udp.TransportFlow(), + Stream: m.factory.New(tuple, payload, uc), + Tuple: tuple, } m.streams.Add(streamID, value) m.bindTupleLocked(streamID, tuple) } } - if value.Stream.Accept(udp, rev, uc) { - value.Stream.Feed(udp, rev, uc) + if value.Stream.Accept(rev, uc) { + value.Stream.Feed(rev, payload, uc) } } @@ -242,25 +233,34 @@ func (m *udpStreamManager) removeTupleMappingLocked(streamID uint32) { } } -func canonicalUDPTupleKey(ipFlow gopacket.Flow, udp *layers.UDP) udpTupleKey { - srcIP := ipFlow.Src().Raw() - dstIP := ipFlow.Dst().Raw() - srcPort := uint16(udp.SrcPort) - dstPort := uint16(udp.DstPort) +func canonicalUDPTupleKey(srcIP, dstIP net.IP, srcPort, dstPort uint16) udpTupleKey { + srcRaw := []byte(srcIP) + dstRaw := []byte(dstIP) - if compareIPEndpoint(srcIP, srcPort, dstIP, dstPort) > 0 { - srcIP, dstIP = dstIP, srcIP + if compareIPEndpoint(srcRaw, srcPort, dstRaw, dstPort) > 0 { + srcRaw, dstRaw = dstRaw, srcRaw srcPort, dstPort = dstPort, srcPort } var key udpTupleKey - key.ALen = uint8(copy(key.AIP[:], srcIP)) - key.BLen = uint8(copy(key.BIP[:], dstIP)) + key.ALen = uint8(copy(key.AIP[:], srcRaw)) + key.BLen = uint8(copy(key.BIP[:], dstRaw)) key.APort = srcPort key.BPort = dstPort return key } +func reverseTuple(k udpTupleKey) udpTupleKey { + var r udpTupleKey + r.ALen = k.BLen + r.BLen = k.ALen + r.AIP = k.BIP + r.BIP = k.AIP + r.APort = k.BPort + r.BPort = k.APort + return r +} + func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int { if len(aIP) != len(bIP) { if len(aIP) < len(bIP) { @@ -298,11 +298,8 @@ type udpStreamEntry struct { Quota int } -func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool { +func (s *udpStream) Accept(rev bool, uc *udpContext) bool { if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() { - // Make sure every stream matches against the ruleset at least once, - // even if there are no activeEntries, as the ruleset may have built-in - // properties that need to be matched. return true } else { uc.Verdict = s.lastVerdict @@ -310,12 +307,11 @@ func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool { } } -func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) { +func (s *udpStream) Feed(rev bool, payload []byte, uc *udpContext) { updated := false for i := len(s.activeEntries) - 1; i >= 0; i-- { - // Important: reverse order so we can remove entries entry := s.activeEntries[i] - update, closeUpdate, done := s.feedEntry(entry, rev, udp.Payload) + update, closeUpdate, done := s.feedEntry(entry, rev, payload) up1 := processPropUpdate(s.info.Props, entry.Name, update) up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate) updated = updated || up1 || up2 @@ -345,7 +341,7 @@ func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) { action = ruleset.ActionMaybe } else { var err error - uc.Packet, err = udpMI.Process(udp.Payload) + uc.Packet, err = udpMI.Process(payload) if err != nil { // Modifier error, fallback to maybe s.logger.ModifyError(s.info, err) diff --git a/engine/udp_manager_bench_test.go b/engine/udp_manager_bench_test.go index 49a8fb9..9498ccf 100644 --- a/engine/udp_manager_bench_test.go +++ b/engine/udp_manager_bench_test.go @@ -1,20 +1,16 @@ package engine import ( - "net" "testing" "git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/ruleset" "github.com/bwmarrin/snowflake" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" ) type legacyUDPStreamValue struct { - IPFlow gopacket.Flow - UDPFlow gopacket.Flow + Tuple udpTupleKey } type emptyRuleset struct{} @@ -36,17 +32,20 @@ func benchmarkUDPManager(b *testing.B, churn bool) { } const flowCount = 20000 - flows := make([]gopacket.Flow, flowCount) - udps := make([]*layers.UDP, flowCount) + tuples := make([]udpTupleKey, flowCount) + payloads := make([][]byte, flowCount) for i := 0; i < flowCount; i++ { a := byte(i >> 8) c := byte(i) - flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4()) - udps[i] = &layers.UDP{ - SrcPort: layers.UDPPort(1024 + i%20000), - DstPort: layers.UDPPort(20000 + (i*7)%20000), - BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}, - } + var t udpTupleKey + t.AIP = [16]byte{10, a, 0, c} + t.ALen = 4 + t.BIP = [16]byte{172, 16, a, c} + t.BLen = 4 + t.APort = 1024 + uint16(i%20000) + t.BPort = 20000 + uint16((i*7)%20000) + tuples[i] = t + payloads[i] = []byte{0x01, 0x00, 0x00, 0x00} } ctx := &udpContext{Verdict: udpVerdictAccept} @@ -59,7 +58,7 @@ func benchmarkUDPManager(b *testing.B, churn bool) { } ctx.Verdict = udpVerdictAccept ctx.Packet = nil - mgr.MatchWithContext(streamID, flows[idx], udps[idx], ctx) + mgr.MatchWithContext(streamID, tuples[idx], false, payloads[idx], ctx) } } @@ -73,27 +72,25 @@ func BenchmarkUDPManagerMatchStreamIDChurn(b *testing.B) { func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) { const flowCount = 5000 - flows := make([]gopacket.Flow, flowCount) - udps := make([]*layers.UDP, flowCount) + tuples := make([]udpTupleKey, flowCount) for i := 0; i < flowCount; i++ { a := byte(i >> 8) c := byte(i) - flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4()) - udps[i] = &layers.UDP{ - SrcPort: layers.UDPPort(1024 + i%20000), - DstPort: layers.UDPPort(20000 + (i*7)%20000), - BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}, - } + var t udpTupleKey + t.AIP = [16]byte{10, a, 0, c} + t.ALen = 4 + t.BIP = [16]byte{172, 16, a, c} + t.BLen = 4 + t.APort = 1024 + uint16(i%20000) + t.BPort = 20000 + uint16((i*7)%20000) + tuples[i] = t } streams := make(map[uint32]*legacyUDPStreamValue, flowCount) keys := make([]uint32, 0, flowCount) for i := 0; i < flowCount; i++ { streamID := uint32(i + 1) - streams[streamID] = &legacyUDPStreamValue{ - IPFlow: flows[i], - UDPFlow: udps[i].TransportFlow(), - } + streams[streamID] = &legacyUDPStreamValue{Tuple: tuples[i]} keys = append(keys, streamID) } @@ -104,15 +101,14 @@ func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) { if _, ok := streams[streamID]; ok { continue } - ipFlow := flows[idx] - udpFlow := udps[idx].TransportFlow() + tuple := tuples[idx] + revTuple := reverseTuple(tuple) for _, k := range keys { v, ok := streams[k] if !ok || v == nil { continue } - if (v.IPFlow == ipFlow && v.UDPFlow == udpFlow) || - (v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()) { + if v.Tuple == tuple || v.Tuple == revTuple { delete(streams, k) streams[streamID] = v break diff --git a/engine/udp_manager_tuple_test.go b/engine/udp_manager_tuple_test.go index 3d9e5f2..ea71c44 100644 --- a/engine/udp_manager_tuple_test.go +++ b/engine/udp_manager_tuple_test.go @@ -1,7 +1,6 @@ package engine import ( - "net" "sync/atomic" "testing" @@ -9,8 +8,6 @@ import ( "git.difuse.io/Difuse/Mellaris/ruleset" "github.com/bwmarrin/snowflake" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" ) type countingRuleset struct { @@ -54,17 +51,17 @@ func TestUDPStreamManagerRebindsByTupleInO1Path(t *testing.T) { t.Fatalf("new manager: %v", err) } - ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) - udp := &layers.UDP{SrcPort: 50000, DstPort: 443, BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}} + tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 50000, BPort: 443} + payload := []byte{0x01, 0x00, 0x00, 0x00} ctx1 := &udpContext{Verdict: udpVerdictAccept} - mgr.MatchWithContext(100, ipFlow, udp, ctx1) + mgr.MatchWithContext(100, tuple, false, payload, ctx1) if got := newCalls.Load(); got != 1 { t.Fatalf("new stream calls=%d want=1", got) } ctx2 := &udpContext{Verdict: udpVerdictAccept} - mgr.MatchWithContext(200, ipFlow, udp, ctx2) + mgr.MatchWithContext(200, tuple, false, payload, ctx2) if got := newCalls.Load(); got != 1 { t.Fatalf("expected stream reuse by tuple, new stream calls=%d want=1", got) } diff --git a/engine/worker.go b/engine/worker.go index 080ceca..4faf21a 100644 --- a/engine/worker.go +++ b/engine/worker.go @@ -3,6 +3,7 @@ package engine import ( "context" "net" + "time" "git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/ruleset" @@ -119,10 +120,16 @@ func (w *worker) FeedBlocking(p *workerPacket) { func (w *worker) Run(ctx context.Context) { w.logger.WorkerStart(w.id) defer w.logger.WorkerStop(w.id) + + tcpSweepTicker := time.NewTicker(1 * time.Minute) + defer tcpSweepTicker.Stop() + for { select { case <-ctx.Done(): return + case <-tcpSweepTicker.C: + w.tcpFlowMgr.cleanupIdle(time.Now()) case wp := <-w.packetChan: if wp == nil { return @@ -202,15 +209,6 @@ func (w *worker) handleIPPacket(wp *workerPacket, data []byte) (io.Verdict, []by func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) { ipSrc := l3.SrcIPAddr() ipDst := l3.DstIPAddr() - endpointType := layers.EndpointIPv4 - flowSrc := ipSrc.To4() - flowDst := ipDst.To4() - if l3.Version == 6 { - endpointType = layers.EndpointIPv6 - flowSrc = ipSrc.To16() - flowDst = ipDst.To16() - } - ipFlow := gopacket.NewFlow(endpointType, flowSrc, flowDst) if len(srcMAC) == 0 && w.macResolver != nil { srcMAC = w.macResolver.Resolve(ipSrc) @@ -221,12 +219,9 @@ func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []by SrcMAC: srcMAC, DstMAC: dstMAC, } - // Temporarily set payload on a UDP layer so existing UDP handling works. - w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{ - BaseLayer: layers.BaseLayer{Payload: payload}, - SrcPort: layers.UDPPort(udp.SrcPort), - DstPort: layers.UDPPort(udp.DstPort), - }, uc) + + tuple := canonicalUDPTupleKey(ipSrc, ipDst, udp.SrcPort, udp.DstPort) + w.udpSM.MatchWithContext(streamID, tuple, false, payload, uc) return io.Verdict(uc.Verdict), uc.Packet } @@ -253,7 +248,7 @@ func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, modPayload []b if err != nil { return io.VerdictAccept, nil } - return io.VerdictAcceptModify, w.modSerializeBuffer.Bytes() + return io.VerdictAcceptModify, append([]byte(nil), w.modSerializeBuffer.Bytes()...) } func extractL3PayloadFromEthernet(data []byte) ([]byte, bool) { diff --git a/io/nfqueue.go b/io/nfqueue.go index 683dc0e..07301c4 100644 --- a/io/nfqueue.go +++ b/io/nfqueue.go @@ -456,7 +456,7 @@ func ctIDFromCtBytes(ct []byte) uint32 { return 0 } for _, attr := range ctAttrs { - if attr.Type == 12 { // CTA_ID + if attr.Type == 12 && len(attr.Data) >= 4 { // CTA_ID return binary.BigEndian.Uint32(attr.Data) } } diff --git a/ruleset/builtins/geo/geo_loader.go b/ruleset/builtins/geo/geo_loader.go index d99853b..30dfdb0 100644 --- a/ruleset/builtins/geo/geo_loader.go +++ b/ruleset/builtins/geo/geo_loader.go @@ -4,6 +4,7 @@ import ( "io" "net/http" "os" + "sync" "time" "git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo" @@ -31,6 +32,7 @@ type V2GeoLoader struct { DownloadFunc func(filename, url string) DownloadErrFunc func(err error) + mu sync.Mutex geoipMap map[string]*v2geo.GeoIP geositeMap map[string]*v2geo.GeoSite } @@ -80,6 +82,8 @@ func (l *V2GeoLoader) download(filename, url string) error { } func (l *V2GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) { + l.mu.Lock() + defer l.mu.Unlock() if l.geoipMap != nil { return l.geoipMap, nil } @@ -104,6 +108,8 @@ func (l *V2GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) { } func (l *V2GeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) { + l.mu.Lock() + defer l.mu.Unlock() if l.geositeMap != nil { return l.geositeMap, nil } diff --git a/ruleset/expr.go b/ruleset/expr.go index 182a41e..7838db9 100644 --- a/ruleset/expr.go +++ b/ruleset/expr.go @@ -519,7 +519,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]* InitFunc: geoMatcher.LoadGeoIP, PatchFunc: nil, Func: func(params ...any) (any, error) { - return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil + a, ok1 := params[0].(string) + b, ok2 := params[1].(string) + if !ok1 || !ok2 { + return false, nil + } + return geoMatcher.MatchGeoIp(a, b), nil }, Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)}, }, @@ -527,7 +532,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]* InitFunc: geoMatcher.LoadGeoSite, PatchFunc: nil, Func: func(params ...any) (any, error) { - return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil + a, ok1 := params[0].(string) + b, ok2 := params[1].(string) + if !ok1 || !ok2 { + return false, nil + } + return geoMatcher.MatchGeoSite(a, b), nil }, Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)}, }, @@ -535,7 +545,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]* InitFunc: geoMatcher.LoadGeoSite, PatchFunc: nil, Func: func(params ...any) (any, error) { - return geoMatcher.MatchGeoSiteSet(params[0].(string), params[1].(*geo.SiteConditionSet)), nil + a, ok1 := params[0].(string) + b, ok2 := params[1].(*geo.SiteConditionSet) + if !ok1 || !ok2 { + return false, nil + } + return geoMatcher.MatchGeoSiteSet(a, b), nil }, Types: []reflect.Type{ reflect.TypeOf((func(string, *geo.SiteConditionSet) bool)(nil)), @@ -556,7 +571,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]* return nil }, Func: func(params ...any) (any, error) { - return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil + a, ok1 := params[0].(string) + b, ok2 := params[1].(*net.IPNet) + if !ok1 || !ok2 { + return false, nil + } + return builtins.MatchCIDR(a, b), nil }, Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)}, }, @@ -565,7 +585,6 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]* PatchFunc: func(args *[]ast.Node) error { var serverStr *ast.StringNode if len(*args) > 1 { - // Has the optional server argument var ok bool serverStr, ok = (*args)[1].(*ast.StringNode) if !ok { @@ -595,9 +614,14 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]* stats.LookupLatencyNanos.Add(uint64(time.Since(start).Nanoseconds())) }() } + a, ok1 := params[0].(string) + b, ok2 := params[1].(*net.Resolver) + if !ok1 || !ok2 { + return nil, nil + } ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) defer cancel() - out, err := params[1].(*net.Resolver).LookupHost(ctx, params[0].(string)) + out, err := b.LookupHost(ctx, a) if err != nil && stats != nil { stats.LookupErrors.Add(1) }