fix: eliminate stale verdict poisoning, memory leaks, data races, and per-packet allocations in engine

This commit is contained in:
2026-05-15 02:08:22 +00:00
parent bc25169f41
commit 301c252c43
15 changed files with 222 additions and 163 deletions
+2 -2
View File
@@ -130,8 +130,8 @@ func (s *httpStream) parseResponseLine() utils.LSMAction {
return utils.LSMActionCancel return utils.LSMActionCancel
} }
version := fields[0] version := fields[0]
status, _ := strconv.Atoi(fields[1]) status, err := strconv.Atoi(fields[1])
if !strings.HasPrefix(version, "HTTP/") || status == 0 { if err != nil || !strings.HasPrefix(version, "HTTP/") || status == 0 {
// Invalid version // Invalid version
return utils.LSMActionCancel return utils.LSMActionCancel
} }
+4 -2
View File
@@ -6,6 +6,8 @@ import (
"git.difuse.io/Difuse/Mellaris/analyzer/utils" "git.difuse.io/Difuse/Mellaris/analyzer/utils"
) )
const maxHandshakeLen = 65536
var _ analyzer.TCPAnalyzer = (*TLSAnalyzer)(nil) var _ analyzer.TCPAnalyzer = (*TLSAnalyzer)(nil)
type TLSAnalyzer struct{} type TLSAnalyzer struct{}
@@ -123,7 +125,7 @@ func (s *tlsStream) tlsClientHelloPreprocess() utils.LSMAction {
} }
s.clientHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8]) s.clientHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8])
if s.clientHelloLen < minDataSize { if s.clientHelloLen < minDataSize || s.clientHelloLen > maxHandshakeLen {
return utils.LSMActionCancel return utils.LSMActionCancel
} }
@@ -167,7 +169,7 @@ func (s *tlsStream) tlsServerHelloPreprocess() utils.LSMAction {
} }
s.serverHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8]) s.serverHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8])
if s.serverHelloLen < minDataSize { if s.serverHelloLen < minDataSize || s.serverHelloLen > maxHandshakeLen {
return utils.LSMActionCancel return utils.LSMActionCancel
} }
+3 -2
View File
@@ -38,6 +38,7 @@ const (
OpenVPNMinPktLen = 6 OpenVPNMinPktLen = 6
OpenVPNTCPPktDefaultLimit = 256 OpenVPNTCPPktDefaultLimit = 256
OpenVPNUDPPktDefaultLimit = 256 OpenVPNUDPPktDefaultLimit = 256
OpenVPNTCPMaxPktLen = 4096
) )
type OpenVPNAnalyzer struct{} type OpenVPNAnalyzer struct{}
@@ -195,7 +196,7 @@ func newOpenVPNUDPStream(logger analyzer.Logger) *openvpnUDPStream {
} }
func (o *openvpnUDPStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, d bool) { func (o *openvpnUDPStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, d bool) {
if len(data) == 0 { if len(data) < OpenVPNMinPktLen {
return nil, false return nil, false
} }
var update *analyzer.PropUpdate var update *analyzer.PropUpdate
@@ -338,7 +339,7 @@ func (o *openvpnTCPStream) parsePkt(rev bool) (p *openvpnPkt, action utils.LSMAc
return nil, utils.LSMActionPause return nil, utils.LSMActionPause
} }
if pktLen < OpenVPNMinPktLen { if pktLen < OpenVPNMinPktLen || pktLen > OpenVPNTCPMaxPktLen {
return nil, utils.LSMActionCancel return nil, utils.LSMActionCancel
} }
+4
View File
@@ -14,6 +14,7 @@ import (
const ( const (
quicInvalidCountThreshold = 16 quicInvalidCountThreshold = 16
quicMaxCryptoDataLen = 256 * 1024 quicMaxCryptoDataLen = 256 * 1024
quicMaxFrameEntries = 100
) )
var ( var (
@@ -158,6 +159,9 @@ func (s *quicStream) mergeFrame(offset int64, data []byte) {
if len(data) == 0 || offset < 0 { if len(data) == 0 || offset < 0 {
return return
} }
if len(s.frames) >= quicMaxFrameEntries {
return
}
if s.frames == nil { if s.frames == nil {
s.frames = make(map[int64][]byte) s.frames = make(map[int64][]byte)
} }
+40 -11
View File
@@ -5,6 +5,7 @@ import (
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset" "git.difuse.io/Difuse/Mellaris/ruleset"
@@ -13,10 +14,16 @@ import (
var _ Engine = (*engine)(nil) var _ Engine = (*engine)(nil)
type verdictEntry struct { type verdictEntry struct {
Verdict io.Verdict Verdict io.Verdict
Gen int64 Gen int64
CreatedAt time.Time
} }
const (
verdictTTL = 15 * time.Second
verdictSweepInterval = 15 * time.Second
)
type engine struct { type engine struct {
logger Logger logger Logger
io io.PacketIO io io.PacketIO
@@ -39,7 +46,7 @@ func NewEngine(config Config) (Engine, error) {
} }
overflowPolicy := config.OverflowPolicy overflowPolicy := config.OverflowPolicy
if overflowPolicy == "" { if overflowPolicy == "" {
overflowPolicy = OverflowPolicyAccept overflowPolicy = OverflowPolicyDrop
} }
selectionMode := config.AnalyzerSelectionMode selectionMode := config.AnalyzerSelectionMode
if selectionMode == "" { if selectionMode == "" {
@@ -83,7 +90,6 @@ func NewEngine(config Config) (Engine, error) {
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error { func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
e.verdictsGen.Add(1) e.verdictsGen.Add(1)
e.verdicts = sync.Map{}
for _, w := range e.workers { for _, w := range e.workers {
if err := w.UpdateRuleset(r); err != nil { if err := w.UpdateRuleset(r); err != nil {
return err return err
@@ -100,6 +106,7 @@ func (e *engine) Run(ctx context.Context) error {
go w.Run(ioCtx) go w.Run(ioCtx)
} }
go e.drainResults(ioCtx) go e.drainResults(ioCtx)
go e.sweepVerdicts(ioCtx)
errChan := make(chan error, 1) errChan := make(chan error, 1)
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool { err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
@@ -124,11 +131,13 @@ func (e *engine) Run(ctx context.Context) error {
func (e *engine) dispatch(p io.Packet) bool { func (e *engine) dispatch(p io.Packet) bool {
streamID := p.StreamID() streamID := p.StreamID()
if v, ok := e.verdicts.Load(streamID); ok { if streamID != 0 {
entry := v.(verdictEntry) if v, ok := e.verdicts.Load(streamID); ok {
if entry.Gen == e.verdictsGen.Load() { entry := v.(verdictEntry)
_ = e.io.SetVerdict(p, entry.Verdict, nil) if entry.Gen == e.verdictsGen.Load() {
return true _ = e.io.SetVerdict(p, entry.Verdict, nil)
return true
}
} }
} }
@@ -163,12 +172,32 @@ func (e *engine) dispatch(p io.Packet) bool {
} }
func (e *engine) applyWorkerResult(r workerResult) { func (e *engine) applyWorkerResult(r workerResult) {
if r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream { if r.StreamID != 0 && (r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream) {
e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen}) e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen, CreatedAt: time.Now()})
} }
_ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket) _ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket)
} }
func (e *engine) sweepVerdicts(ctx context.Context) {
ticker := time.NewTicker(verdictSweepInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
e.verdicts.Range(func(key, value interface{}) bool {
entry := value.(verdictEntry)
if now.Sub(entry.CreatedAt) > verdictTTL {
e.verdicts.Delete(key)
}
return true
})
}
}
}
func validPacket(data []byte) bool { func validPacket(data []byte) bool {
if len(data) == 0 { if len(data) == 0 {
return false return false
+4 -1
View File
@@ -59,7 +59,10 @@ func ParseL3(data []byte) (l3 L3Info, transport []byte, ok bool) {
return return
} }
totalLen := int(uint16(data[2])<<8 | uint16(data[3])) totalLen := int(uint16(data[2])<<8 | uint16(data[3]))
if totalLen < int(ihl)*4 || totalLen > len(data) { if totalLen < int(ihl)*4 {
return
}
if totalLen > len(data) {
totalLen = len(data) totalLen = len(data)
} }
return L3Info{ return L3Info{
+13 -28
View File
@@ -1,7 +1,6 @@
package engine package engine
import ( import (
"net"
"testing" "testing"
"git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/analyzer"
@@ -9,8 +8,6 @@ import (
"git.difuse.io/Difuse/Mellaris/ruleset" "git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake" "github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
) )
type fixedRuleset struct { type fixedRuleset struct {
@@ -60,25 +57,19 @@ func TestUDPStreamUsesUpdatedRuleset(t *testing.T) {
Ruleset: fixedRuleset{action: ruleset.ActionAllow}, Ruleset: fixedRuleset{action: ruleset.ActionAllow},
} }
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 12345, BPort: 53}
udp := &layers.UDP{ payload := []byte("query")
SrcPort: 12345,
DstPort: 53,
BaseLayer: layers.BaseLayer{
Payload: []byte("query"),
},
}
ctx := &udpContext{Verdict: udpVerdictAccept} ctx := &udpContext{Verdict: udpVerdictAccept}
s := f.New(ipFlow, udp.TransportFlow(), udp, ctx) s := f.New(tuple, payload, ctx)
if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil { if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil {
t.Fatalf("update ruleset: %v", err) t.Fatalf("update ruleset: %v", err)
} }
if !s.Accept(udp, false, ctx) { if !s.Accept(false, ctx) {
t.Fatalf("unexpected Accept=false for virgin stream") t.Fatalf("unexpected Accept=false for virgin stream")
} }
s.Feed(udp, false, ctx) s.Feed(false, payload, ctx)
if ctx.Verdict != udpVerdictDropStream { if ctx.Verdict != udpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx.Verdict, udpVerdictDropStream) t.Fatalf("verdict=%v want=%v", ctx.Verdict, udpVerdictDropStream)
} }
@@ -96,21 +87,15 @@ func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
Ruleset: fixedRuleset{action: ruleset.ActionAllow}, Ruleset: fixedRuleset{action: ruleset.ActionAllow},
} }
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 12345, BPort: 53}
udp := &layers.UDP{ payload := []byte("query")
SrcPort: 12345,
DstPort: 53,
BaseLayer: layers.BaseLayer{
Payload: []byte("query"),
},
}
ctx1 := &udpContext{Verdict: udpVerdictAccept} ctx1 := &udpContext{Verdict: udpVerdictAccept}
s := f.New(ipFlow, udp.TransportFlow(), udp, ctx1) s := f.New(tuple, payload, ctx1)
if !s.Accept(udp, false, ctx1) { if !s.Accept(false, ctx1) {
t.Fatalf("unexpected Accept=false before first feed") t.Fatalf("unexpected Accept=false before first feed")
} }
s.Feed(udp, false, ctx1) s.Feed(false, payload, ctx1)
if ctx1.Verdict != udpVerdictAcceptStream { if ctx1.Verdict != udpVerdictAcceptStream {
t.Fatalf("verdict=%v want=%v", ctx1.Verdict, udpVerdictAcceptStream) t.Fatalf("verdict=%v want=%v", ctx1.Verdict, udpVerdictAcceptStream)
} }
@@ -120,16 +105,16 @@ func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
} }
ctx2 := &udpContext{Verdict: udpVerdictAccept} ctx2 := &udpContext{Verdict: udpVerdictAccept}
if !s.Accept(udp, false, ctx2) { if !s.Accept(false, ctx2) {
t.Fatalf("expected Accept=true after ruleset update") t.Fatalf("expected Accept=true after ruleset update")
} }
s.Feed(udp, false, ctx2) s.Feed(false, payload, ctx2)
if ctx2.Verdict != udpVerdictDropStream { if ctx2.Verdict != udpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx2.Verdict, udpVerdictDropStream) t.Fatalf("verdict=%v want=%v", ctx2.Verdict, udpVerdictDropStream)
} }
ctx3 := &udpContext{Verdict: udpVerdictAccept} ctx3 := &udpContext{Verdict: udpVerdictAccept}
if s.Accept(udp, false, ctx3) { if s.Accept(false, ctx3) {
t.Fatalf("expected Accept=false with unchanged ruleset and no active entries") t.Fatalf("expected Accept=false with unchanged ruleset and no active entries")
} }
if ctx3.Verdict != udpVerdictDropStream { if ctx3.Verdict != udpVerdictDropStream {
+24 -3
View File
@@ -3,6 +3,7 @@ package engine
import ( import (
"net" "net"
"sync" "sync"
"time"
"git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/io"
@@ -13,6 +14,8 @@ import (
const tcpFlowMaxBuffer = 16384 const tcpFlowMaxBuffer = 16384
const tcpFlowIdleTimeout = 10 * time.Minute
type tcpFlowDirection uint8 type tcpFlowDirection uint8
const ( const (
@@ -37,6 +40,7 @@ type tcpFlow struct {
doneEntries []*tcpFlowEntry doneEntries []*tcpFlowEntry
lastVerdict io.Verdict lastVerdict io.Verdict
feedCalled [2]bool feedCalled [2]bool
lastSeen time.Time
} }
type tcpFlowEntry struct { type tcpFlowEntry struct {
@@ -67,16 +71,17 @@ func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
expected := f.dirSeq[dir] expected := f.dirSeq[dir]
if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected { if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected {
f.feedCalled[dir] = true f.feedCalled[dir] = true
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
if len(f.dirBuf[dir]) <= tcpFlowMaxBuffer { if len(f.dirBuf[dir]) <= tcpFlowMaxBuffer {
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
propUpdated = f.feedAnalyzers(rev) propUpdated = f.feedAnalyzers(rev)
} }
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
} }
} }
f.runMatch(rs, version, rulesetChanged, propUpdated) f.runMatch(rs, version, rulesetChanged, propUpdated)
f.maybeFinalizeVerdict() f.maybeFinalizeVerdict()
f.lastSeen = time.Now()
return f.lastVerdict return f.lastVerdict
} }
@@ -218,7 +223,11 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
Props: make(analyzer.CombinedPropMap), Props: make(analyzer.CombinedPropMap),
} }
m.logger.TCPStreamNew(m.workerID, info) m.logger.TCPStreamNew(m.workerID, info)
rs, version := m.rulesetSource() var rs ruleset.Ruleset
var version uint64
if m.rulesetSource != nil {
rs, version = m.rulesetSource()
}
var ans []analyzer.TCPAnalyzer var ans []analyzer.TCPAnalyzer
if rs != nil { if rs != nil {
baseAns := rs.Analyzers(info) baseAns := rs.Analyzers(info)
@@ -255,6 +264,7 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
rulesetVersion: version, rulesetVersion: version,
activeEntries: entries, activeEntries: entries,
lastVerdict: io.VerdictAccept, lastVerdict: io.VerdictAccept,
lastSeen: time.Now(),
} }
flow.dirSeq[tcpDirC2S] = tcp.Seq + 1 flow.dirSeq[tcpDirC2S] = tcp.Seq + 1
return flow return flow
@@ -266,6 +276,17 @@ func (m *tcpFlowManager) updateRuleset(r ruleset.Ruleset, version uint64) {
} }
} }
func (m *tcpFlowManager) cleanupIdle(now time.Time) {
m.mu.Lock()
defer m.mu.Unlock()
for id, flow := range m.flows {
if now.Sub(flow.lastSeen) > tcpFlowIdleTimeout {
flow.closeActiveEntries()
delete(m.flows, id)
}
}
}
func feedFlowEntry(entry *tcpFlowEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) { func feedFlowEntry(entry *tcpFlowEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
if !entry.HasLimit { if !entry.HasLimit {
update, done = entry.Stream.Feed(rev, true, false, 0, data) update, done = entry.Stream.Feed(rev, true, false, 0, data)
+50 -54
View File
@@ -12,8 +12,6 @@ import (
"git.difuse.io/Difuse/Mellaris/ruleset" "git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake" "github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
lru "github.com/hashicorp/golang-lru/v2" lru "github.com/hashicorp/golang-lru/v2"
) )
@@ -49,9 +47,10 @@ type udpStreamFactory struct {
RulesetVersion uint64 RulesetVersion uint64
} }
func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) *udpStream { func (f *udpStreamFactory) New(k udpTupleKey, payload []byte, uc *udpContext) *udpStream {
id := f.Node.Generate() id := f.Node.Generate()
ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw()) ipSrc := net.IP(k.AIP[:k.ALen])
ipDst := net.IP(k.BIP[:k.BLen])
info := ruleset.StreamInfo{ info := ruleset.StreamInfo{
ID: id.Int64(), ID: id.Int64(),
Protocol: ruleset.ProtocolUDP, Protocol: ruleset.ProtocolUDP,
@@ -59,8 +58,8 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
DstMAC: append(net.HardwareAddr(nil), uc.DstMAC...), DstMAC: append(net.HardwareAddr(nil), uc.DstMAC...),
SrcIP: ipSrc, SrcIP: ipSrc,
DstIP: ipDst, DstIP: ipDst,
SrcPort: uint16(udp.SrcPort), SrcPort: k.APort,
DstPort: uint16(udp.DstPort), DstPort: k.BPort,
Props: make(analyzer.CombinedPropMap), Props: make(analyzer.CombinedPropMap),
} }
f.Logger.UDPStreamNew(f.WorkerID, info) f.Logger.UDPStreamNew(f.WorkerID, info)
@@ -69,11 +68,10 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
if rs != nil { if rs != nil {
baseAns := rs.Analyzers(info) baseAns := rs.Analyzers(info)
if f.Selector != nil { if f.Selector != nil {
baseAns = f.Selector.SelectUDP(baseAns, udp.Payload) baseAns = f.Selector.SelectUDP(baseAns, payload)
} }
ans = analyzersToUDPAnalyzers(baseAns) ans = analyzersToUDPAnalyzers(baseAns)
} }
// Create entries for each analyzer
entries := make([]*udpStreamEntry, 0, len(ans)) entries := make([]*udpStreamEntry, 0, len(ans))
for _, a := range ans { for _, a := range ans {
entries = append(entries, &udpStreamEntry{ entries = append(entries, &udpStreamEntry{
@@ -81,8 +79,8 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
Stream: a.NewUDP(analyzer.UDPInfo{ Stream: a.NewUDP(analyzer.UDPInfo{
SrcIP: ipSrc, SrcIP: ipSrc,
DstIP: ipDst, DstIP: ipDst,
SrcPort: uint16(udp.SrcPort), SrcPort: k.APort,
DstPort: uint16(udp.DstPort), DstPort: k.BPort,
}, &analyzerLogger{ }, &analyzerLogger{
StreamID: id.Int64(), StreamID: id.Int64(),
Name: a.Name(), Name: a.Name(),
@@ -125,9 +123,14 @@ type udpStreamManager struct {
} }
type udpStreamValue struct { type udpStreamValue struct {
Stream *udpStream Stream *udpStream
IPFlow gopacket.Flow Tuple udpTupleKey
UDPFlow gopacket.Flow }
func (v *udpStreamValue) Match(k udpTupleKey) (ok, rev bool) {
fwd := v.Tuple == k
rev = v.Tuple == reverseTuple(k)
return fwd || rev, rev
} }
type udpTupleKey struct { type udpTupleKey struct {
@@ -139,12 +142,6 @@ type udpTupleKey struct {
BPort uint16 BPort uint16
} }
func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) {
fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow
rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()
return fwd || rev, rev
}
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) { func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) {
m := &udpStreamManager{ m := &udpStreamManager{
factory: factory, factory: factory,
@@ -153,6 +150,9 @@ func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *stats
stats: stats, stats: stats,
} }
ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) { ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) {
if v != nil && v.Stream != nil {
v.Stream.Close()
}
m.removeTupleMappingLocked(k) m.removeTupleMappingLocked(k)
}) })
if err != nil { if err != nil {
@@ -162,16 +162,12 @@ func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *stats
return m, nil return m, nil
} }
func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) { func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, rev bool, payload []byte, uc *udpContext) {
rev := false
value, ok := m.streams.Get(streamID) value, ok := m.streams.Get(streamID)
tuple := canonicalUDPTupleKey(ipFlow, udp)
if !ok { if !ok {
if m.stats != nil { if m.stats != nil {
m.stats.UDPTupleLookups.Add(1) m.stats.UDPTupleLookups.Add(1)
} }
// Conntrack IDs can change during early flow lifetime on some systems.
// Rebind by canonical 5-tuple in O(1).
matchedKey, found := m.tupleIndex[tuple] matchedKey, found := m.tupleIndex[tuple]
var matchedValue *udpStreamValue var matchedValue *udpStreamValue
var matchedRev bool var matchedRev bool
@@ -188,7 +184,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
} }
} }
if found { if found {
_, matchedRev = matchedValue.Match(ipFlow, udp.TransportFlow()) _, matchedRev = matchedValue.Match(tuple)
value = matchedValue value = matchedValue
rev = matchedRev rev = matchedRev
if matchedKey != streamID { if matchedKey != streamID {
@@ -197,32 +193,27 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
m.bindTupleLocked(streamID, tuple) m.bindTupleLocked(streamID, tuple)
} }
} else { } else {
// New stream
value = &udpStreamValue{ value = &udpStreamValue{
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc), Stream: m.factory.New(tuple, payload, uc),
IPFlow: ipFlow, Tuple: tuple,
UDPFlow: udp.TransportFlow(),
} }
m.streams.Add(streamID, value) m.streams.Add(streamID, value)
m.bindTupleLocked(streamID, tuple) m.bindTupleLocked(streamID, tuple)
} }
} else { } else {
// Stream ID exists, but is it really the same stream? ok, rev = value.Match(tuple)
ok, rev = value.Match(ipFlow, udp.TransportFlow())
if !ok { if !ok {
// It's not - close the old stream & replace it with a new one
value.Stream.Close() value.Stream.Close()
value = &udpStreamValue{ value = &udpStreamValue{
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc), Stream: m.factory.New(tuple, payload, uc),
IPFlow: ipFlow, Tuple: tuple,
UDPFlow: udp.TransportFlow(),
} }
m.streams.Add(streamID, value) m.streams.Add(streamID, value)
m.bindTupleLocked(streamID, tuple) m.bindTupleLocked(streamID, tuple)
} }
} }
if value.Stream.Accept(udp, rev, uc) { if value.Stream.Accept(rev, uc) {
value.Stream.Feed(udp, rev, uc) value.Stream.Feed(rev, payload, uc)
} }
} }
@@ -242,25 +233,34 @@ func (m *udpStreamManager) removeTupleMappingLocked(streamID uint32) {
} }
} }
func canonicalUDPTupleKey(ipFlow gopacket.Flow, udp *layers.UDP) udpTupleKey { func canonicalUDPTupleKey(srcIP, dstIP net.IP, srcPort, dstPort uint16) udpTupleKey {
srcIP := ipFlow.Src().Raw() srcRaw := []byte(srcIP)
dstIP := ipFlow.Dst().Raw() dstRaw := []byte(dstIP)
srcPort := uint16(udp.SrcPort)
dstPort := uint16(udp.DstPort)
if compareIPEndpoint(srcIP, srcPort, dstIP, dstPort) > 0 { if compareIPEndpoint(srcRaw, srcPort, dstRaw, dstPort) > 0 {
srcIP, dstIP = dstIP, srcIP srcRaw, dstRaw = dstRaw, srcRaw
srcPort, dstPort = dstPort, srcPort srcPort, dstPort = dstPort, srcPort
} }
var key udpTupleKey var key udpTupleKey
key.ALen = uint8(copy(key.AIP[:], srcIP)) key.ALen = uint8(copy(key.AIP[:], srcRaw))
key.BLen = uint8(copy(key.BIP[:], dstIP)) key.BLen = uint8(copy(key.BIP[:], dstRaw))
key.APort = srcPort key.APort = srcPort
key.BPort = dstPort key.BPort = dstPort
return key return key
} }
func reverseTuple(k udpTupleKey) udpTupleKey {
var r udpTupleKey
r.ALen = k.BLen
r.BLen = k.ALen
r.AIP = k.BIP
r.BIP = k.AIP
r.APort = k.BPort
r.BPort = k.APort
return r
}
func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int { func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int {
if len(aIP) != len(bIP) { if len(aIP) != len(bIP) {
if len(aIP) < len(bIP) { if len(aIP) < len(bIP) {
@@ -298,11 +298,8 @@ type udpStreamEntry struct {
Quota int Quota int
} }
func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool { func (s *udpStream) Accept(rev bool, uc *udpContext) bool {
if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() { if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() {
// Make sure every stream matches against the ruleset at least once,
// even if there are no activeEntries, as the ruleset may have built-in
// properties that need to be matched.
return true return true
} else { } else {
uc.Verdict = s.lastVerdict uc.Verdict = s.lastVerdict
@@ -310,12 +307,11 @@ func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool {
} }
} }
func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) { func (s *udpStream) Feed(rev bool, payload []byte, uc *udpContext) {
updated := false updated := false
for i := len(s.activeEntries) - 1; i >= 0; i-- { for i := len(s.activeEntries) - 1; i >= 0; i-- {
// Important: reverse order so we can remove entries
entry := s.activeEntries[i] entry := s.activeEntries[i]
update, closeUpdate, done := s.feedEntry(entry, rev, udp.Payload) update, closeUpdate, done := s.feedEntry(entry, rev, payload)
up1 := processPropUpdate(s.info.Props, entry.Name, update) up1 := processPropUpdate(s.info.Props, entry.Name, update)
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate) up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
updated = updated || up1 || up2 updated = updated || up1 || up2
@@ -345,7 +341,7 @@ func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) {
action = ruleset.ActionMaybe action = ruleset.ActionMaybe
} else { } else {
var err error var err error
uc.Packet, err = udpMI.Process(udp.Payload) uc.Packet, err = udpMI.Process(payload)
if err != nil { if err != nil {
// Modifier error, fallback to maybe // Modifier error, fallback to maybe
s.logger.ModifyError(s.info, err) s.logger.ModifyError(s.info, err)
+26 -30
View File
@@ -1,20 +1,16 @@
package engine package engine
import ( import (
"net"
"testing" "testing"
"git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/ruleset" "git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake" "github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
) )
type legacyUDPStreamValue struct { type legacyUDPStreamValue struct {
IPFlow gopacket.Flow Tuple udpTupleKey
UDPFlow gopacket.Flow
} }
type emptyRuleset struct{} type emptyRuleset struct{}
@@ -36,17 +32,20 @@ func benchmarkUDPManager(b *testing.B, churn bool) {
} }
const flowCount = 20000 const flowCount = 20000
flows := make([]gopacket.Flow, flowCount) tuples := make([]udpTupleKey, flowCount)
udps := make([]*layers.UDP, flowCount) payloads := make([][]byte, flowCount)
for i := 0; i < flowCount; i++ { for i := 0; i < flowCount; i++ {
a := byte(i >> 8) a := byte(i >> 8)
c := byte(i) c := byte(i)
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4()) var t udpTupleKey
udps[i] = &layers.UDP{ t.AIP = [16]byte{10, a, 0, c}
SrcPort: layers.UDPPort(1024 + i%20000), t.ALen = 4
DstPort: layers.UDPPort(20000 + (i*7)%20000), t.BIP = [16]byte{172, 16, a, c}
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}, t.BLen = 4
} t.APort = 1024 + uint16(i%20000)
t.BPort = 20000 + uint16((i*7)%20000)
tuples[i] = t
payloads[i] = []byte{0x01, 0x00, 0x00, 0x00}
} }
ctx := &udpContext{Verdict: udpVerdictAccept} ctx := &udpContext{Verdict: udpVerdictAccept}
@@ -59,7 +58,7 @@ func benchmarkUDPManager(b *testing.B, churn bool) {
} }
ctx.Verdict = udpVerdictAccept ctx.Verdict = udpVerdictAccept
ctx.Packet = nil ctx.Packet = nil
mgr.MatchWithContext(streamID, flows[idx], udps[idx], ctx) mgr.MatchWithContext(streamID, tuples[idx], false, payloads[idx], ctx)
} }
} }
@@ -73,27 +72,25 @@ func BenchmarkUDPManagerMatchStreamIDChurn(b *testing.B) {
func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) { func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) {
const flowCount = 5000 const flowCount = 5000
flows := make([]gopacket.Flow, flowCount) tuples := make([]udpTupleKey, flowCount)
udps := make([]*layers.UDP, flowCount)
for i := 0; i < flowCount; i++ { for i := 0; i < flowCount; i++ {
a := byte(i >> 8) a := byte(i >> 8)
c := byte(i) c := byte(i)
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4()) var t udpTupleKey
udps[i] = &layers.UDP{ t.AIP = [16]byte{10, a, 0, c}
SrcPort: layers.UDPPort(1024 + i%20000), t.ALen = 4
DstPort: layers.UDPPort(20000 + (i*7)%20000), t.BIP = [16]byte{172, 16, a, c}
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}, t.BLen = 4
} t.APort = 1024 + uint16(i%20000)
t.BPort = 20000 + uint16((i*7)%20000)
tuples[i] = t
} }
streams := make(map[uint32]*legacyUDPStreamValue, flowCount) streams := make(map[uint32]*legacyUDPStreamValue, flowCount)
keys := make([]uint32, 0, flowCount) keys := make([]uint32, 0, flowCount)
for i := 0; i < flowCount; i++ { for i := 0; i < flowCount; i++ {
streamID := uint32(i + 1) streamID := uint32(i + 1)
streams[streamID] = &legacyUDPStreamValue{ streams[streamID] = &legacyUDPStreamValue{Tuple: tuples[i]}
IPFlow: flows[i],
UDPFlow: udps[i].TransportFlow(),
}
keys = append(keys, streamID) keys = append(keys, streamID)
} }
@@ -104,15 +101,14 @@ func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) {
if _, ok := streams[streamID]; ok { if _, ok := streams[streamID]; ok {
continue continue
} }
ipFlow := flows[idx] tuple := tuples[idx]
udpFlow := udps[idx].TransportFlow() revTuple := reverseTuple(tuple)
for _, k := range keys { for _, k := range keys {
v, ok := streams[k] v, ok := streams[k]
if !ok || v == nil { if !ok || v == nil {
continue continue
} }
if (v.IPFlow == ipFlow && v.UDPFlow == udpFlow) || if v.Tuple == tuple || v.Tuple == revTuple {
(v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()) {
delete(streams, k) delete(streams, k)
streams[streamID] = v streams[streamID] = v
break break
+4 -7
View File
@@ -1,7 +1,6 @@
package engine package engine
import ( import (
"net"
"sync/atomic" "sync/atomic"
"testing" "testing"
@@ -9,8 +8,6 @@ import (
"git.difuse.io/Difuse/Mellaris/ruleset" "git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake" "github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
) )
type countingRuleset struct { type countingRuleset struct {
@@ -54,17 +51,17 @@ func TestUDPStreamManagerRebindsByTupleInO1Path(t *testing.T) {
t.Fatalf("new manager: %v", err) t.Fatalf("new manager: %v", err)
} }
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 50000, BPort: 443}
udp := &layers.UDP{SrcPort: 50000, DstPort: 443, BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}} payload := []byte{0x01, 0x00, 0x00, 0x00}
ctx1 := &udpContext{Verdict: udpVerdictAccept} ctx1 := &udpContext{Verdict: udpVerdictAccept}
mgr.MatchWithContext(100, ipFlow, udp, ctx1) mgr.MatchWithContext(100, tuple, false, payload, ctx1)
if got := newCalls.Load(); got != 1 { if got := newCalls.Load(); got != 1 {
t.Fatalf("new stream calls=%d want=1", got) t.Fatalf("new stream calls=%d want=1", got)
} }
ctx2 := &udpContext{Verdict: udpVerdictAccept} ctx2 := &udpContext{Verdict: udpVerdictAccept}
mgr.MatchWithContext(200, ipFlow, udp, ctx2) mgr.MatchWithContext(200, tuple, false, payload, ctx2)
if got := newCalls.Load(); got != 1 { if got := newCalls.Load(); got != 1 {
t.Fatalf("expected stream reuse by tuple, new stream calls=%d want=1", got) t.Fatalf("expected stream reuse by tuple, new stream calls=%d want=1", got)
} }
+11 -16
View File
@@ -3,6 +3,7 @@ package engine
import ( import (
"context" "context"
"net" "net"
"time"
"git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset" "git.difuse.io/Difuse/Mellaris/ruleset"
@@ -119,10 +120,16 @@ func (w *worker) FeedBlocking(p *workerPacket) {
func (w *worker) Run(ctx context.Context) { func (w *worker) Run(ctx context.Context) {
w.logger.WorkerStart(w.id) w.logger.WorkerStart(w.id)
defer w.logger.WorkerStop(w.id) defer w.logger.WorkerStop(w.id)
tcpSweepTicker := time.NewTicker(1 * time.Minute)
defer tcpSweepTicker.Stop()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-tcpSweepTicker.C:
w.tcpFlowMgr.cleanupIdle(time.Now())
case wp := <-w.packetChan: case wp := <-w.packetChan:
if wp == nil { if wp == nil {
return return
@@ -202,15 +209,6 @@ func (w *worker) handleIPPacket(wp *workerPacket, data []byte) (io.Verdict, []by
func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) { func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) {
ipSrc := l3.SrcIPAddr() ipSrc := l3.SrcIPAddr()
ipDst := l3.DstIPAddr() ipDst := l3.DstIPAddr()
endpointType := layers.EndpointIPv4
flowSrc := ipSrc.To4()
flowDst := ipDst.To4()
if l3.Version == 6 {
endpointType = layers.EndpointIPv6
flowSrc = ipSrc.To16()
flowDst = ipDst.To16()
}
ipFlow := gopacket.NewFlow(endpointType, flowSrc, flowDst)
if len(srcMAC) == 0 && w.macResolver != nil { if len(srcMAC) == 0 && w.macResolver != nil {
srcMAC = w.macResolver.Resolve(ipSrc) srcMAC = w.macResolver.Resolve(ipSrc)
@@ -221,12 +219,9 @@ func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []by
SrcMAC: srcMAC, SrcMAC: srcMAC,
DstMAC: dstMAC, DstMAC: dstMAC,
} }
// Temporarily set payload on a UDP layer so existing UDP handling works.
w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{ tuple := canonicalUDPTupleKey(ipSrc, ipDst, udp.SrcPort, udp.DstPort)
BaseLayer: layers.BaseLayer{Payload: payload}, w.udpSM.MatchWithContext(streamID, tuple, false, payload, uc)
SrcPort: layers.UDPPort(udp.SrcPort),
DstPort: layers.UDPPort(udp.DstPort),
}, uc)
return io.Verdict(uc.Verdict), uc.Packet return io.Verdict(uc.Verdict), uc.Packet
} }
@@ -253,7 +248,7 @@ func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, modPayload []b
if err != nil { if err != nil {
return io.VerdictAccept, nil return io.VerdictAccept, nil
} }
return io.VerdictAcceptModify, w.modSerializeBuffer.Bytes() return io.VerdictAcceptModify, append([]byte(nil), w.modSerializeBuffer.Bytes()...)
} }
func extractL3PayloadFromEthernet(data []byte) ([]byte, bool) { func extractL3PayloadFromEthernet(data []byte) ([]byte, bool) {
+1 -1
View File
@@ -456,7 +456,7 @@ func ctIDFromCtBytes(ct []byte) uint32 {
return 0 return 0
} }
for _, attr := range ctAttrs { for _, attr := range ctAttrs {
if attr.Type == 12 { // CTA_ID if attr.Type == 12 && len(attr.Data) >= 4 { // CTA_ID
return binary.BigEndian.Uint32(attr.Data) return binary.BigEndian.Uint32(attr.Data)
} }
} }
+6
View File
@@ -4,6 +4,7 @@ import (
"io" "io"
"net/http" "net/http"
"os" "os"
"sync"
"time" "time"
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo" "git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
@@ -31,6 +32,7 @@ type V2GeoLoader struct {
DownloadFunc func(filename, url string) DownloadFunc func(filename, url string)
DownloadErrFunc func(err error) DownloadErrFunc func(err error)
mu sync.Mutex
geoipMap map[string]*v2geo.GeoIP geoipMap map[string]*v2geo.GeoIP
geositeMap map[string]*v2geo.GeoSite geositeMap map[string]*v2geo.GeoSite
} }
@@ -80,6 +82,8 @@ func (l *V2GeoLoader) download(filename, url string) error {
} }
func (l *V2GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) { func (l *V2GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) {
l.mu.Lock()
defer l.mu.Unlock()
if l.geoipMap != nil { if l.geoipMap != nil {
return l.geoipMap, nil return l.geoipMap, nil
} }
@@ -104,6 +108,8 @@ func (l *V2GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) {
} }
func (l *V2GeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) { func (l *V2GeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) {
l.mu.Lock()
defer l.mu.Unlock()
if l.geositeMap != nil { if l.geositeMap != nil {
return l.geositeMap, nil return l.geositeMap, nil
} }
+30 -6
View File
@@ -519,7 +519,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
InitFunc: geoMatcher.LoadGeoIP, InitFunc: geoMatcher.LoadGeoIP,
PatchFunc: nil, PatchFunc: nil,
Func: func(params ...any) (any, error) { Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil a, ok1 := params[0].(string)
b, ok2 := params[1].(string)
if !ok1 || !ok2 {
return false, nil
}
return geoMatcher.MatchGeoIp(a, b), nil
}, },
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)}, Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)},
}, },
@@ -527,7 +532,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
InitFunc: geoMatcher.LoadGeoSite, InitFunc: geoMatcher.LoadGeoSite,
PatchFunc: nil, PatchFunc: nil,
Func: func(params ...any) (any, error) { Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil a, ok1 := params[0].(string)
b, ok2 := params[1].(string)
if !ok1 || !ok2 {
return false, nil
}
return geoMatcher.MatchGeoSite(a, b), nil
}, },
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)}, Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
}, },
@@ -535,7 +545,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
InitFunc: geoMatcher.LoadGeoSite, InitFunc: geoMatcher.LoadGeoSite,
PatchFunc: nil, PatchFunc: nil,
Func: func(params ...any) (any, error) { Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoSiteSet(params[0].(string), params[1].(*geo.SiteConditionSet)), nil a, ok1 := params[0].(string)
b, ok2 := params[1].(*geo.SiteConditionSet)
if !ok1 || !ok2 {
return false, nil
}
return geoMatcher.MatchGeoSiteSet(a, b), nil
}, },
Types: []reflect.Type{ Types: []reflect.Type{
reflect.TypeOf((func(string, *geo.SiteConditionSet) bool)(nil)), reflect.TypeOf((func(string, *geo.SiteConditionSet) bool)(nil)),
@@ -556,7 +571,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
return nil return nil
}, },
Func: func(params ...any) (any, error) { Func: func(params ...any) (any, error) {
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil a, ok1 := params[0].(string)
b, ok2 := params[1].(*net.IPNet)
if !ok1 || !ok2 {
return false, nil
}
return builtins.MatchCIDR(a, b), nil
}, },
Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)}, Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)},
}, },
@@ -565,7 +585,6 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
PatchFunc: func(args *[]ast.Node) error { PatchFunc: func(args *[]ast.Node) error {
var serverStr *ast.StringNode var serverStr *ast.StringNode
if len(*args) > 1 { if len(*args) > 1 {
// Has the optional server argument
var ok bool var ok bool
serverStr, ok = (*args)[1].(*ast.StringNode) serverStr, ok = (*args)[1].(*ast.StringNode)
if !ok { if !ok {
@@ -595,9 +614,14 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
stats.LookupLatencyNanos.Add(uint64(time.Since(start).Nanoseconds())) stats.LookupLatencyNanos.Add(uint64(time.Since(start).Nanoseconds()))
}() }()
} }
a, ok1 := params[0].(string)
b, ok2 := params[1].(*net.Resolver)
if !ok1 || !ok2 {
return nil, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
defer cancel() defer cancel()
out, err := params[1].(*net.Resolver).LookupHost(ctx, params[0].(string)) out, err := b.LookupHost(ctx, a)
if err != nil && stats != nil { if err != nil && stats != nil {
stats.LookupErrors.Add(1) stats.LookupErrors.Add(1)
} }