Improves flow handling and adds runtime stats APIs

Refactors TCP and UDP flow managers to enhance analyzer selection and flow binding accuracy, including O(1) UDP stream rebinding by 5-tuple.
Introduces runtime stats tracking for engine and ruleset operations, exposing new APIs for granular performance and error metrics.
Optimizes GeoMatcher with result caching and supports efficient geosite set matching, reducing redundant computation in ruleset expressions.
This commit is contained in:
2026-05-13 06:10:38 +05:30
parent 3f895adb43
commit 7a3f6e945d
23 changed files with 1440 additions and 152 deletions
+23 -2
View File
@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"runtime"
"git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/engine" "git.difuse.io/Difuse/Mellaris/engine"
@@ -17,6 +18,7 @@ type App struct {
engine engine.Engine engine engine.Engine
io gfwio.PacketIO io gfwio.PacketIO
rulesetConfig *ruleset.BuiltinConfig rulesetConfig *ruleset.BuiltinConfig
ruleset ruleset.Ruleset
analyzers []analyzer.Analyzer analyzers []analyzer.Analyzer
modifiers []modifier.Modifier modifiers []modifier.Modifier
rulesFile string rulesFile string
@@ -42,6 +44,11 @@ func New(cfg Config, opts Options) (*App, error) {
packetIO := cfg.IO.PacketIO packetIO := cfg.IO.PacketIO
ownsIO := false ownsIO := false
workerCount := effectiveWorkerCount(cfg.Workers.Count)
numQueues := cfg.IO.NumQueues
if numQueues <= 0 {
numQueues = workerCount
}
if packetIO == nil { if packetIO == nil {
packetIO, err = gfwio.NewNFQueuePacketIO(gfwio.NFQueuePacketIOConfig{ packetIO, err = gfwio.NewNFQueuePacketIO(gfwio.NFQueuePacketIOConfig{
QueueSize: cfg.IO.QueueSize, QueueSize: cfg.IO.QueueSize,
@@ -49,7 +56,7 @@ func New(cfg Config, opts Options) (*App, error) {
WriteBuffer: cfg.IO.WriteBuffer, WriteBuffer: cfg.IO.WriteBuffer,
Local: cfg.IO.Local, Local: cfg.IO.Local,
RST: cfg.IO.RST, RST: cfg.IO.RST,
NumQueues: cfg.IO.NumQueues, NumQueues: numQueues,
MaxPacketLen: cfg.IO.MaxPacketLen, MaxPacketLen: cfg.IO.MaxPacketLen,
}) })
if err != nil { if err != nil {
@@ -79,11 +86,13 @@ func New(cfg Config, opts Options) (*App, error) {
Logger: engineLogger, Logger: engineLogger,
IO: packetIO, IO: packetIO,
Ruleset: rs, Ruleset: rs,
Workers: cfg.Workers.Count, Workers: workerCount,
WorkerQueueSize: cfg.Workers.QueueSize, WorkerQueueSize: cfg.Workers.QueueSize,
WorkerTCPMaxBufferedPagesTotal: cfg.Workers.TCPMaxBufferedPagesTotal, WorkerTCPMaxBufferedPagesTotal: cfg.Workers.TCPMaxBufferedPagesTotal,
WorkerTCPMaxBufferedPagesPerConn: cfg.Workers.TCPMaxBufferedPagesPerConn, WorkerTCPMaxBufferedPagesPerConn: cfg.Workers.TCPMaxBufferedPagesPerConn,
WorkerUDPMaxStreams: cfg.Workers.UDPMaxStreams, WorkerUDPMaxStreams: cfg.Workers.UDPMaxStreams,
OverflowPolicy: cfg.Workers.OverflowPolicy,
AnalyzerSelectionMode: cfg.Workers.AnalyzerSelectionMode,
} }
eng, err := engine.NewEngine(engCfg) eng, err := engine.NewEngine(engCfg)
if err != nil { if err != nil {
@@ -95,6 +104,7 @@ func New(cfg Config, opts Options) (*App, error) {
engine: eng, engine: eng,
io: packetIO, io: packetIO,
rulesetConfig: rsConfig, rulesetConfig: rsConfig,
ruleset: rs,
analyzers: analyzers, analyzers: analyzers,
modifiers: modifiers, modifiers: modifiers,
rulesFile: rulesFile, rulesFile: rulesFile,
@@ -140,6 +150,17 @@ func (a *App) Engine() engine.Engine {
return a.engine return a.engine
} }
func effectiveWorkerCount(configured int) int {
if configured > 0 {
return configured
}
n := runtime.GOMAXPROCS(0)
if n <= 0 {
return 1
}
return n
}
func resolveRules(opts Options) ([]ruleset.ExprRule, string, error) { func resolveRules(opts Options) ([]ruleset.ExprRule, string, error) {
if opts.RulesFile != "" && len(opts.Rules) > 0 { if opts.RulesFile != "" && len(opts.Rules) > 0 {
return nil, "", ConfigError{Field: "rules", Err: errors.New("use either RulesFile or Rules")} return nil, "", ConfigError{Field: "rules", Err: errors.New("use either RulesFile or Rules")}
+7 -5
View File
@@ -32,11 +32,13 @@ type IOConfig struct {
// WorkersConfig configures engine worker behavior. // WorkersConfig configures engine worker behavior.
type WorkersConfig struct { type WorkersConfig struct {
Count int `mapstructure:"count" yaml:"count"` Count int `mapstructure:"count" yaml:"count"`
QueueSize int `mapstructure:"queueSize" yaml:"queueSize"` QueueSize int `mapstructure:"queueSize" yaml:"queueSize"`
TCPMaxBufferedPagesTotal int `mapstructure:"tcpMaxBufferedPagesTotal" yaml:"tcpMaxBufferedPagesTotal"` TCPMaxBufferedPagesTotal int `mapstructure:"tcpMaxBufferedPagesTotal" yaml:"tcpMaxBufferedPagesTotal"`
TCPMaxBufferedPagesPerConn int `mapstructure:"tcpMaxBufferedPagesPerConn" yaml:"tcpMaxBufferedPagesPerConn"` TCPMaxBufferedPagesPerConn int `mapstructure:"tcpMaxBufferedPagesPerConn" yaml:"tcpMaxBufferedPagesPerConn"`
UDPMaxStreams int `mapstructure:"udpMaxStreams" yaml:"udpMaxStreams"` UDPMaxStreams int `mapstructure:"udpMaxStreams" yaml:"udpMaxStreams"`
OverflowPolicy engine.OverflowPolicy `mapstructure:"overflowPolicy" yaml:"overflowPolicy"`
AnalyzerSelectionMode engine.AnalyzerSelectionMode `mapstructure:"analyzerSelectionMode" yaml:"analyzerSelectionMode"`
} }
// RulesetConfig configures built-in rule helpers. // RulesetConfig configures built-in rule helpers.
+235
View File
@@ -0,0 +1,235 @@
package engine
import (
"bytes"
"strings"
"git.difuse.io/Difuse/Mellaris/analyzer"
)
type analyzerSelector struct {
mode AnalyzerSelectionMode
stats *statsCounters
}
func newAnalyzerSelector(mode AnalyzerSelectionMode, stats *statsCounters) *analyzerSelector {
if mode == "" {
mode = AnalyzerSelectionModeSignature
}
return &analyzerSelector{mode: mode, stats: stats}
}
func (s *analyzerSelector) SelectTCP(ans []analyzer.Analyzer, payload []byte) []analyzer.Analyzer {
if s == nil || s.mode == AnalyzerSelectionModeAlways || len(ans) <= 1 {
return ans
}
allowed := tcpAllowedAnalyzers(payload)
if len(allowed) == 0 {
return ans
}
out := make([]analyzer.Analyzer, 0, len(ans))
for _, a := range ans {
name := strings.ToLower(a.Name())
if _, known := knownTCPAnalyzers[name]; !known {
out = append(out, a)
continue
}
if allowed[name] {
out = append(out, a)
}
}
s.recordSelection(len(ans), len(out))
if len(out) == 0 {
return ans
}
return out
}
func (s *analyzerSelector) SelectUDP(ans []analyzer.Analyzer, payload []byte) []analyzer.Analyzer {
if s == nil || s.mode == AnalyzerSelectionModeAlways || len(ans) <= 1 {
return ans
}
allowed := udpAllowedAnalyzers(payload)
if len(allowed) == 0 {
return ans
}
out := make([]analyzer.Analyzer, 0, len(ans))
for _, a := range ans {
name := strings.ToLower(a.Name())
if _, known := knownUDPAnalyzers[name]; !known {
out = append(out, a)
continue
}
if allowed[name] {
out = append(out, a)
}
}
s.recordSelection(len(ans), len(out))
if len(out) == 0 {
return ans
}
return out
}
func (s *analyzerSelector) recordSelection(total, selected int) {
if s == nil || s.stats == nil || total <= 0 {
return
}
s.stats.AnalyzerSelectionsTotal.Add(1)
if selected < total {
s.stats.AnalyzerSelectionsPruned.Add(1)
}
}
var (
knownTCPAnalyzers = map[string]struct{}{
"fet": {},
"http": {},
"socks": {},
"ssh": {},
"tls": {},
"trojan": {},
"dns": {},
"openvpn": {},
}
knownUDPAnalyzers = map[string]struct{}{
"dns": {},
"openvpn": {},
"quic": {},
"wireguard": {},
}
)
func tcpAllowedAnalyzers(payload []byte) map[string]bool {
allowed := make(map[string]bool, 4)
if looksLikeTLS(payload) {
allowed["tls"] = true
allowed["trojan"] = true
allowed["fet"] = true
}
if looksLikeHTTP(payload) {
allowed["http"] = true
allowed["fet"] = true
}
if looksLikeSSH(payload) {
allowed["ssh"] = true
allowed["fet"] = true
}
if looksLikeSOCKS(payload) {
allowed["socks"] = true
allowed["fet"] = true
}
if looksLikeDNSTCP(payload) {
allowed["dns"] = true
allowed["fet"] = true
}
if len(allowed) == 0 {
return nil
}
return allowed
}
func udpAllowedAnalyzers(payload []byte) map[string]bool {
allowed := make(map[string]bool, 4)
if looksLikeWireGuard(payload) {
allowed["wireguard"] = true
}
if looksLikeOpenVPN(payload) {
allowed["openvpn"] = true
}
if looksLikeQUIC(payload) {
allowed["quic"] = true
}
if looksLikeDNSUDP(payload) {
allowed["dns"] = true
}
if len(allowed) == 0 {
return nil
}
return allowed
}
func looksLikeTLS(payload []byte) bool {
if len(payload) < 3 {
return false
}
return (payload[0] == 0x16 || payload[0] == 0x17) && payload[1] == 0x03 && payload[2] <= 0x09
}
func looksLikeHTTP(payload []byte) bool {
if len(payload) < 3 {
return false
}
head := strings.ToUpper(string(payload[:3]))
switch head {
case "GET", "HEA", "POS", "PUT", "DEL", "CON", "OPT", "TRA", "PAT":
return true
default:
return false
}
}
func looksLikeSSH(payload []byte) bool {
return len(payload) >= 4 && bytes.HasPrefix(payload, []byte("SSH-"))
}
func looksLikeSOCKS(payload []byte) bool {
if len(payload) < 2 {
return false
}
return payload[0] == 0x04 || payload[0] == 0x05
}
func looksLikeDNSTCP(payload []byte) bool {
if len(payload) < 14 {
return false
}
msgLen := int(payload[0])<<8 | int(payload[1])
if msgLen <= 0 || msgLen+2 > len(payload) {
return false
}
qd := int(payload[6])<<8 | int(payload[7])
an := int(payload[8])<<8 | int(payload[9])
return qd+an > 0
}
func looksLikeDNSUDP(payload []byte) bool {
if len(payload) < 12 {
return false
}
qd := int(payload[4])<<8 | int(payload[5])
an := int(payload[6])<<8 | int(payload[7])
ns := int(payload[8])<<8 | int(payload[9])
ar := int(payload[10])<<8 | int(payload[11])
return qd+an+ns+ar > 0
}
func looksLikeQUIC(payload []byte) bool {
if len(payload) < 6 {
return false
}
// Long header with non-zero version.
if payload[0]&0x80 == 0 {
return false
}
version := uint32(payload[1])<<24 | uint32(payload[2])<<16 | uint32(payload[3])<<8 | uint32(payload[4])
return version != 0
}
func looksLikeOpenVPN(payload []byte) bool {
if len(payload) == 0 {
return false
}
opcode := payload[0] >> 3
return opcode >= 1 && opcode <= 11
}
func looksLikeWireGuard(payload []byte) bool {
if len(payload) < 4 {
return false
}
if payload[0] < 1 || payload[0] > 4 {
return false
}
return payload[1] == 0 && payload[2] == 0 && payload[3] == 0
}
+56
View File
@@ -0,0 +1,56 @@
package engine
import (
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
)
type namedAnalyzer struct{ name string }
func (a namedAnalyzer) Name() string { return a.name }
func (a namedAnalyzer) Limit() int { return 0 }
func TestSignatureSelectorTCPPrunesByPayloadNotPort(t *testing.T) {
sel := newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{})
all := []analyzer.Analyzer{
namedAnalyzer{"http"},
namedAnalyzer{"tls"},
namedAnalyzer{"trojan"},
namedAnalyzer{"ssh"},
namedAnalyzer{"socks"},
namedAnalyzer{"fet"},
}
// TLS record-like prefix, regardless of destination port.
payload := []byte{0x16, 0x03, 0x03, 0x00, 0x10}
selected := sel.SelectTCP(all, payload)
got := make(map[string]bool)
for _, a := range selected {
got[a.Name()] = true
}
for _, keep := range []string{"tls", "trojan", "fet"} {
if !got[keep] {
t.Fatalf("expected analyzer %q to be selected", keep)
}
}
for _, drop := range []string{"http", "ssh", "socks"} {
if got[drop] {
t.Fatalf("expected analyzer %q to be pruned", drop)
}
}
}
func TestSignatureSelectorConservativeFallback(t *testing.T) {
sel := newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{})
all := []analyzer.Analyzer{
namedAnalyzer{"http"},
namedAnalyzer{"tls"},
namedAnalyzer{"custom"},
}
payload := []byte{0xde, 0xad, 0xbe, 0xef}
selected := sel.SelectTCP(all, payload)
if len(selected) != len(all) {
t.Fatalf("expected conservative fallback to keep all analyzers, got=%d want=%d", len(selected), len(all))
}
}
+68 -25
View File
@@ -2,6 +2,7 @@ package engine
import ( import (
"context" "context"
"runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -20,18 +21,34 @@ type engine struct {
logger Logger logger Logger
io io.PacketIO io io.PacketIO
workers []*worker workers []*worker
verdicts sync.Map // streamID(uint32) → verdictEntry stats *statsCounters
verdicts sync.Map // streamID(uint32) -> verdictEntry
verdictsGen atomic.Int64 // incremented on ruleset update verdictsGen atomic.Int64 // incremented on ruleset update
overflowCh chan *workerPacket overflowPolicy OverflowPolicy
overflowOnce sync.Once resultCh chan workerResult
} }
func NewEngine(config Config) (Engine, error) { func NewEngine(config Config) (Engine, error) {
workerCount := config.Workers workerCount := config.Workers
if workerCount <= 0 { if workerCount <= 0 {
workerCount = 1 workerCount = runtime.GOMAXPROCS(0)
if workerCount <= 0 {
workerCount = 1
}
} }
overflowPolicy := config.OverflowPolicy
if overflowPolicy == "" {
overflowPolicy = OverflowPolicyAccept
}
selectionMode := config.AnalyzerSelectionMode
if selectionMode == "" {
selectionMode = AnalyzerSelectionModeSignature
}
stats := &statsCounters{}
resultCh := make(chan workerResult, workerCount*256)
macResolver := newSourceMACResolver() macResolver := newSourceMACResolver()
var err error var err error
workers := make([]*worker, workerCount) workers := make([]*worker, workerCount)
@@ -45,16 +62,21 @@ func NewEngine(config Config) (Engine, error) {
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal, TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn, TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
UDPMaxStreams: config.WorkerUDPMaxStreams, UDPMaxStreams: config.WorkerUDPMaxStreams,
AnalyzerSelectionMode: selectionMode,
ResultChan: resultCh,
Stats: stats,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
e := &engine{ e := &engine{
logger: config.Logger, logger: config.Logger,
io: config.IO, io: config.IO,
workers: workers, workers: workers,
overflowCh: make(chan *workerPacket, 1024), stats: stats,
overflowPolicy: overflowPolicy,
resultCh: resultCh,
} }
return e, nil return e, nil
} }
@@ -74,13 +96,10 @@ func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx) ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel() defer ioCancel()
e.overflowOnce.Do(func() {
go e.drainOverflow(ioCtx)
})
for _, w := range e.workers { for _, w := range e.workers {
go w.Run(ioCtx) go w.Run(ioCtx)
} }
go e.drainResults(ioCtx)
errChan := make(chan error, 1) errChan := make(chan error, 1)
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool { err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
@@ -121,24 +140,35 @@ func (e *engine) dispatch(p io.Packet) bool {
gen := e.verdictsGen.Load() gen := e.verdictsGen.Load()
index := streamID % uint32(len(e.workers)) index := streamID % uint32(len(e.workers))
wp := &workerPacket{ wp := &workerPacket{
StreamID: streamID, Packet: p,
Data: data, StreamID: streamID,
SetVerdict: func(v io.Verdict, b []byte) error { Data: data,
if v == io.VerdictAcceptStream || v == io.VerdictDropStream { Gen: gen,
e.verdicts.Store(streamID, verdictEntry{Verdict: v, Gen: gen})
}
return e.io.SetVerdict(p, v, b)
},
} }
if !e.workers[index].Feed(wp) { if !e.workers[index].Feed(wp) {
select { e.stats.OverflowEvents.Add(1)
case e.overflowCh <- wp: switch e.overflowPolicy {
case OverflowPolicyDrop:
e.stats.OverflowDrops.Add(1)
_ = e.io.SetVerdict(p, io.VerdictDrop, nil)
case OverflowPolicyBackpressure:
e.stats.OverflowBackpressureEvents.Add(1)
e.workers[index].FeedBlocking(wp)
default: default:
e.stats.OverflowAccepts.Add(1)
_ = e.io.SetVerdict(p, io.VerdictAccept, nil)
} }
} }
return true return true
} }
func (e *engine) applyWorkerResult(r workerResult) {
if r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream {
e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen})
}
_ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket)
}
func validPacket(data []byte) bool { func validPacket(data []byte) bool {
if len(data) == 0 { if len(data) == 0 {
return false return false
@@ -156,13 +186,26 @@ func validPacket(data []byte) bool {
return false return false
} }
func (e *engine) drainOverflow(ctx context.Context) { func (e *engine) drainResults(ctx context.Context) {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case wp := <-e.overflowCh: case r := <-e.resultCh:
_ = wp.SetVerdict(io.VerdictAccept, nil) e.applyWorkerResult(r)
} }
} }
} }
func (e *engine) Stats() Stats {
return Stats{
OverflowEvents: e.stats.OverflowEvents.Load(),
OverflowAccepts: e.stats.OverflowAccepts.Load(),
OverflowDrops: e.stats.OverflowDrops.Load(),
OverflowBackpressureEvents: e.stats.OverflowBackpressureEvents.Load(),
AnalyzerSelectionsTotal: e.stats.AnalyzerSelectionsTotal.Load(),
AnalyzerSelectionsPruned: e.stats.AnalyzerSelectionsPruned.Load(),
UDPTupleLookups: e.stats.UDPTupleLookups.Load(),
UDPTupleHits: e.stats.UDPTupleHits.Load(),
}
}
+46
View File
@@ -2,6 +2,7 @@ package engine
import ( import (
"context" "context"
"sync/atomic"
"git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset" "git.difuse.io/Difuse/Mellaris/ruleset"
@@ -13,6 +14,49 @@ type Engine interface {
UpdateRuleset(ruleset.Ruleset) error UpdateRuleset(ruleset.Ruleset) error
// Run runs the engine, until an error occurs or the context is cancelled. // Run runs the engine, until an error occurs or the context is cancelled.
Run(context.Context) error Run(context.Context) error
// Stats returns a consistent snapshot of runtime counters.
Stats() Stats
}
type OverflowPolicy string
const (
OverflowPolicyAccept OverflowPolicy = "accept"
OverflowPolicyDrop OverflowPolicy = "drop"
OverflowPolicyBackpressure OverflowPolicy = "backpressure"
)
type AnalyzerSelectionMode string
const (
AnalyzerSelectionModeAlways AnalyzerSelectionMode = "always"
AnalyzerSelectionModeSignature AnalyzerSelectionMode = "signature"
)
type statsCounters struct {
OverflowEvents atomic.Uint64
OverflowAccepts atomic.Uint64
OverflowDrops atomic.Uint64
OverflowBackpressureEvents atomic.Uint64
AnalyzerSelectionsTotal atomic.Uint64
AnalyzerSelectionsPruned atomic.Uint64
UDPTupleLookups atomic.Uint64
UDPTupleHits atomic.Uint64
}
type Stats struct {
OverflowEvents uint64
OverflowAccepts uint64
OverflowDrops uint64
OverflowBackpressureEvents uint64
AnalyzerSelectionsTotal uint64
AnalyzerSelectionsPruned uint64
UDPTupleLookups uint64
UDPTupleHits uint64
} }
// Config is the configuration for the engine. // Config is the configuration for the engine.
@@ -26,6 +70,8 @@ type Config struct {
WorkerTCPMaxBufferedPagesTotal int WorkerTCPMaxBufferedPagesTotal int
WorkerTCPMaxBufferedPagesPerConn int WorkerTCPMaxBufferedPagesPerConn int
WorkerUDPMaxStreams int WorkerUDPMaxStreams int
OverflowPolicy OverflowPolicy
AnalyzerSelectionMode AnalyzerSelectionMode
} }
// Logger is the combined logging interface for the engine, workers and analyzers. // Logger is the combined logging interface for the engine, workers and analyzers.
+3
View File
@@ -1,3 +1,6 @@
//go:build linux
// +build linux
package engine package engine
import ( import (
+17
View File
@@ -0,0 +1,17 @@
//go:build !linux
// +build !linux
package engine
import "net"
type sourceMACResolver struct{}
func newSourceMACResolver() *sourceMACResolver {
return &sourceMACResolver{}
}
func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr {
_ = ip
return nil
}
+2 -2
View File
@@ -142,7 +142,7 @@ func TestTCPFlowUsesUpdatedRuleset(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("create node: %v", err) t.Fatalf("create node: %v", err)
} }
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node) mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, nil)
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0) mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0)
l3 := L3Info{ l3 := L3Info{
@@ -180,7 +180,7 @@ func TestTCPFlowReevaluatesAfterRulesetVersionChange(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("create node: %v", err) t.Fatalf("create node: %v", err)
} }
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node) mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, nil)
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0) mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0)
l3 := L3Info{ l3 := L3Info{
+8 -4
View File
@@ -163,15 +163,17 @@ type tcpFlowManager struct {
rulesetSource func() (ruleset.Ruleset, uint64) rulesetSource func() (ruleset.Ruleset, uint64)
workerID int workerID int
macResolver *sourceMACResolver macResolver *sourceMACResolver
selector *analyzerSelector
} }
func newTCPFlowManager(workerID int, logger Logger, macResolver *sourceMACResolver, node *snowflake.Node) *tcpFlowManager { func newTCPFlowManager(workerID int, logger Logger, macResolver *sourceMACResolver, node *snowflake.Node, selector *analyzerSelector) *tcpFlowManager {
return &tcpFlowManager{ return &tcpFlowManager{
flows: make(map[uint32]*tcpFlow), flows: make(map[uint32]*tcpFlow),
sfNode: node, sfNode: node,
logger: logger, logger: logger,
workerID: workerID, workerID: workerID,
macResolver: macResolver, macResolver: macResolver,
selector: selector,
} }
} }
@@ -179,7 +181,7 @@ func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload
m.mu.Lock() m.mu.Lock()
flow, ok := m.flows[streamID] flow, ok := m.flows[streamID]
if !ok { if !ok {
flow = m.createFlow(streamID, l3, tcp, srcMAC, dstMAC) flow = m.createFlow(streamID, l3, tcp, payload, srcMAC, dstMAC)
m.flows[streamID] = flow m.flows[streamID] = flow
} }
m.mu.Unlock() m.mu.Unlock()
@@ -195,7 +197,7 @@ func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload
return verdict return verdict
} }
func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, srcMAC, dstMAC net.HardwareAddr) *tcpFlow { func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) *tcpFlow {
id := m.sfNode.Generate() id := m.sfNode.Generate()
ipSrc := net.IP(l3.SrcIP[:]) ipSrc := net.IP(l3.SrcIP[:])
ipDst := net.IP(l3.DstIP[:]) ipDst := net.IP(l3.DstIP[:])
@@ -217,7 +219,9 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, src
rs, version := m.rulesetSource() rs, version := m.rulesetSource()
var ans []analyzer.TCPAnalyzer var ans []analyzer.TCPAnalyzer
if rs != nil { if rs != nil {
ans = analyzersToTCPAnalyzers(rs.Analyzers(info)) baseAns := rs.Analyzers(info)
baseAns = m.selector.SelectTCP(baseAns, payload)
ans = analyzersToTCPAnalyzers(baseAns)
} }
entries := make([]*tcpFlowEntry, 0, len(ans)) entries := make([]*tcpFlowEntry, 0, len(ans))
for _, a := range ans { for _, a := range ans {
+109 -21
View File
@@ -1,6 +1,7 @@
package engine package engine
import ( import (
"bytes"
"errors" "errors"
"net" "net"
"sync" "sync"
@@ -40,6 +41,8 @@ type udpStreamFactory struct {
WorkerID int WorkerID int
Logger Logger Logger Logger
Node *snowflake.Node Node *snowflake.Node
Selector *analyzerSelector
Stats *statsCounters
RulesetMutex sync.RWMutex RulesetMutex sync.RWMutex
Ruleset ruleset.Ruleset Ruleset ruleset.Ruleset
@@ -64,7 +67,11 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
rs, version := f.currentRuleset() rs, version := f.currentRuleset()
var ans []analyzer.UDPAnalyzer var ans []analyzer.UDPAnalyzer
if rs != nil { if rs != nil {
ans = analyzersToUDPAnalyzers(rs.Analyzers(info)) baseAns := rs.Analyzers(info)
if f.Selector != nil {
baseAns = f.Selector.SelectUDP(baseAns, udp.Payload)
}
ans = analyzersToUDPAnalyzers(baseAns)
} }
// Create entries for each analyzer // Create entries for each analyzer
entries := make([]*udpStreamEntry, 0, len(ans)) entries := make([]*udpStreamEntry, 0, len(ans))
@@ -110,8 +117,11 @@ func (f *udpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) {
} }
type udpStreamManager struct { type udpStreamManager struct {
factory *udpStreamFactory factory *udpStreamFactory
streams *lru.Cache[uint32, *udpStreamValue] streams *lru.Cache[uint32, *udpStreamValue]
tupleIndex map[udpTupleKey]uint32
streamTuples map[uint32]udpTupleKey
stats *statsCounters
} }
type udpStreamValue struct { type udpStreamValue struct {
@@ -120,36 +130,71 @@ type udpStreamValue struct {
UDPFlow gopacket.Flow UDPFlow gopacket.Flow
} }
type udpTupleKey struct {
AIP [16]byte
BIP [16]byte
ALen uint8
BLen uint8
APort uint16
BPort uint16
}
func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) { func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) {
fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow
rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse() rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()
return fwd || rev, rev return fwd || rev, rev
} }
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int) (*udpStreamManager, error) { func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) {
ss, err := lru.New[uint32, *udpStreamValue](maxStreams) m := &udpStreamManager{
factory: factory,
tupleIndex: make(map[udpTupleKey]uint32, maxStreams),
streamTuples: make(map[uint32]udpTupleKey, maxStreams),
stats: stats,
}
ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) {
m.removeTupleMappingLocked(k)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &udpStreamManager{ m.streams = ss
factory: factory, return m, nil
streams: ss,
}, nil
} }
func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) { func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) {
rev := false rev := false
value, ok := m.streams.Get(streamID) value, ok := m.streams.Get(streamID)
tuple := canonicalUDPTupleKey(ipFlow, udp)
if !ok { if !ok {
// Fallback: conntrack IDs can change during early flow lifetime on some systems. if m.stats != nil {
// Try to find an existing stream by 5-tuple before creating a new stream. m.stats.UDPTupleLookups.Add(1)
matchedKey, matchedValue, matchedRev, found := m.findByFlow(ipFlow, udp.TransportFlow()) }
// Conntrack IDs can change during early flow lifetime on some systems.
// Rebind by canonical 5-tuple in O(1).
matchedKey, found := m.tupleIndex[tuple]
var matchedValue *udpStreamValue
var matchedRev bool
if found { if found {
if m.stats != nil {
m.stats.UDPTupleHits.Add(1)
}
var hasValue bool
matchedValue, hasValue = m.streams.Get(matchedKey)
if !hasValue || matchedValue == nil {
delete(m.tupleIndex, tuple)
delete(m.streamTuples, matchedKey)
found = false
}
}
if found {
_, matchedRev = matchedValue.Match(ipFlow, udp.TransportFlow())
value = matchedValue value = matchedValue
rev = matchedRev rev = matchedRev
if matchedKey != streamID { if matchedKey != streamID {
m.streams.Remove(matchedKey) m.streams.Remove(matchedKey)
m.streams.Add(streamID, matchedValue) m.streams.Add(streamID, matchedValue)
m.bindTupleLocked(streamID, tuple)
} }
} else { } else {
// New stream // New stream
@@ -159,6 +204,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
UDPFlow: udp.TransportFlow(), UDPFlow: udp.TransportFlow(),
} }
m.streams.Add(streamID, value) m.streams.Add(streamID, value)
m.bindTupleLocked(streamID, tuple)
} }
} else { } else {
// Stream ID exists, but is it really the same stream? // Stream ID exists, but is it really the same stream?
@@ -172,6 +218,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
UDPFlow: udp.TransportFlow(), UDPFlow: udp.TransportFlow(),
} }
m.streams.Add(streamID, value) m.streams.Add(streamID, value)
m.bindTupleLocked(streamID, tuple)
} }
} }
if value.Stream.Accept(udp, rev, uc) { if value.Stream.Accept(udp, rev, uc) {
@@ -179,17 +226,58 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
} }
} }
func (m *udpStreamManager) findByFlow(ipFlow, udpFlow gopacket.Flow) (key uint32, value *udpStreamValue, rev bool, found bool) { func (m *udpStreamManager) bindTupleLocked(streamID uint32, key udpTupleKey) {
for _, k := range m.streams.Keys() { m.removeTupleMappingLocked(streamID)
v, ok := m.streams.Peek(k) m.tupleIndex[key] = streamID
if !ok || v == nil { m.streamTuples[streamID] = key
continue }
}
if ok2, rev2 := v.Match(ipFlow, udpFlow); ok2 { func (m *udpStreamManager) removeTupleMappingLocked(streamID uint32) {
return k, v, rev2, true if key, ok := m.streamTuples[streamID]; ok {
delete(m.streamTuples, streamID)
current, exists := m.tupleIndex[key]
if exists && current == streamID {
delete(m.tupleIndex, key)
} }
} }
return 0, nil, false, false }
func canonicalUDPTupleKey(ipFlow gopacket.Flow, udp *layers.UDP) udpTupleKey {
srcIP := ipFlow.Src().Raw()
dstIP := ipFlow.Dst().Raw()
srcPort := uint16(udp.SrcPort)
dstPort := uint16(udp.DstPort)
if compareIPEndpoint(srcIP, srcPort, dstIP, dstPort) > 0 {
srcIP, dstIP = dstIP, srcIP
srcPort, dstPort = dstPort, srcPort
}
var key udpTupleKey
key.ALen = uint8(copy(key.AIP[:], srcIP))
key.BLen = uint8(copy(key.BIP[:], dstIP))
key.APort = srcPort
key.BPort = dstPort
return key
}
func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int {
if len(aIP) != len(bIP) {
if len(aIP) < len(bIP) {
return -1
}
return 1
}
if c := bytes.Compare(aIP, bIP); c != 0 {
return c
}
if aPort < bPort {
return -1
}
if aPort > bPort {
return 1
}
return 0
} }
type udpStream struct { type udpStream struct {
+122
View File
@@ -0,0 +1,122 @@
package engine
import (
"net"
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
type legacyUDPStreamValue struct {
IPFlow gopacket.Flow
UDPFlow gopacket.Flow
}
type emptyRuleset struct{}
func (emptyRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer { return nil }
func (emptyRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult {
return ruleset.MatchResult{Action: ruleset.ActionMaybe}
}
func benchmarkUDPManager(b *testing.B, churn bool) {
node, err := snowflake.NewNode(0)
if err != nil {
b.Fatalf("create node: %v", err)
}
factory := &udpStreamFactory{WorkerID: 0, Logger: noopTestLogger{}, Node: node, Ruleset: emptyRuleset{}}
mgr, err := newUDPStreamManager(factory, 200000, &statsCounters{})
if err != nil {
b.Fatalf("new manager: %v", err)
}
const flowCount = 20000
flows := make([]gopacket.Flow, flowCount)
udps := make([]*layers.UDP, flowCount)
for i := 0; i < flowCount; i++ {
a := byte(i >> 8)
c := byte(i)
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4())
udps[i] = &layers.UDP{
SrcPort: layers.UDPPort(1024 + i%20000),
DstPort: layers.UDPPort(20000 + (i*7)%20000),
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}},
}
}
ctx := &udpContext{Verdict: udpVerdictAccept}
b.ResetTimer()
for i := 0; i < b.N; i++ {
idx := i % flowCount
streamID := uint32(idx + 1)
if churn {
streamID = uint32((i % flowCount) + 1 + ((i / flowCount) * flowCount))
}
ctx.Verdict = udpVerdictAccept
ctx.Packet = nil
mgr.MatchWithContext(streamID, flows[idx], udps[idx], ctx)
}
}
func BenchmarkUDPManagerMatchStableStreamID(b *testing.B) {
benchmarkUDPManager(b, false)
}
func BenchmarkUDPManagerMatchStreamIDChurn(b *testing.B) {
benchmarkUDPManager(b, true)
}
func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) {
const flowCount = 5000
flows := make([]gopacket.Flow, flowCount)
udps := make([]*layers.UDP, flowCount)
for i := 0; i < flowCount; i++ {
a := byte(i >> 8)
c := byte(i)
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4())
udps[i] = &layers.UDP{
SrcPort: layers.UDPPort(1024 + i%20000),
DstPort: layers.UDPPort(20000 + (i*7)%20000),
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}},
}
}
streams := make(map[uint32]*legacyUDPStreamValue, flowCount)
keys := make([]uint32, 0, flowCount)
for i := 0; i < flowCount; i++ {
streamID := uint32(i + 1)
streams[streamID] = &legacyUDPStreamValue{
IPFlow: flows[i],
UDPFlow: udps[i].TransportFlow(),
}
keys = append(keys, streamID)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
idx := i % flowCount
streamID := uint32((i % flowCount) + 1 + ((i / flowCount) * flowCount))
if _, ok := streams[streamID]; ok {
continue
}
ipFlow := flows[idx]
udpFlow := udps[idx].TransportFlow()
for _, k := range keys {
v, ok := streams[k]
if !ok || v == nil {
continue
}
if (v.IPFlow == ipFlow && v.UDPFlow == udpFlow) ||
(v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()) {
delete(streams, k)
streams[streamID] = v
break
}
}
}
}
+71
View File
@@ -0,0 +1,71 @@
package engine
import (
"net"
"sync/atomic"
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
type countingRuleset struct {
ans []analyzer.Analyzer
}
func (r countingRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer { return r.ans }
func (r countingRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult {
return ruleset.MatchResult{Action: ruleset.ActionMaybe}
}
type countingUDPAnalyzer struct{ newCalls *atomic.Uint64 }
func (a countingUDPAnalyzer) Name() string { return "countudp" }
func (a countingUDPAnalyzer) Limit() int { return 0 }
func (a countingUDPAnalyzer) NewUDP(analyzer.UDPInfo, analyzer.Logger) analyzer.UDPStream {
a.newCalls.Add(1)
return countingUDPStream{}
}
type countingUDPStream struct{}
func (countingUDPStream) Feed(bool, []byte) (*analyzer.PropUpdate, bool) { return nil, false }
func (countingUDPStream) Close(bool) *analyzer.PropUpdate { return nil }
func TestUDPStreamManagerRebindsByTupleInO1Path(t *testing.T) {
node, err := snowflake.NewNode(0)
if err != nil {
t.Fatalf("create node: %v", err)
}
var newCalls atomic.Uint64
rs := countingRuleset{ans: []analyzer.Analyzer{countingUDPAnalyzer{newCalls: &newCalls}}}
factory := &udpStreamFactory{
WorkerID: 0,
Logger: noopTestLogger{},
Node: node,
Ruleset: rs,
}
mgr, err := newUDPStreamManager(factory, 64, &statsCounters{})
if err != nil {
t.Fatalf("new manager: %v", err)
}
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
udp := &layers.UDP{SrcPort: 50000, DstPort: 443, BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}}
ctx1 := &udpContext{Verdict: udpVerdictAccept}
mgr.MatchWithContext(100, ipFlow, udp, ctx1)
if got := newCalls.Load(); got != 1 {
t.Fatalf("new stream calls=%d want=1", got)
}
ctx2 := &udpContext{Verdict: udpVerdictAccept}
mgr.MatchWithContext(200, ipFlow, udp, ctx2)
if got := newCalls.Load(); got != 1 {
t.Fatalf("expected stream reuse by tuple, new stream calls=%d want=1", got)
}
}
+38 -14
View File
@@ -12,24 +12,32 @@ import (
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
) )
var _ Engine = (*engine)(nil)
type workerPacket struct { type workerPacket struct {
StreamID uint32 Packet io.Packet
Data []byte StreamID uint32
SrcMAC net.HardwareAddr Data []byte
DstMAC net.HardwareAddr SrcMAC net.HardwareAddr
SetVerdict func(io.Verdict, []byte) error DstMAC net.HardwareAddr
Gen int64
}
type workerResult struct {
Packet io.Packet
StreamID uint32
Verdict io.Verdict
ModifiedPacket []byte
Gen int64
} }
type worker struct { type worker struct {
id int id int
packetChan chan *workerPacket packetChan chan *workerPacket
resultChan chan workerResult
logger Logger logger Logger
macResolver *sourceMACResolver macResolver *sourceMACResolver
tcpFlowMgr *tcpFlowManager tcpFlowMgr *tcpFlowManager
udpSM *udpStreamManager udpSM *udpStreamManager
modSerializeBuffer gopacket.SerializeBuffer modSerializeBuffer gopacket.SerializeBuffer
} }
@@ -43,6 +51,9 @@ type workerConfig struct {
TCPMaxBufferedPagesTotal int // unused, kept for config compat TCPMaxBufferedPagesTotal int // unused, kept for config compat
TCPMaxBufferedPagesPerConn int // unused, kept for config compat TCPMaxBufferedPagesPerConn int // unused, kept for config compat
UDPMaxStreams int UDPMaxStreams int
AnalyzerSelectionMode AnalyzerSelectionMode
ResultChan chan workerResult
Stats *statsCounters
} }
func (c *workerConfig) fillDefaults() { func (c *workerConfig) fillDefaults() {
@@ -61,7 +72,8 @@ func newWorker(config workerConfig) (*worker, error) {
return nil, err return nil, err
} }
tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode) selector := newAnalyzerSelector(config.AnalyzerSelectionMode, config.Stats)
tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode, selector)
if config.Ruleset != nil { if config.Ruleset != nil {
tcpMgr.updateRuleset(config.Ruleset, 0) tcpMgr.updateRuleset(config.Ruleset, 0)
} }
@@ -71,8 +83,10 @@ func newWorker(config workerConfig) (*worker, error) {
Logger: config.Logger, Logger: config.Logger,
Node: sfNode, Node: sfNode,
Ruleset: config.Ruleset, Ruleset: config.Ruleset,
Selector: selector,
Stats: config.Stats,
} }
udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams) udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams, config.Stats)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -80,6 +94,7 @@ func newWorker(config workerConfig) (*worker, error) {
return &worker{ return &worker{
id: config.ID, id: config.ID,
packetChan: make(chan *workerPacket, config.ChanSize), packetChan: make(chan *workerPacket, config.ChanSize),
resultChan: config.ResultChan,
logger: config.Logger, logger: config.Logger,
macResolver: config.MACResolver, macResolver: config.MACResolver,
tcpFlowMgr: tcpMgr, tcpFlowMgr: tcpMgr,
@@ -97,6 +112,10 @@ func (w *worker) Feed(p *workerPacket) bool {
} }
} }
func (w *worker) FeedBlocking(p *workerPacket) {
w.packetChan <- p
}
func (w *worker) Run(ctx context.Context) { func (w *worker) Run(ctx context.Context) {
w.logger.WorkerStart(w.id) w.logger.WorkerStart(w.id)
defer w.logger.WorkerStop(w.id) defer w.logger.WorkerStop(w.id)
@@ -109,7 +128,13 @@ func (w *worker) Run(ctx context.Context) {
return return
} }
v, b := w.handle(wp) v, b := w.handle(wp)
_ = wp.SetVerdict(v, b) w.resultChan <- workerResult{
Packet: wp.Packet,
StreamID: wp.StreamID,
Verdict: v,
ModifiedPacket: b,
Gen: wp.Gen,
}
} }
} }
} }
@@ -185,8 +210,7 @@ func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []by
SrcMAC: srcMAC, SrcMAC: srcMAC,
DstMAC: dstMAC, DstMAC: dstMAC,
} }
// Temporarily set payload on a UDP layer so existing UDP handling works // Temporarily set payload on a UDP layer so existing UDP handling works.
// We pass the payload through the context
w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{ w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{
BaseLayer: layers.BaseLayer{Payload: payload}, BaseLayer: layers.BaseLayer{Payload: payload},
SrcPort: layers.UDPPort(udp.SrcPort), SrcPort: layers.UDPPort(udp.SrcPort),
+3
View File
@@ -1,3 +1,6 @@
//go:build linux
// +build linux
package io package io
import ( import (
+43
View File
@@ -0,0 +1,43 @@
//go:build !linux
// +build !linux
package io
import (
"context"
"errors"
"net"
)
var errNFQueueUnsupported = errors.New("nfqueue packet io is only supported on linux")
type NFQueuePacketIOConfig struct {
QueueSize uint32
ReadBuffer int
WriteBuffer int
Local bool
RST bool
NumQueues int
MaxPacketLen uint32
}
func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
_ = config
return nil, errNFQueueUnsupported
}
func (*unsupportedPacketIO) Register(context.Context, PacketCallback) error {
return errNFQueueUnsupported
}
func (*unsupportedPacketIO) SetVerdict(Packet, Verdict, []byte) error {
return errNFQueueUnsupported
}
func (*unsupportedPacketIO) ProtectedDialContext(context.Context, string, string) (net.Conn, error) {
return nil, errNFQueueUnsupported
}
func (*unsupportedPacketIO) Close() error { return nil }
type unsupportedPacketIO struct{}
+233 -54
View File
@@ -1,52 +1,45 @@
package geo package geo
import ( import (
"container/list"
"net" "net"
"sort"
"strings" "strings"
"sync" "sync"
) )
const (
geoSiteResultCacheSize = 1 << 16
geoSiteSetResultCacheSize = 1 << 16
)
type GeoMatcher struct { type GeoMatcher struct {
geoLoader GeoLoader geoLoader GeoLoader
geoSiteMatcher map[string]hostMatcher geoSiteMatcher map[string]hostMatcher
siteMatcherLock sync.Mutex siteMatcherLock sync.RWMutex
geoSiteSets map[string][]hostMatcher
siteSetLock sync.RWMutex
geoIpMatcher map[string]hostMatcher geoIpMatcher map[string]hostMatcher
ipMatcherLock sync.Mutex ipMatcherLock sync.RWMutex
geoSiteResult *boolLRUCache
geoSiteSetCache *boolLRUCache
} }
func NewGeoMatcher(geoSiteFilename, geoIpFilename string) *GeoMatcher { func NewGeoMatcher(geoSiteFilename, geoIpFilename string) *GeoMatcher {
return &GeoMatcher{ return &GeoMatcher{
geoLoader: NewDefaultGeoLoader(geoSiteFilename, geoIpFilename), geoLoader: NewDefaultGeoLoader(geoSiteFilename, geoIpFilename),
geoSiteMatcher: make(map[string]hostMatcher), geoSiteMatcher: make(map[string]hostMatcher),
geoIpMatcher: make(map[string]hostMatcher), geoSiteSets: make(map[string][]hostMatcher),
geoIpMatcher: make(map[string]hostMatcher),
geoSiteResult: newBoolLRUCache(geoSiteResultCacheSize),
geoSiteSetCache: newBoolLRUCache(geoSiteSetResultCacheSize),
} }
} }
func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool { func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool {
g.ipMatcherLock.Lock() matcher, ok := g.getOrCreateGeoIPMatcher(condition)
defer g.ipMatcherLock.Unlock() if !ok || matcher == nil {
return false
matcher, ok := g.geoIpMatcher[condition]
if !ok {
// GeoIP matcher
condition = strings.ToLower(condition)
country := condition
if len(country) == 0 {
return false
}
gMap, err := g.geoLoader.LoadGeoIP()
if err != nil {
return false
}
list, ok := gMap[country]
if !ok || list == nil {
return false
}
matcher, err = newGeoIPMatcher(list)
if err != nil {
return false
}
g.geoIpMatcher[condition] = matcher
} }
parseIp := net.ParseIP(ip) parseIp := net.ParseIP(ip)
if parseIp == nil { if parseIp == nil {
@@ -64,32 +57,69 @@ func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool {
} }
func (g *GeoMatcher) MatchGeoSite(site, condition string) bool { func (g *GeoMatcher) MatchGeoSite(site, condition string) bool {
g.siteMatcherLock.Lock() conditionKey := strings.TrimSpace(strings.ToLower(condition))
defer g.siteMatcherLock.Unlock() if conditionKey == "" {
return false
matcher, ok := g.geoSiteMatcher[condition]
if !ok {
// MatchGeoSite matcher
condition = strings.ToLower(condition)
name, attrs := parseGeoSiteName(condition)
if len(name) == 0 {
return false
}
gMap, err := g.geoLoader.LoadGeoSite()
if err != nil {
return false
}
list, ok := gMap[name]
if !ok || list == nil {
return false
}
matcher, err = newGeositeMatcher(list, attrs)
if err != nil {
return false
}
g.geoSiteMatcher[condition] = matcher
} }
return matcher.Match(HostInfo{Name: site}) cacheKey := site + "\x1f" + conditionKey
if v, ok := g.geoSiteResult.Get(cacheKey); ok {
return v
}
matcher, ok := g.getOrCreateGeoSiteMatcher(condition)
if !ok || matcher == nil {
return false
}
result := matcher.Match(HostInfo{Name: site})
g.geoSiteResult.Set(cacheKey, result)
return result
}
func (g *GeoMatcher) MatchGeoSiteSet(site string, set *SiteConditionSet) bool {
if set == nil {
return false
}
conditions := normalizeGeoSiteSetConditions(set.Conditions)
if len(conditions) == 0 {
return false
}
key := strings.Join(conditions, "\x1f")
cacheKey := site + "\x1e" + key
if v, ok := g.geoSiteSetCache.Get(cacheKey); ok {
return v
}
g.siteSetLock.RLock()
matchers, ok := g.geoSiteSets[key]
g.siteSetLock.RUnlock()
if !ok {
compiled := make([]hostMatcher, 0, len(conditions))
for _, condition := range conditions {
m, ok := g.getOrCreateGeoSiteMatcher(condition)
if ok && m != nil {
compiled = append(compiled, m)
}
}
g.siteSetLock.Lock()
if existing, exists := g.geoSiteSets[key]; exists {
matchers = existing
} else {
g.geoSiteSets[key] = compiled
matchers = compiled
}
g.siteSetLock.Unlock()
}
if len(matchers) == 0 {
return false
}
host := HostInfo{Name: site}
for _, matcher := range matchers {
if matcher.Match(host) {
g.geoSiteSetCache.Set(cacheKey, true)
return true
}
}
g.geoSiteSetCache.Set(cacheKey, false)
return false
} }
func (g *GeoMatcher) LoadGeoSite() error { func (g *GeoMatcher) LoadGeoSite() error {
@@ -111,3 +141,152 @@ func parseGeoSiteName(s string) (string, []string) {
} }
return base, attrs return base, attrs
} }
func (g *GeoMatcher) getOrCreateGeoSiteMatcher(condition string) (hostMatcher, bool) {
condition = strings.TrimSpace(strings.ToLower(condition))
if condition == "" {
return nil, false
}
g.siteMatcherLock.RLock()
matcher, ok := g.geoSiteMatcher[condition]
g.siteMatcherLock.RUnlock()
if ok {
return matcher, true
}
name, attrs := parseGeoSiteName(condition)
if len(name) == 0 {
return nil, false
}
gMap, err := g.geoLoader.LoadGeoSite()
if err != nil {
return nil, false
}
list, ok := gMap[name]
if !ok || list == nil {
return nil, false
}
matcher, err = newGeositeMatcher(list, attrs)
if err != nil {
return nil, false
}
g.siteMatcherLock.Lock()
if existing, exists := g.geoSiteMatcher[condition]; exists {
matcher = existing
} else {
g.geoSiteMatcher[condition] = matcher
}
g.siteMatcherLock.Unlock()
return matcher, true
}
func (g *GeoMatcher) getOrCreateGeoIPMatcher(condition string) (hostMatcher, bool) {
condition = strings.TrimSpace(strings.ToLower(condition))
if condition == "" {
return nil, false
}
g.ipMatcherLock.RLock()
matcher, ok := g.geoIpMatcher[condition]
g.ipMatcherLock.RUnlock()
if ok {
return matcher, true
}
gMap, err := g.geoLoader.LoadGeoIP()
if err != nil {
return nil, false
}
list, ok := gMap[condition]
if !ok || list == nil {
return nil, false
}
matcher, err = newGeoIPMatcher(list)
if err != nil {
return nil, false
}
g.ipMatcherLock.Lock()
if existing, exists := g.geoIpMatcher[condition]; exists {
matcher = existing
} else {
g.geoIpMatcher[condition] = matcher
}
g.ipMatcherLock.Unlock()
return matcher, true
}
func normalizeGeoSiteSetConditions(in []string) []string {
if len(in) == 0 {
return nil
}
out := make([]string, 0, len(in))
seen := make(map[string]struct{}, len(in))
for _, v := range in {
s := strings.TrimSpace(strings.ToLower(v))
if s == "" {
continue
}
if _, ok := seen[s]; ok {
continue
}
seen[s] = struct{}{}
out = append(out, s)
}
sort.Strings(out)
return out
}
type boolLRUCache struct {
mu sync.Mutex
cap int
ll *list.List
items map[string]*list.Element
}
type boolCacheEntry struct {
key string
value bool
}
func newBoolLRUCache(capacity int) *boolLRUCache {
if capacity <= 0 {
capacity = 1
}
return &boolLRUCache{
cap: capacity,
ll: list.New(),
items: make(map[string]*list.Element, capacity),
}
}
func (c *boolLRUCache) Get(key string) (bool, bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ele, ok := c.items[key]; ok {
c.ll.MoveToFront(ele)
entry := ele.Value.(boolCacheEntry)
return entry.value, true
}
return false, false
}
func (c *boolLRUCache) Set(key string, value bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ele, ok := c.items[key]; ok {
ele.Value = boolCacheEntry{key: key, value: value}
c.ll.MoveToFront(ele)
return
}
ele := c.ll.PushFront(boolCacheEntry{key: key, value: value})
c.items[key] = ele
if c.ll.Len() <= c.cap {
return
}
back := c.ll.Back()
if back == nil {
return
}
entry := back.Value.(boolCacheEntry)
delete(c.items, entry.key)
c.ll.Remove(back)
}
+79 -1
View File
@@ -1,13 +1,14 @@
package geo package geo
import ( import (
"sync/atomic"
"testing" "testing"
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo" "git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
) )
type fakeGeoLoader struct { type fakeGeoLoader struct {
geoip map[string]*v2geo.GeoIP geoip map[string]*v2geo.GeoIP
geosite map[string]*v2geo.GeoSite geosite map[string]*v2geo.GeoSite
} }
@@ -110,6 +111,83 @@ func TestGeoMatcher_MatchGeoSite_MissingSite(t *testing.T) {
} }
} }
func TestGeoMatcher_MatchGeoSiteSet(t *testing.T) {
loader := &fakeGeoLoader{
geosite: map[string]*v2geo.GeoSite{
"openai": {
Domain: []*v2geo.Domain{
{Type: v2geo.Domain_Plain, Value: "openai"},
},
},
"google": {
Domain: []*v2geo.Domain{
{Type: v2geo.Domain_RootDomain, Value: "google.com"},
},
},
},
}
g := NewGeoMatcher("", "")
g.geoLoader = loader
set := &SiteConditionSet{Conditions: []string{" google ", "openai", "OPENAI"}}
if !g.MatchGeoSiteSet("api.openai.com", set) {
t.Error("MatchGeoSiteSet should match openai")
}
if !g.MatchGeoSiteSet("mail.google.com", set) {
t.Error("MatchGeoSiteSet should match google")
}
if g.MatchGeoSiteSet("example.com", set) {
t.Error("MatchGeoSiteSet should not match unrelated host")
}
}
type countingMatcher struct {
calls *atomic.Uint64
match bool
}
func (m countingMatcher) Match(host HostInfo) bool {
_ = host
m.calls.Add(1)
return m.match
}
func TestGeoMatcher_MatchGeoSite_UsesResultCache(t *testing.T) {
g := NewGeoMatcher("", "")
var calls atomic.Uint64
g.geoSiteMatcher["openai"] = countingMatcher{calls: &calls, match: true}
if !g.MatchGeoSite("api.openai.com", "openai") {
t.Fatal("expected match")
}
if !g.MatchGeoSite("api.openai.com", "openai") {
t.Fatal("expected cached match")
}
if got := calls.Load(); got != 1 {
t.Fatalf("matcher calls=%d want=1", got)
}
}
func TestGeoMatcher_MatchGeoSiteSet_UsesResultCache(t *testing.T) {
g := NewGeoMatcher("", "")
var calls atomic.Uint64
g.geoSiteSets["openai\x1fyoutube"] = []hostMatcher{
countingMatcher{calls: &calls, match: false},
countingMatcher{calls: &calls, match: true},
}
set := &SiteConditionSet{Conditions: []string{"youtube", "openai"}}
if !g.MatchGeoSiteSet("www.youtube.com", set) {
t.Fatal("expected match")
}
if !g.MatchGeoSiteSet("www.youtube.com", set) {
t.Fatal("expected cached match")
}
if got := calls.Load(); got != 2 {
t.Fatalf("matcher calls=%d want=2", got)
}
}
func ipv4(a, b, c, d byte) []byte { func ipv4(a, b, c, d byte) []byte {
return []byte{a, b, c, d} return []byte{a, b, c, d}
} }
+4
View File
@@ -13,6 +13,10 @@ type HostInfo struct {
IPv6 net.IP IPv6 net.IP
} }
type SiteConditionSet struct {
Conditions []string
}
func (h HostInfo) String() string { func (h HostInfo) String() string {
return fmt.Sprintf("%s|%s|%s", h.Name, h.IPv4, h.IPv6) return fmt.Sprintf("%s|%s|%s", h.Name, h.IPv4, h.IPv6)
} }
+203 -24
View File
@@ -60,8 +60,8 @@ type compiledExprRule struct {
ModInstance modifier.Instance ModInstance modifier.Instance
Program *vm.Program Program *vm.Program
GeoSiteConditions []string GeoSiteConditions []string
StartTimeSecs int // seconds since midnight, -1 if unset StartTimeSecs int // seconds since midnight, -1 if unset
StopTimeSecs int // seconds since midnight, -1 if unset StopTimeSecs int // seconds since midnight, -1 if unset
Weekdays []time.Weekday Weekdays []time.Weekday
WeekdaysNegated bool WeekdaysNegated bool
} }
@@ -86,6 +86,7 @@ type exprRuleset struct {
Ans []analyzer.Analyzer Ans []analyzer.Analyzer
Logger Logger Logger Logger
GeoMatcher *geo.GeoMatcher GeoMatcher *geo.GeoMatcher
stats *statsCounters
} }
func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer { func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
@@ -93,9 +94,24 @@ func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
} }
func (r *exprRuleset) Match(info StreamInfo) MatchResult { func (r *exprRuleset) Match(info StreamInfo) MatchResult {
start := time.Now()
if r.stats != nil {
r.stats.MatchCalls.Add(1)
defer func() {
r.stats.MatchLatencyNanos.Add(uint64(time.Since(start).Nanoseconds()))
}()
}
env := envPool.Get().(map[string]any) env := envPool.Get().(map[string]any)
clear(env) clear(env)
populateExprEnv(env, info) macMap, ipMap, portMap := populateExprEnv(env, info)
releaseEnv := func() {
clear(env)
envPool.Put(env)
putSubMap(macMap)
putSubMap(ipMap)
putSubMap(portMap)
}
now := time.Now() now := time.Now()
for _, rule := range r.Rules { for _, rule := range r.Rules {
if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) { if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) {
@@ -103,6 +119,9 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
} }
v, err := vm.Run(rule.Program, env) v, err := vm.Run(rule.Program, env)
if err != nil { if err != nil {
if r.stats != nil {
r.stats.MatchErrors.Add(1)
}
r.Logger.MatchError(info, rule.Name, err) r.Logger.MatchError(info, rule.Name, err)
continue continue
} }
@@ -115,7 +134,7 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
r.Logger.Log(logInfo, rule.Name) r.Logger.Log(logInfo, rule.Name)
} }
if rule.Action != nil { if rule.Action != nil {
envPool.Put(env) releaseEnv()
return MatchResult{ return MatchResult{
Action: *rule.Action, Action: *rule.Action,
ModInstance: rule.ModInstance, ModInstance: rule.ModInstance,
@@ -123,12 +142,26 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
} }
} }
} }
envPool.Put(env) releaseEnv()
return MatchResult{ return MatchResult{
Action: ActionMaybe, Action: ActionMaybe,
} }
} }
func (r *exprRuleset) Stats() Stats {
if r == nil || r.stats == nil {
return Stats{}
}
return Stats{
MatchCalls: r.stats.MatchCalls.Load(),
MatchErrors: r.stats.MatchErrors.Load(),
MatchLatencyNanos: r.stats.MatchLatencyNanos.Load(),
LookupCalls: r.stats.LookupCalls.Load(),
LookupErrors: r.stats.LookupErrors.Load(),
LookupLatencyNanos: r.stats.LookupLatencyNanos.Load(),
}
}
// CompileExprRules compiles a list of expression rules into a ruleset. // CompileExprRules compiles a list of expression rules into a ruleset.
// It returns an error if any of the rules are invalid, or if any of the analyzers // It returns an error if any of the rules are invalid, or if any of the analyzers
// used by the rules are unknown (not provided in the analyzer list). // used by the rules are unknown (not provided in the analyzer list).
@@ -137,7 +170,8 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
fullAnMap := analyzersToMap(ans) fullAnMap := analyzersToMap(ans)
fullModMap := modifiersToMap(mods) fullModMap := modifiersToMap(mods)
depAnMap := make(map[string]analyzer.Analyzer) depAnMap := make(map[string]analyzer.Analyzer)
funcMap, geoMatcher := buildFunctionMap(config) stats := &statsCounters{}
funcMap, geoMatcher := buildFunctionMap(config, stats)
// Compile all rules and build a map of analyzers that are used by the rules. // Compile all rules and build a map of analyzers that are used by the rules.
for _, rule := range rules { for _, rule := range rules {
if rule.Action == "" && !rule.Log { if rule.Action == "" && !rule.Log {
@@ -152,7 +186,7 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
action = &a action = &a
} }
visitor := &idVisitor{Variables: make(map[string]bool), Identifiers: make(map[string]bool)} visitor := &idVisitor{Variables: make(map[string]bool), Identifiers: make(map[string]bool)}
patcher := &idPatcher{FuncMap: funcMap} patcher := &idPatcher{FuncMap: funcMap, GeoMatcher: geoMatcher}
program, err := expr.Compile(rule.Expr, program, err := expr.Compile(rule.Expr,
func(c *conf.Config) { func(c *conf.Config) {
c.Strict = false c.Strict = false
@@ -242,29 +276,47 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
Ans: depAns, Ans: depAns,
Logger: config.Logger, Logger: config.Logger,
GeoMatcher: geoMatcher, GeoMatcher: geoMatcher,
stats: stats,
}, nil }, nil
} }
func populateExprEnv(m map[string]any, info StreamInfo) { func populateExprEnv(m map[string]any, info StreamInfo) (macMap, ipMap, portMap map[string]any) {
macMap = getSubMap()
ipMap = getSubMap()
portMap = getSubMap()
macMap["src"] = info.SrcMAC.String()
macMap["dst"] = info.DstMAC.String()
ipMap["src"] = info.SrcIP.String()
ipMap["dst"] = info.DstIP.String()
portMap["src"] = info.SrcPort
portMap["dst"] = info.DstPort
m["id"] = info.ID m["id"] = info.ID
m["proto"] = info.Protocol.String() m["proto"] = info.Protocol.String()
m["mac"] = map[string]string{ m["mac"] = macMap
"src": info.SrcMAC.String(), m["ip"] = ipMap
"dst": info.DstMAC.String(), m["port"] = portMap
}
m["ip"] = map[string]string{
"src": info.SrcIP.String(),
"dst": info.DstIP.String(),
}
m["port"] = map[string]uint16{
"src": info.SrcPort,
"dst": info.DstPort,
}
for anName, anProps := range info.Props { for anName, anProps := range info.Props {
if len(anProps) != 0 { if len(anProps) != 0 {
m[anName] = anProps m[anName] = anProps
} }
} }
return macMap, ipMap, portMap
}
func getSubMap() map[string]any {
m := subMapPool.Get().(map[string]any)
clear(m)
return m
}
func putSubMap(m map[string]any) {
if m == nil {
return
}
clear(m)
subMapPool.Put(m)
} }
func isBuiltInAnalyzer(name string) bool { func isBuiltInAnalyzer(name string) bool {
@@ -329,11 +381,15 @@ func (v *idVisitor) Visit(node *ast.Node) {
// idPatcher patches the AST during expr compilation, replacing certain values with // idPatcher patches the AST during expr compilation, replacing certain values with
// their internal representations for better runtime performance. // their internal representations for better runtime performance.
type idPatcher struct { type idPatcher struct {
FuncMap map[string]*Function FuncMap map[string]*Function
Err error GeoMatcher *geo.GeoMatcher
Err error
} }
func (p *idPatcher) Visit(node *ast.Node) { func (p *idPatcher) Visit(node *ast.Node) {
if p.tryPatchGeoSiteORChain(node) {
return
}
switch (*node).(type) { switch (*node).(type) {
case *ast.CallNode: case *ast.CallNode:
callNode := (*node).(*ast.CallNode) callNode := (*node).(*ast.CallNode)
@@ -352,6 +408,108 @@ func (p *idPatcher) Visit(node *ast.Node) {
} }
} }
func (p *idPatcher) tryPatchGeoSiteORChain(node *ast.Node) bool {
if p == nil || p.GeoMatcher == nil {
return false
}
terms, ok := collectGeoSiteORChain(*node)
if !ok || len(terms) < 2 {
return false
}
hostExpr := strings.TrimSpace(terms[0].hostExpr)
if hostExpr == "" {
return false
}
conditions := make([]string, 0, len(terms))
for _, term := range terms {
if strings.TrimSpace(term.hostExpr) != hostExpr {
return false
}
conditions = append(conditions, term.condition)
}
normalized := normalizeUniqueLowerStrings(conditions)
if len(normalized) < 2 {
return false
}
hostNode, err := parser.Parse(hostExpr)
if err != nil || hostNode == nil || hostNode.Node == nil {
return false
}
call := &ast.CallNode{
Callee: &ast.IdentifierNode{Value: "geosite_set"},
Arguments: []ast.Node{
hostNode.Node,
&ast.ConstantNode{Value: &geo.SiteConditionSet{Conditions: normalized}},
},
}
ast.Patch(node, call)
return true
}
type geositeTerm struct {
hostExpr string
condition string
}
func collectGeoSiteORChain(node ast.Node) ([]geositeTerm, bool) {
switch n := node.(type) {
case *ast.BinaryNode:
if n.Operator != "or" && n.Operator != "||" {
return nil, false
}
left, ok := collectGeoSiteORChain(n.Left)
if !ok {
return nil, false
}
right, ok := collectGeoSiteORChain(n.Right)
if !ok {
return nil, false
}
out := make([]geositeTerm, 0, len(left)+len(right))
out = append(out, left...)
out = append(out, right...)
return out, true
case *ast.CallNode:
idNode, ok := n.Callee.(*ast.IdentifierNode)
if !ok || len(n.Arguments) < 2 {
return nil, false
}
name := strings.ToLower(idNode.Value)
if name == "geosite" {
condNode, ok := n.Arguments[1].(*ast.StringNode)
if !ok {
return nil, false
}
return []geositeTerm{{
hostExpr: n.Arguments[0].String(),
condition: condNode.Value,
}}, true
}
if name != "geosite_set" {
return nil, false
}
setNode, ok := n.Arguments[1].(*ast.ConstantNode)
if !ok || setNode.Value == nil {
return nil, false
}
set, ok := setNode.Value.(*geo.SiteConditionSet)
if !ok || set == nil {
return nil, false
}
if len(set.Conditions) == 0 {
return nil, false
}
out := make([]geositeTerm, 0, len(set.Conditions))
hostExpr := n.Arguments[0].String()
for _, condition := range set.Conditions {
out = append(out, geositeTerm{hostExpr: hostExpr, condition: condition})
}
return out, true
default:
return nil, false
}
}
type Function struct { type Function struct {
InitFunc func() error InitFunc func() error
PatchFunc func(args *[]ast.Node) error PatchFunc func(args *[]ast.Node) error
@@ -359,7 +517,7 @@ type Function struct {
Types []reflect.Type Types []reflect.Type
} }
func buildFunctionMap(config *BuiltinConfig) (map[string]*Function, *geo.GeoMatcher) { func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*Function, *geo.GeoMatcher) {
geoMatcher := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename) geoMatcher := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename)
return map[string]*Function{ return map[string]*Function{
"geoip": { "geoip": {
@@ -378,6 +536,16 @@ func buildFunctionMap(config *BuiltinConfig) (map[string]*Function, *geo.GeoMatc
}, },
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)}, Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
}, },
"geosite_set": {
InitFunc: geoMatcher.LoadGeoSite,
PatchFunc: nil,
Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoSiteSet(params[0].(string), params[1].(*geo.SiteConditionSet)), nil
},
Types: []reflect.Type{
reflect.TypeOf((func(string, *geo.SiteConditionSet) bool)(nil)),
},
},
"cidr": { "cidr": {
InitFunc: nil, InitFunc: nil,
PatchFunc: func(args *[]ast.Node) error { PatchFunc: func(args *[]ast.Node) error {
@@ -425,9 +593,20 @@ func buildFunctionMap(config *BuiltinConfig) (map[string]*Function, *geo.GeoMatc
return nil return nil
}, },
Func: func(params ...any) (any, error) { Func: func(params ...any) (any, error) {
start := time.Now()
if stats != nil {
stats.LookupCalls.Add(1)
defer func() {
stats.LookupLatencyNanos.Add(uint64(time.Since(start).Nanoseconds()))
}()
}
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
defer cancel() defer cancel()
return params[1].(*net.Resolver).LookupHost(ctx, params[0].(string)) out, err := params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
if err != nil && stats != nil {
stats.LookupErrors.Add(1)
}
return out, err
}, },
Types: []reflect.Type{ Types: []reflect.Type{
reflect.TypeOf((func(string, *net.Resolver) []string)(nil)), reflect.TypeOf((func(string, *net.Resolver) []string)(nil)),
+25
View File
@@ -2,9 +2,14 @@ package ruleset
import ( import (
"reflect" "reflect"
"strings"
"testing" "testing"
"git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo"
"github.com/expr-lang/expr/ast"
"github.com/expr-lang/expr/parser"
) )
func TestExtractGeoSiteConditions(t *testing.T) { func TestExtractGeoSiteConditions(t *testing.T) {
@@ -63,3 +68,23 @@ func TestMatchGeoSiteConditions(t *testing.T) {
t.Fatalf("matchGeoSiteConditions() = %v, want %v", got, want) t.Fatalf("matchGeoSiteConditions() = %v, want %v", got, want)
} }
} }
func TestIDPatcher_PatchesGeoSiteORChainToGeoSiteSet(t *testing.T) {
tree, err := parser.Parse(`geosite(tls.req.sni, "google") || geosite(tls.req.sni, "youtube") || geosite(tls.req.sni, "openai")`)
if err != nil {
t.Fatalf("parse expression: %v", err)
}
root := tree.Node
patcher := &idPatcher{GeoMatcher: geo.NewGeoMatcher("", "")}
ast.Walk(&root, patcher)
if patcher.Err != nil {
t.Fatalf("patch error: %v", patcher.Err)
}
got := root.String()
if !strings.Contains(got, "geosite_set(") {
t.Fatalf("expected geosite_set rewrite, got %q", got)
}
if strings.Contains(got, "||") || strings.Contains(got, " or ") {
t.Fatalf("expected OR chain to be collapsed, got %q", got)
}
}
+23
View File
@@ -4,6 +4,7 @@ import (
"context" "context"
"net" "net"
"strconv" "strconv"
"sync/atomic"
"git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/modifier" "git.difuse.io/Difuse/Mellaris/modifier"
@@ -95,6 +96,28 @@ type Ruleset interface {
Match(StreamInfo) MatchResult Match(StreamInfo) MatchResult
} }
type Stats struct {
MatchCalls uint64
MatchErrors uint64
MatchLatencyNanos uint64
LookupCalls uint64
LookupErrors uint64
LookupLatencyNanos uint64
}
type statsCounters struct {
MatchCalls atomic.Uint64
MatchErrors atomic.Uint64
MatchLatencyNanos atomic.Uint64
LookupCalls atomic.Uint64
LookupErrors atomic.Uint64
LookupLatencyNanos atomic.Uint64
}
type StatsProvider interface {
Stats() Stats
}
// Logger is the logging interface for the ruleset. // Logger is the logging interface for the ruleset.
type Logger interface { type Logger interface {
Log(info StreamInfo, name string) Log(info StreamInfo, name string)
+22
View File
@@ -0,0 +1,22 @@
package mellaris
import (
"git.difuse.io/Difuse/Mellaris/engine"
"git.difuse.io/Difuse/Mellaris/ruleset"
)
type Stats struct {
Engine engine.Stats
Ruleset ruleset.Stats
}
func (a *App) Stats() Stats {
if a == nil || a.engine == nil {
return Stats{}
}
out := Stats{Engine: a.engine.Stats()}
if rs, ok := a.ruleset.(ruleset.StatsProvider); ok {
out.Ruleset = rs.Stats()
}
return out
}