Improves flow handling and adds runtime stats APIs
Refactors TCP and UDP flow managers to enhance analyzer selection and flow binding accuracy, including O(1) UDP stream rebinding by 5-tuple. Introduces runtime stats tracking for engine and ruleset operations, exposing new APIs for granular performance and error metrics. Optimizes GeoMatcher with result caching and supports efficient geosite set matching, reducing redundant computation in ruleset expressions.
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||||
"git.difuse.io/Difuse/Mellaris/engine"
|
"git.difuse.io/Difuse/Mellaris/engine"
|
||||||
@@ -17,6 +18,7 @@ type App struct {
|
|||||||
engine engine.Engine
|
engine engine.Engine
|
||||||
io gfwio.PacketIO
|
io gfwio.PacketIO
|
||||||
rulesetConfig *ruleset.BuiltinConfig
|
rulesetConfig *ruleset.BuiltinConfig
|
||||||
|
ruleset ruleset.Ruleset
|
||||||
analyzers []analyzer.Analyzer
|
analyzers []analyzer.Analyzer
|
||||||
modifiers []modifier.Modifier
|
modifiers []modifier.Modifier
|
||||||
rulesFile string
|
rulesFile string
|
||||||
@@ -42,6 +44,11 @@ func New(cfg Config, opts Options) (*App, error) {
|
|||||||
|
|
||||||
packetIO := cfg.IO.PacketIO
|
packetIO := cfg.IO.PacketIO
|
||||||
ownsIO := false
|
ownsIO := false
|
||||||
|
workerCount := effectiveWorkerCount(cfg.Workers.Count)
|
||||||
|
numQueues := cfg.IO.NumQueues
|
||||||
|
if numQueues <= 0 {
|
||||||
|
numQueues = workerCount
|
||||||
|
}
|
||||||
if packetIO == nil {
|
if packetIO == nil {
|
||||||
packetIO, err = gfwio.NewNFQueuePacketIO(gfwio.NFQueuePacketIOConfig{
|
packetIO, err = gfwio.NewNFQueuePacketIO(gfwio.NFQueuePacketIOConfig{
|
||||||
QueueSize: cfg.IO.QueueSize,
|
QueueSize: cfg.IO.QueueSize,
|
||||||
@@ -49,7 +56,7 @@ func New(cfg Config, opts Options) (*App, error) {
|
|||||||
WriteBuffer: cfg.IO.WriteBuffer,
|
WriteBuffer: cfg.IO.WriteBuffer,
|
||||||
Local: cfg.IO.Local,
|
Local: cfg.IO.Local,
|
||||||
RST: cfg.IO.RST,
|
RST: cfg.IO.RST,
|
||||||
NumQueues: cfg.IO.NumQueues,
|
NumQueues: numQueues,
|
||||||
MaxPacketLen: cfg.IO.MaxPacketLen,
|
MaxPacketLen: cfg.IO.MaxPacketLen,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -79,11 +86,13 @@ func New(cfg Config, opts Options) (*App, error) {
|
|||||||
Logger: engineLogger,
|
Logger: engineLogger,
|
||||||
IO: packetIO,
|
IO: packetIO,
|
||||||
Ruleset: rs,
|
Ruleset: rs,
|
||||||
Workers: cfg.Workers.Count,
|
Workers: workerCount,
|
||||||
WorkerQueueSize: cfg.Workers.QueueSize,
|
WorkerQueueSize: cfg.Workers.QueueSize,
|
||||||
WorkerTCPMaxBufferedPagesTotal: cfg.Workers.TCPMaxBufferedPagesTotal,
|
WorkerTCPMaxBufferedPagesTotal: cfg.Workers.TCPMaxBufferedPagesTotal,
|
||||||
WorkerTCPMaxBufferedPagesPerConn: cfg.Workers.TCPMaxBufferedPagesPerConn,
|
WorkerTCPMaxBufferedPagesPerConn: cfg.Workers.TCPMaxBufferedPagesPerConn,
|
||||||
WorkerUDPMaxStreams: cfg.Workers.UDPMaxStreams,
|
WorkerUDPMaxStreams: cfg.Workers.UDPMaxStreams,
|
||||||
|
OverflowPolicy: cfg.Workers.OverflowPolicy,
|
||||||
|
AnalyzerSelectionMode: cfg.Workers.AnalyzerSelectionMode,
|
||||||
}
|
}
|
||||||
eng, err := engine.NewEngine(engCfg)
|
eng, err := engine.NewEngine(engCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -95,6 +104,7 @@ func New(cfg Config, opts Options) (*App, error) {
|
|||||||
engine: eng,
|
engine: eng,
|
||||||
io: packetIO,
|
io: packetIO,
|
||||||
rulesetConfig: rsConfig,
|
rulesetConfig: rsConfig,
|
||||||
|
ruleset: rs,
|
||||||
analyzers: analyzers,
|
analyzers: analyzers,
|
||||||
modifiers: modifiers,
|
modifiers: modifiers,
|
||||||
rulesFile: rulesFile,
|
rulesFile: rulesFile,
|
||||||
@@ -140,6 +150,17 @@ func (a *App) Engine() engine.Engine {
|
|||||||
return a.engine
|
return a.engine
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func effectiveWorkerCount(configured int) int {
|
||||||
|
if configured > 0 {
|
||||||
|
return configured
|
||||||
|
}
|
||||||
|
n := runtime.GOMAXPROCS(0)
|
||||||
|
if n <= 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
func resolveRules(opts Options) ([]ruleset.ExprRule, string, error) {
|
func resolveRules(opts Options) ([]ruleset.ExprRule, string, error) {
|
||||||
if opts.RulesFile != "" && len(opts.Rules) > 0 {
|
if opts.RulesFile != "" && len(opts.Rules) > 0 {
|
||||||
return nil, "", ConfigError{Field: "rules", Err: errors.New("use either RulesFile or Rules")}
|
return nil, "", ConfigError{Field: "rules", Err: errors.New("use either RulesFile or Rules")}
|
||||||
|
|||||||
@@ -32,11 +32,13 @@ type IOConfig struct {
|
|||||||
|
|
||||||
// WorkersConfig configures engine worker behavior.
|
// WorkersConfig configures engine worker behavior.
|
||||||
type WorkersConfig struct {
|
type WorkersConfig struct {
|
||||||
Count int `mapstructure:"count" yaml:"count"`
|
Count int `mapstructure:"count" yaml:"count"`
|
||||||
QueueSize int `mapstructure:"queueSize" yaml:"queueSize"`
|
QueueSize int `mapstructure:"queueSize" yaml:"queueSize"`
|
||||||
TCPMaxBufferedPagesTotal int `mapstructure:"tcpMaxBufferedPagesTotal" yaml:"tcpMaxBufferedPagesTotal"`
|
TCPMaxBufferedPagesTotal int `mapstructure:"tcpMaxBufferedPagesTotal" yaml:"tcpMaxBufferedPagesTotal"`
|
||||||
TCPMaxBufferedPagesPerConn int `mapstructure:"tcpMaxBufferedPagesPerConn" yaml:"tcpMaxBufferedPagesPerConn"`
|
TCPMaxBufferedPagesPerConn int `mapstructure:"tcpMaxBufferedPagesPerConn" yaml:"tcpMaxBufferedPagesPerConn"`
|
||||||
UDPMaxStreams int `mapstructure:"udpMaxStreams" yaml:"udpMaxStreams"`
|
UDPMaxStreams int `mapstructure:"udpMaxStreams" yaml:"udpMaxStreams"`
|
||||||
|
OverflowPolicy engine.OverflowPolicy `mapstructure:"overflowPolicy" yaml:"overflowPolicy"`
|
||||||
|
AnalyzerSelectionMode engine.AnalyzerSelectionMode `mapstructure:"analyzerSelectionMode" yaml:"analyzerSelectionMode"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RulesetConfig configures built-in rule helpers.
|
// RulesetConfig configures built-in rule helpers.
|
||||||
|
|||||||
@@ -0,0 +1,235 @@
|
|||||||
|
package engine
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type analyzerSelector struct {
|
||||||
|
mode AnalyzerSelectionMode
|
||||||
|
stats *statsCounters
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAnalyzerSelector(mode AnalyzerSelectionMode, stats *statsCounters) *analyzerSelector {
|
||||||
|
if mode == "" {
|
||||||
|
mode = AnalyzerSelectionModeSignature
|
||||||
|
}
|
||||||
|
return &analyzerSelector{mode: mode, stats: stats}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *analyzerSelector) SelectTCP(ans []analyzer.Analyzer, payload []byte) []analyzer.Analyzer {
|
||||||
|
if s == nil || s.mode == AnalyzerSelectionModeAlways || len(ans) <= 1 {
|
||||||
|
return ans
|
||||||
|
}
|
||||||
|
allowed := tcpAllowedAnalyzers(payload)
|
||||||
|
if len(allowed) == 0 {
|
||||||
|
return ans
|
||||||
|
}
|
||||||
|
out := make([]analyzer.Analyzer, 0, len(ans))
|
||||||
|
for _, a := range ans {
|
||||||
|
name := strings.ToLower(a.Name())
|
||||||
|
if _, known := knownTCPAnalyzers[name]; !known {
|
||||||
|
out = append(out, a)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if allowed[name] {
|
||||||
|
out = append(out, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.recordSelection(len(ans), len(out))
|
||||||
|
if len(out) == 0 {
|
||||||
|
return ans
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *analyzerSelector) SelectUDP(ans []analyzer.Analyzer, payload []byte) []analyzer.Analyzer {
|
||||||
|
if s == nil || s.mode == AnalyzerSelectionModeAlways || len(ans) <= 1 {
|
||||||
|
return ans
|
||||||
|
}
|
||||||
|
allowed := udpAllowedAnalyzers(payload)
|
||||||
|
if len(allowed) == 0 {
|
||||||
|
return ans
|
||||||
|
}
|
||||||
|
out := make([]analyzer.Analyzer, 0, len(ans))
|
||||||
|
for _, a := range ans {
|
||||||
|
name := strings.ToLower(a.Name())
|
||||||
|
if _, known := knownUDPAnalyzers[name]; !known {
|
||||||
|
out = append(out, a)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if allowed[name] {
|
||||||
|
out = append(out, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.recordSelection(len(ans), len(out))
|
||||||
|
if len(out) == 0 {
|
||||||
|
return ans
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *analyzerSelector) recordSelection(total, selected int) {
|
||||||
|
if s == nil || s.stats == nil || total <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.stats.AnalyzerSelectionsTotal.Add(1)
|
||||||
|
if selected < total {
|
||||||
|
s.stats.AnalyzerSelectionsPruned.Add(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
knownTCPAnalyzers = map[string]struct{}{
|
||||||
|
"fet": {},
|
||||||
|
"http": {},
|
||||||
|
"socks": {},
|
||||||
|
"ssh": {},
|
||||||
|
"tls": {},
|
||||||
|
"trojan": {},
|
||||||
|
"dns": {},
|
||||||
|
"openvpn": {},
|
||||||
|
}
|
||||||
|
knownUDPAnalyzers = map[string]struct{}{
|
||||||
|
"dns": {},
|
||||||
|
"openvpn": {},
|
||||||
|
"quic": {},
|
||||||
|
"wireguard": {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func tcpAllowedAnalyzers(payload []byte) map[string]bool {
|
||||||
|
allowed := make(map[string]bool, 4)
|
||||||
|
if looksLikeTLS(payload) {
|
||||||
|
allowed["tls"] = true
|
||||||
|
allowed["trojan"] = true
|
||||||
|
allowed["fet"] = true
|
||||||
|
}
|
||||||
|
if looksLikeHTTP(payload) {
|
||||||
|
allowed["http"] = true
|
||||||
|
allowed["fet"] = true
|
||||||
|
}
|
||||||
|
if looksLikeSSH(payload) {
|
||||||
|
allowed["ssh"] = true
|
||||||
|
allowed["fet"] = true
|
||||||
|
}
|
||||||
|
if looksLikeSOCKS(payload) {
|
||||||
|
allowed["socks"] = true
|
||||||
|
allowed["fet"] = true
|
||||||
|
}
|
||||||
|
if looksLikeDNSTCP(payload) {
|
||||||
|
allowed["dns"] = true
|
||||||
|
allowed["fet"] = true
|
||||||
|
}
|
||||||
|
if len(allowed) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
func udpAllowedAnalyzers(payload []byte) map[string]bool {
|
||||||
|
allowed := make(map[string]bool, 4)
|
||||||
|
if looksLikeWireGuard(payload) {
|
||||||
|
allowed["wireguard"] = true
|
||||||
|
}
|
||||||
|
if looksLikeOpenVPN(payload) {
|
||||||
|
allowed["openvpn"] = true
|
||||||
|
}
|
||||||
|
if looksLikeQUIC(payload) {
|
||||||
|
allowed["quic"] = true
|
||||||
|
}
|
||||||
|
if looksLikeDNSUDP(payload) {
|
||||||
|
allowed["dns"] = true
|
||||||
|
}
|
||||||
|
if len(allowed) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeTLS(payload []byte) bool {
|
||||||
|
if len(payload) < 3 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return (payload[0] == 0x16 || payload[0] == 0x17) && payload[1] == 0x03 && payload[2] <= 0x09
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeHTTP(payload []byte) bool {
|
||||||
|
if len(payload) < 3 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
head := strings.ToUpper(string(payload[:3]))
|
||||||
|
switch head {
|
||||||
|
case "GET", "HEA", "POS", "PUT", "DEL", "CON", "OPT", "TRA", "PAT":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeSSH(payload []byte) bool {
|
||||||
|
return len(payload) >= 4 && bytes.HasPrefix(payload, []byte("SSH-"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeSOCKS(payload []byte) bool {
|
||||||
|
if len(payload) < 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return payload[0] == 0x04 || payload[0] == 0x05
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeDNSTCP(payload []byte) bool {
|
||||||
|
if len(payload) < 14 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
msgLen := int(payload[0])<<8 | int(payload[1])
|
||||||
|
if msgLen <= 0 || msgLen+2 > len(payload) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
qd := int(payload[6])<<8 | int(payload[7])
|
||||||
|
an := int(payload[8])<<8 | int(payload[9])
|
||||||
|
return qd+an > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeDNSUDP(payload []byte) bool {
|
||||||
|
if len(payload) < 12 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
qd := int(payload[4])<<8 | int(payload[5])
|
||||||
|
an := int(payload[6])<<8 | int(payload[7])
|
||||||
|
ns := int(payload[8])<<8 | int(payload[9])
|
||||||
|
ar := int(payload[10])<<8 | int(payload[11])
|
||||||
|
return qd+an+ns+ar > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeQUIC(payload []byte) bool {
|
||||||
|
if len(payload) < 6 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Long header with non-zero version.
|
||||||
|
if payload[0]&0x80 == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
version := uint32(payload[1])<<24 | uint32(payload[2])<<16 | uint32(payload[3])<<8 | uint32(payload[4])
|
||||||
|
return version != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeOpenVPN(payload []byte) bool {
|
||||||
|
if len(payload) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
opcode := payload[0] >> 3
|
||||||
|
return opcode >= 1 && opcode <= 11
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeWireGuard(payload []byte) bool {
|
||||||
|
if len(payload) < 4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if payload[0] < 1 || payload[0] > 4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return payload[1] == 0 && payload[2] == 0 && payload[3] == 0
|
||||||
|
}
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
package engine
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type namedAnalyzer struct{ name string }
|
||||||
|
|
||||||
|
func (a namedAnalyzer) Name() string { return a.name }
|
||||||
|
func (a namedAnalyzer) Limit() int { return 0 }
|
||||||
|
|
||||||
|
func TestSignatureSelectorTCPPrunesByPayloadNotPort(t *testing.T) {
|
||||||
|
sel := newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{})
|
||||||
|
all := []analyzer.Analyzer{
|
||||||
|
namedAnalyzer{"http"},
|
||||||
|
namedAnalyzer{"tls"},
|
||||||
|
namedAnalyzer{"trojan"},
|
||||||
|
namedAnalyzer{"ssh"},
|
||||||
|
namedAnalyzer{"socks"},
|
||||||
|
namedAnalyzer{"fet"},
|
||||||
|
}
|
||||||
|
// TLS record-like prefix, regardless of destination port.
|
||||||
|
payload := []byte{0x16, 0x03, 0x03, 0x00, 0x10}
|
||||||
|
selected := sel.SelectTCP(all, payload)
|
||||||
|
|
||||||
|
got := make(map[string]bool)
|
||||||
|
for _, a := range selected {
|
||||||
|
got[a.Name()] = true
|
||||||
|
}
|
||||||
|
for _, keep := range []string{"tls", "trojan", "fet"} {
|
||||||
|
if !got[keep] {
|
||||||
|
t.Fatalf("expected analyzer %q to be selected", keep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, drop := range []string{"http", "ssh", "socks"} {
|
||||||
|
if got[drop] {
|
||||||
|
t.Fatalf("expected analyzer %q to be pruned", drop)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignatureSelectorConservativeFallback(t *testing.T) {
|
||||||
|
sel := newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{})
|
||||||
|
all := []analyzer.Analyzer{
|
||||||
|
namedAnalyzer{"http"},
|
||||||
|
namedAnalyzer{"tls"},
|
||||||
|
namedAnalyzer{"custom"},
|
||||||
|
}
|
||||||
|
payload := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||||
|
selected := sel.SelectTCP(all, payload)
|
||||||
|
if len(selected) != len(all) {
|
||||||
|
t.Fatalf("expected conservative fallback to keep all analyzers, got=%d want=%d", len(selected), len(all))
|
||||||
|
}
|
||||||
|
}
|
||||||
+68
-25
@@ -2,6 +2,7 @@ package engine
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
@@ -20,18 +21,34 @@ type engine struct {
|
|||||||
logger Logger
|
logger Logger
|
||||||
io io.PacketIO
|
io io.PacketIO
|
||||||
workers []*worker
|
workers []*worker
|
||||||
verdicts sync.Map // streamID(uint32) → verdictEntry
|
stats *statsCounters
|
||||||
|
verdicts sync.Map // streamID(uint32) -> verdictEntry
|
||||||
verdictsGen atomic.Int64 // incremented on ruleset update
|
verdictsGen atomic.Int64 // incremented on ruleset update
|
||||||
|
|
||||||
overflowCh chan *workerPacket
|
overflowPolicy OverflowPolicy
|
||||||
overflowOnce sync.Once
|
resultCh chan workerResult
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewEngine(config Config) (Engine, error) {
|
func NewEngine(config Config) (Engine, error) {
|
||||||
workerCount := config.Workers
|
workerCount := config.Workers
|
||||||
if workerCount <= 0 {
|
if workerCount <= 0 {
|
||||||
workerCount = 1
|
workerCount = runtime.GOMAXPROCS(0)
|
||||||
|
if workerCount <= 0 {
|
||||||
|
workerCount = 1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
overflowPolicy := config.OverflowPolicy
|
||||||
|
if overflowPolicy == "" {
|
||||||
|
overflowPolicy = OverflowPolicyAccept
|
||||||
|
}
|
||||||
|
selectionMode := config.AnalyzerSelectionMode
|
||||||
|
if selectionMode == "" {
|
||||||
|
selectionMode = AnalyzerSelectionModeSignature
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := &statsCounters{}
|
||||||
|
resultCh := make(chan workerResult, workerCount*256)
|
||||||
|
|
||||||
macResolver := newSourceMACResolver()
|
macResolver := newSourceMACResolver()
|
||||||
var err error
|
var err error
|
||||||
workers := make([]*worker, workerCount)
|
workers := make([]*worker, workerCount)
|
||||||
@@ -45,16 +62,21 @@ func NewEngine(config Config) (Engine, error) {
|
|||||||
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
|
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
|
||||||
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
|
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
|
||||||
UDPMaxStreams: config.WorkerUDPMaxStreams,
|
UDPMaxStreams: config.WorkerUDPMaxStreams,
|
||||||
|
AnalyzerSelectionMode: selectionMode,
|
||||||
|
ResultChan: resultCh,
|
||||||
|
Stats: stats,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
e := &engine{
|
e := &engine{
|
||||||
logger: config.Logger,
|
logger: config.Logger,
|
||||||
io: config.IO,
|
io: config.IO,
|
||||||
workers: workers,
|
workers: workers,
|
||||||
overflowCh: make(chan *workerPacket, 1024),
|
stats: stats,
|
||||||
|
overflowPolicy: overflowPolicy,
|
||||||
|
resultCh: resultCh,
|
||||||
}
|
}
|
||||||
return e, nil
|
return e, nil
|
||||||
}
|
}
|
||||||
@@ -74,13 +96,10 @@ func (e *engine) Run(ctx context.Context) error {
|
|||||||
ioCtx, ioCancel := context.WithCancel(ctx)
|
ioCtx, ioCancel := context.WithCancel(ctx)
|
||||||
defer ioCancel()
|
defer ioCancel()
|
||||||
|
|
||||||
e.overflowOnce.Do(func() {
|
|
||||||
go e.drainOverflow(ioCtx)
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, w := range e.workers {
|
for _, w := range e.workers {
|
||||||
go w.Run(ioCtx)
|
go w.Run(ioCtx)
|
||||||
}
|
}
|
||||||
|
go e.drainResults(ioCtx)
|
||||||
|
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
|
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
|
||||||
@@ -121,24 +140,35 @@ func (e *engine) dispatch(p io.Packet) bool {
|
|||||||
gen := e.verdictsGen.Load()
|
gen := e.verdictsGen.Load()
|
||||||
index := streamID % uint32(len(e.workers))
|
index := streamID % uint32(len(e.workers))
|
||||||
wp := &workerPacket{
|
wp := &workerPacket{
|
||||||
StreamID: streamID,
|
Packet: p,
|
||||||
Data: data,
|
StreamID: streamID,
|
||||||
SetVerdict: func(v io.Verdict, b []byte) error {
|
Data: data,
|
||||||
if v == io.VerdictAcceptStream || v == io.VerdictDropStream {
|
Gen: gen,
|
||||||
e.verdicts.Store(streamID, verdictEntry{Verdict: v, Gen: gen})
|
|
||||||
}
|
|
||||||
return e.io.SetVerdict(p, v, b)
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
if !e.workers[index].Feed(wp) {
|
if !e.workers[index].Feed(wp) {
|
||||||
select {
|
e.stats.OverflowEvents.Add(1)
|
||||||
case e.overflowCh <- wp:
|
switch e.overflowPolicy {
|
||||||
|
case OverflowPolicyDrop:
|
||||||
|
e.stats.OverflowDrops.Add(1)
|
||||||
|
_ = e.io.SetVerdict(p, io.VerdictDrop, nil)
|
||||||
|
case OverflowPolicyBackpressure:
|
||||||
|
e.stats.OverflowBackpressureEvents.Add(1)
|
||||||
|
e.workers[index].FeedBlocking(wp)
|
||||||
default:
|
default:
|
||||||
|
e.stats.OverflowAccepts.Add(1)
|
||||||
|
_ = e.io.SetVerdict(p, io.VerdictAccept, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *engine) applyWorkerResult(r workerResult) {
|
||||||
|
if r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream {
|
||||||
|
e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen})
|
||||||
|
}
|
||||||
|
_ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket)
|
||||||
|
}
|
||||||
|
|
||||||
func validPacket(data []byte) bool {
|
func validPacket(data []byte) bool {
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
return false
|
return false
|
||||||
@@ -156,13 +186,26 @@ func validPacket(data []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *engine) drainOverflow(ctx context.Context) {
|
func (e *engine) drainResults(ctx context.Context) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case wp := <-e.overflowCh:
|
case r := <-e.resultCh:
|
||||||
_ = wp.SetVerdict(io.VerdictAccept, nil)
|
e.applyWorkerResult(r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *engine) Stats() Stats {
|
||||||
|
return Stats{
|
||||||
|
OverflowEvents: e.stats.OverflowEvents.Load(),
|
||||||
|
OverflowAccepts: e.stats.OverflowAccepts.Load(),
|
||||||
|
OverflowDrops: e.stats.OverflowDrops.Load(),
|
||||||
|
OverflowBackpressureEvents: e.stats.OverflowBackpressureEvents.Load(),
|
||||||
|
AnalyzerSelectionsTotal: e.stats.AnalyzerSelectionsTotal.Load(),
|
||||||
|
AnalyzerSelectionsPruned: e.stats.AnalyzerSelectionsPruned.Load(),
|
||||||
|
UDPTupleLookups: e.stats.UDPTupleLookups.Load(),
|
||||||
|
UDPTupleHits: e.stats.UDPTupleHits.Load(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package engine
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"git.difuse.io/Difuse/Mellaris/io"
|
"git.difuse.io/Difuse/Mellaris/io"
|
||||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||||
@@ -13,6 +14,49 @@ type Engine interface {
|
|||||||
UpdateRuleset(ruleset.Ruleset) error
|
UpdateRuleset(ruleset.Ruleset) error
|
||||||
// Run runs the engine, until an error occurs or the context is cancelled.
|
// Run runs the engine, until an error occurs or the context is cancelled.
|
||||||
Run(context.Context) error
|
Run(context.Context) error
|
||||||
|
// Stats returns a consistent snapshot of runtime counters.
|
||||||
|
Stats() Stats
|
||||||
|
}
|
||||||
|
|
||||||
|
type OverflowPolicy string
|
||||||
|
|
||||||
|
const (
|
||||||
|
OverflowPolicyAccept OverflowPolicy = "accept"
|
||||||
|
OverflowPolicyDrop OverflowPolicy = "drop"
|
||||||
|
OverflowPolicyBackpressure OverflowPolicy = "backpressure"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AnalyzerSelectionMode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
AnalyzerSelectionModeAlways AnalyzerSelectionMode = "always"
|
||||||
|
AnalyzerSelectionModeSignature AnalyzerSelectionMode = "signature"
|
||||||
|
)
|
||||||
|
|
||||||
|
type statsCounters struct {
|
||||||
|
OverflowEvents atomic.Uint64
|
||||||
|
OverflowAccepts atomic.Uint64
|
||||||
|
OverflowDrops atomic.Uint64
|
||||||
|
OverflowBackpressureEvents atomic.Uint64
|
||||||
|
|
||||||
|
AnalyzerSelectionsTotal atomic.Uint64
|
||||||
|
AnalyzerSelectionsPruned atomic.Uint64
|
||||||
|
|
||||||
|
UDPTupleLookups atomic.Uint64
|
||||||
|
UDPTupleHits atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type Stats struct {
|
||||||
|
OverflowEvents uint64
|
||||||
|
OverflowAccepts uint64
|
||||||
|
OverflowDrops uint64
|
||||||
|
OverflowBackpressureEvents uint64
|
||||||
|
|
||||||
|
AnalyzerSelectionsTotal uint64
|
||||||
|
AnalyzerSelectionsPruned uint64
|
||||||
|
|
||||||
|
UDPTupleLookups uint64
|
||||||
|
UDPTupleHits uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config is the configuration for the engine.
|
// Config is the configuration for the engine.
|
||||||
@@ -26,6 +70,8 @@ type Config struct {
|
|||||||
WorkerTCPMaxBufferedPagesTotal int
|
WorkerTCPMaxBufferedPagesTotal int
|
||||||
WorkerTCPMaxBufferedPagesPerConn int
|
WorkerTCPMaxBufferedPagesPerConn int
|
||||||
WorkerUDPMaxStreams int
|
WorkerUDPMaxStreams int
|
||||||
|
OverflowPolicy OverflowPolicy
|
||||||
|
AnalyzerSelectionMode AnalyzerSelectionMode
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logger is the combined logging interface for the engine, workers and analyzers.
|
// Logger is the combined logging interface for the engine, workers and analyzers.
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
//go:build linux
|
||||||
|
// +build linux
|
||||||
|
|
||||||
package engine
|
package engine
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
//go:build !linux
|
||||||
|
// +build !linux
|
||||||
|
|
||||||
|
package engine
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
type sourceMACResolver struct{}
|
||||||
|
|
||||||
|
func newSourceMACResolver() *sourceMACResolver {
|
||||||
|
return &sourceMACResolver{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr {
|
||||||
|
_ = ip
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -142,7 +142,7 @@ func TestTCPFlowUsesUpdatedRuleset(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create node: %v", err)
|
t.Fatalf("create node: %v", err)
|
||||||
}
|
}
|
||||||
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node)
|
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, nil)
|
||||||
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0)
|
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0)
|
||||||
|
|
||||||
l3 := L3Info{
|
l3 := L3Info{
|
||||||
@@ -180,7 +180,7 @@ func TestTCPFlowReevaluatesAfterRulesetVersionChange(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create node: %v", err)
|
t.Fatalf("create node: %v", err)
|
||||||
}
|
}
|
||||||
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node)
|
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, nil)
|
||||||
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0)
|
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0)
|
||||||
|
|
||||||
l3 := L3Info{
|
l3 := L3Info{
|
||||||
|
|||||||
+8
-4
@@ -163,15 +163,17 @@ type tcpFlowManager struct {
|
|||||||
rulesetSource func() (ruleset.Ruleset, uint64)
|
rulesetSource func() (ruleset.Ruleset, uint64)
|
||||||
workerID int
|
workerID int
|
||||||
macResolver *sourceMACResolver
|
macResolver *sourceMACResolver
|
||||||
|
selector *analyzerSelector
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTCPFlowManager(workerID int, logger Logger, macResolver *sourceMACResolver, node *snowflake.Node) *tcpFlowManager {
|
func newTCPFlowManager(workerID int, logger Logger, macResolver *sourceMACResolver, node *snowflake.Node, selector *analyzerSelector) *tcpFlowManager {
|
||||||
return &tcpFlowManager{
|
return &tcpFlowManager{
|
||||||
flows: make(map[uint32]*tcpFlow),
|
flows: make(map[uint32]*tcpFlow),
|
||||||
sfNode: node,
|
sfNode: node,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
workerID: workerID,
|
workerID: workerID,
|
||||||
macResolver: macResolver,
|
macResolver: macResolver,
|
||||||
|
selector: selector,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -179,7 +181,7 @@ func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload
|
|||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
flow, ok := m.flows[streamID]
|
flow, ok := m.flows[streamID]
|
||||||
if !ok {
|
if !ok {
|
||||||
flow = m.createFlow(streamID, l3, tcp, srcMAC, dstMAC)
|
flow = m.createFlow(streamID, l3, tcp, payload, srcMAC, dstMAC)
|
||||||
m.flows[streamID] = flow
|
m.flows[streamID] = flow
|
||||||
}
|
}
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
@@ -195,7 +197,7 @@ func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload
|
|||||||
return verdict
|
return verdict
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, srcMAC, dstMAC net.HardwareAddr) *tcpFlow {
|
func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) *tcpFlow {
|
||||||
id := m.sfNode.Generate()
|
id := m.sfNode.Generate()
|
||||||
ipSrc := net.IP(l3.SrcIP[:])
|
ipSrc := net.IP(l3.SrcIP[:])
|
||||||
ipDst := net.IP(l3.DstIP[:])
|
ipDst := net.IP(l3.DstIP[:])
|
||||||
@@ -217,7 +219,9 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, src
|
|||||||
rs, version := m.rulesetSource()
|
rs, version := m.rulesetSource()
|
||||||
var ans []analyzer.TCPAnalyzer
|
var ans []analyzer.TCPAnalyzer
|
||||||
if rs != nil {
|
if rs != nil {
|
||||||
ans = analyzersToTCPAnalyzers(rs.Analyzers(info))
|
baseAns := rs.Analyzers(info)
|
||||||
|
baseAns = m.selector.SelectTCP(baseAns, payload)
|
||||||
|
ans = analyzersToTCPAnalyzers(baseAns)
|
||||||
}
|
}
|
||||||
entries := make([]*tcpFlowEntry, 0, len(ans))
|
entries := make([]*tcpFlowEntry, 0, len(ans))
|
||||||
for _, a := range ans {
|
for _, a := range ans {
|
||||||
|
|||||||
+109
-21
@@ -1,6 +1,7 @@
|
|||||||
package engine
|
package engine
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -40,6 +41,8 @@ type udpStreamFactory struct {
|
|||||||
WorkerID int
|
WorkerID int
|
||||||
Logger Logger
|
Logger Logger
|
||||||
Node *snowflake.Node
|
Node *snowflake.Node
|
||||||
|
Selector *analyzerSelector
|
||||||
|
Stats *statsCounters
|
||||||
|
|
||||||
RulesetMutex sync.RWMutex
|
RulesetMutex sync.RWMutex
|
||||||
Ruleset ruleset.Ruleset
|
Ruleset ruleset.Ruleset
|
||||||
@@ -64,7 +67,11 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
|
|||||||
rs, version := f.currentRuleset()
|
rs, version := f.currentRuleset()
|
||||||
var ans []analyzer.UDPAnalyzer
|
var ans []analyzer.UDPAnalyzer
|
||||||
if rs != nil {
|
if rs != nil {
|
||||||
ans = analyzersToUDPAnalyzers(rs.Analyzers(info))
|
baseAns := rs.Analyzers(info)
|
||||||
|
if f.Selector != nil {
|
||||||
|
baseAns = f.Selector.SelectUDP(baseAns, udp.Payload)
|
||||||
|
}
|
||||||
|
ans = analyzersToUDPAnalyzers(baseAns)
|
||||||
}
|
}
|
||||||
// Create entries for each analyzer
|
// Create entries for each analyzer
|
||||||
entries := make([]*udpStreamEntry, 0, len(ans))
|
entries := make([]*udpStreamEntry, 0, len(ans))
|
||||||
@@ -110,8 +117,11 @@ func (f *udpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type udpStreamManager struct {
|
type udpStreamManager struct {
|
||||||
factory *udpStreamFactory
|
factory *udpStreamFactory
|
||||||
streams *lru.Cache[uint32, *udpStreamValue]
|
streams *lru.Cache[uint32, *udpStreamValue]
|
||||||
|
tupleIndex map[udpTupleKey]uint32
|
||||||
|
streamTuples map[uint32]udpTupleKey
|
||||||
|
stats *statsCounters
|
||||||
}
|
}
|
||||||
|
|
||||||
type udpStreamValue struct {
|
type udpStreamValue struct {
|
||||||
@@ -120,36 +130,71 @@ type udpStreamValue struct {
|
|||||||
UDPFlow gopacket.Flow
|
UDPFlow gopacket.Flow
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type udpTupleKey struct {
|
||||||
|
AIP [16]byte
|
||||||
|
BIP [16]byte
|
||||||
|
ALen uint8
|
||||||
|
BLen uint8
|
||||||
|
APort uint16
|
||||||
|
BPort uint16
|
||||||
|
}
|
||||||
|
|
||||||
func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) {
|
func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) {
|
||||||
fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow
|
fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow
|
||||||
rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()
|
rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()
|
||||||
return fwd || rev, rev
|
return fwd || rev, rev
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int) (*udpStreamManager, error) {
|
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) {
|
||||||
ss, err := lru.New[uint32, *udpStreamValue](maxStreams)
|
m := &udpStreamManager{
|
||||||
|
factory: factory,
|
||||||
|
tupleIndex: make(map[udpTupleKey]uint32, maxStreams),
|
||||||
|
streamTuples: make(map[uint32]udpTupleKey, maxStreams),
|
||||||
|
stats: stats,
|
||||||
|
}
|
||||||
|
ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) {
|
||||||
|
m.removeTupleMappingLocked(k)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &udpStreamManager{
|
m.streams = ss
|
||||||
factory: factory,
|
return m, nil
|
||||||
streams: ss,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) {
|
func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) {
|
||||||
rev := false
|
rev := false
|
||||||
value, ok := m.streams.Get(streamID)
|
value, ok := m.streams.Get(streamID)
|
||||||
|
tuple := canonicalUDPTupleKey(ipFlow, udp)
|
||||||
if !ok {
|
if !ok {
|
||||||
// Fallback: conntrack IDs can change during early flow lifetime on some systems.
|
if m.stats != nil {
|
||||||
// Try to find an existing stream by 5-tuple before creating a new stream.
|
m.stats.UDPTupleLookups.Add(1)
|
||||||
matchedKey, matchedValue, matchedRev, found := m.findByFlow(ipFlow, udp.TransportFlow())
|
}
|
||||||
|
// Conntrack IDs can change during early flow lifetime on some systems.
|
||||||
|
// Rebind by canonical 5-tuple in O(1).
|
||||||
|
matchedKey, found := m.tupleIndex[tuple]
|
||||||
|
var matchedValue *udpStreamValue
|
||||||
|
var matchedRev bool
|
||||||
if found {
|
if found {
|
||||||
|
if m.stats != nil {
|
||||||
|
m.stats.UDPTupleHits.Add(1)
|
||||||
|
}
|
||||||
|
var hasValue bool
|
||||||
|
matchedValue, hasValue = m.streams.Get(matchedKey)
|
||||||
|
if !hasValue || matchedValue == nil {
|
||||||
|
delete(m.tupleIndex, tuple)
|
||||||
|
delete(m.streamTuples, matchedKey)
|
||||||
|
found = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if found {
|
||||||
|
_, matchedRev = matchedValue.Match(ipFlow, udp.TransportFlow())
|
||||||
value = matchedValue
|
value = matchedValue
|
||||||
rev = matchedRev
|
rev = matchedRev
|
||||||
if matchedKey != streamID {
|
if matchedKey != streamID {
|
||||||
m.streams.Remove(matchedKey)
|
m.streams.Remove(matchedKey)
|
||||||
m.streams.Add(streamID, matchedValue)
|
m.streams.Add(streamID, matchedValue)
|
||||||
|
m.bindTupleLocked(streamID, tuple)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// New stream
|
// New stream
|
||||||
@@ -159,6 +204,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
|
|||||||
UDPFlow: udp.TransportFlow(),
|
UDPFlow: udp.TransportFlow(),
|
||||||
}
|
}
|
||||||
m.streams.Add(streamID, value)
|
m.streams.Add(streamID, value)
|
||||||
|
m.bindTupleLocked(streamID, tuple)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Stream ID exists, but is it really the same stream?
|
// Stream ID exists, but is it really the same stream?
|
||||||
@@ -172,6 +218,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
|
|||||||
UDPFlow: udp.TransportFlow(),
|
UDPFlow: udp.TransportFlow(),
|
||||||
}
|
}
|
||||||
m.streams.Add(streamID, value)
|
m.streams.Add(streamID, value)
|
||||||
|
m.bindTupleLocked(streamID, tuple)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if value.Stream.Accept(udp, rev, uc) {
|
if value.Stream.Accept(udp, rev, uc) {
|
||||||
@@ -179,17 +226,58 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *udpStreamManager) findByFlow(ipFlow, udpFlow gopacket.Flow) (key uint32, value *udpStreamValue, rev bool, found bool) {
|
func (m *udpStreamManager) bindTupleLocked(streamID uint32, key udpTupleKey) {
|
||||||
for _, k := range m.streams.Keys() {
|
m.removeTupleMappingLocked(streamID)
|
||||||
v, ok := m.streams.Peek(k)
|
m.tupleIndex[key] = streamID
|
||||||
if !ok || v == nil {
|
m.streamTuples[streamID] = key
|
||||||
continue
|
}
|
||||||
}
|
|
||||||
if ok2, rev2 := v.Match(ipFlow, udpFlow); ok2 {
|
func (m *udpStreamManager) removeTupleMappingLocked(streamID uint32) {
|
||||||
return k, v, rev2, true
|
if key, ok := m.streamTuples[streamID]; ok {
|
||||||
|
delete(m.streamTuples, streamID)
|
||||||
|
current, exists := m.tupleIndex[key]
|
||||||
|
if exists && current == streamID {
|
||||||
|
delete(m.tupleIndex, key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return 0, nil, false, false
|
}
|
||||||
|
|
||||||
|
func canonicalUDPTupleKey(ipFlow gopacket.Flow, udp *layers.UDP) udpTupleKey {
|
||||||
|
srcIP := ipFlow.Src().Raw()
|
||||||
|
dstIP := ipFlow.Dst().Raw()
|
||||||
|
srcPort := uint16(udp.SrcPort)
|
||||||
|
dstPort := uint16(udp.DstPort)
|
||||||
|
|
||||||
|
if compareIPEndpoint(srcIP, srcPort, dstIP, dstPort) > 0 {
|
||||||
|
srcIP, dstIP = dstIP, srcIP
|
||||||
|
srcPort, dstPort = dstPort, srcPort
|
||||||
|
}
|
||||||
|
|
||||||
|
var key udpTupleKey
|
||||||
|
key.ALen = uint8(copy(key.AIP[:], srcIP))
|
||||||
|
key.BLen = uint8(copy(key.BIP[:], dstIP))
|
||||||
|
key.APort = srcPort
|
||||||
|
key.BPort = dstPort
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int {
|
||||||
|
if len(aIP) != len(bIP) {
|
||||||
|
if len(aIP) < len(bIP) {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if c := bytes.Compare(aIP, bIP); c != 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
if aPort < bPort {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if aPort > bPort {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
type udpStream struct {
|
type udpStream struct {
|
||||||
|
|||||||
@@ -0,0 +1,122 @@
|
|||||||
|
package engine
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||||
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||||
|
|
||||||
|
"github.com/bwmarrin/snowflake"
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
)
|
||||||
|
|
||||||
|
type legacyUDPStreamValue struct {
|
||||||
|
IPFlow gopacket.Flow
|
||||||
|
UDPFlow gopacket.Flow
|
||||||
|
}
|
||||||
|
|
||||||
|
type emptyRuleset struct{}
|
||||||
|
|
||||||
|
func (emptyRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer { return nil }
|
||||||
|
func (emptyRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult {
|
||||||
|
return ruleset.MatchResult{Action: ruleset.ActionMaybe}
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkUDPManager(b *testing.B, churn bool) {
|
||||||
|
node, err := snowflake.NewNode(0)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("create node: %v", err)
|
||||||
|
}
|
||||||
|
factory := &udpStreamFactory{WorkerID: 0, Logger: noopTestLogger{}, Node: node, Ruleset: emptyRuleset{}}
|
||||||
|
mgr, err := newUDPStreamManager(factory, 200000, &statsCounters{})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("new manager: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const flowCount = 20000
|
||||||
|
flows := make([]gopacket.Flow, flowCount)
|
||||||
|
udps := make([]*layers.UDP, flowCount)
|
||||||
|
for i := 0; i < flowCount; i++ {
|
||||||
|
a := byte(i >> 8)
|
||||||
|
c := byte(i)
|
||||||
|
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4())
|
||||||
|
udps[i] = &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(1024 + i%20000),
|
||||||
|
DstPort: layers.UDPPort(20000 + (i*7)%20000),
|
||||||
|
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &udpContext{Verdict: udpVerdictAccept}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
idx := i % flowCount
|
||||||
|
streamID := uint32(idx + 1)
|
||||||
|
if churn {
|
||||||
|
streamID = uint32((i % flowCount) + 1 + ((i / flowCount) * flowCount))
|
||||||
|
}
|
||||||
|
ctx.Verdict = udpVerdictAccept
|
||||||
|
ctx.Packet = nil
|
||||||
|
mgr.MatchWithContext(streamID, flows[idx], udps[idx], ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUDPManagerMatchStableStreamID(b *testing.B) {
|
||||||
|
benchmarkUDPManager(b, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUDPManagerMatchStreamIDChurn(b *testing.B) {
|
||||||
|
benchmarkUDPManager(b, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) {
|
||||||
|
const flowCount = 5000
|
||||||
|
flows := make([]gopacket.Flow, flowCount)
|
||||||
|
udps := make([]*layers.UDP, flowCount)
|
||||||
|
for i := 0; i < flowCount; i++ {
|
||||||
|
a := byte(i >> 8)
|
||||||
|
c := byte(i)
|
||||||
|
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4())
|
||||||
|
udps[i] = &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(1024 + i%20000),
|
||||||
|
DstPort: layers.UDPPort(20000 + (i*7)%20000),
|
||||||
|
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
streams := make(map[uint32]*legacyUDPStreamValue, flowCount)
|
||||||
|
keys := make([]uint32, 0, flowCount)
|
||||||
|
for i := 0; i < flowCount; i++ {
|
||||||
|
streamID := uint32(i + 1)
|
||||||
|
streams[streamID] = &legacyUDPStreamValue{
|
||||||
|
IPFlow: flows[i],
|
||||||
|
UDPFlow: udps[i].TransportFlow(),
|
||||||
|
}
|
||||||
|
keys = append(keys, streamID)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
idx := i % flowCount
|
||||||
|
streamID := uint32((i % flowCount) + 1 + ((i / flowCount) * flowCount))
|
||||||
|
if _, ok := streams[streamID]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ipFlow := flows[idx]
|
||||||
|
udpFlow := udps[idx].TransportFlow()
|
||||||
|
for _, k := range keys {
|
||||||
|
v, ok := streams[k]
|
||||||
|
if !ok || v == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if (v.IPFlow == ipFlow && v.UDPFlow == udpFlow) ||
|
||||||
|
(v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()) {
|
||||||
|
delete(streams, k)
|
||||||
|
streams[streamID] = v
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
package engine
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||||
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||||
|
|
||||||
|
"github.com/bwmarrin/snowflake"
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
)
|
||||||
|
|
||||||
|
type countingRuleset struct {
|
||||||
|
ans []analyzer.Analyzer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r countingRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer { return r.ans }
|
||||||
|
func (r countingRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult {
|
||||||
|
return ruleset.MatchResult{Action: ruleset.ActionMaybe}
|
||||||
|
}
|
||||||
|
|
||||||
|
type countingUDPAnalyzer struct{ newCalls *atomic.Uint64 }
|
||||||
|
|
||||||
|
func (a countingUDPAnalyzer) Name() string { return "countudp" }
|
||||||
|
func (a countingUDPAnalyzer) Limit() int { return 0 }
|
||||||
|
func (a countingUDPAnalyzer) NewUDP(analyzer.UDPInfo, analyzer.Logger) analyzer.UDPStream {
|
||||||
|
a.newCalls.Add(1)
|
||||||
|
return countingUDPStream{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type countingUDPStream struct{}
|
||||||
|
|
||||||
|
func (countingUDPStream) Feed(bool, []byte) (*analyzer.PropUpdate, bool) { return nil, false }
|
||||||
|
func (countingUDPStream) Close(bool) *analyzer.PropUpdate { return nil }
|
||||||
|
|
||||||
|
func TestUDPStreamManagerRebindsByTupleInO1Path(t *testing.T) {
|
||||||
|
node, err := snowflake.NewNode(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create node: %v", err)
|
||||||
|
}
|
||||||
|
var newCalls atomic.Uint64
|
||||||
|
rs := countingRuleset{ans: []analyzer.Analyzer{countingUDPAnalyzer{newCalls: &newCalls}}}
|
||||||
|
factory := &udpStreamFactory{
|
||||||
|
WorkerID: 0,
|
||||||
|
Logger: noopTestLogger{},
|
||||||
|
Node: node,
|
||||||
|
Ruleset: rs,
|
||||||
|
}
|
||||||
|
mgr, err := newUDPStreamManager(factory, 64, &statsCounters{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new manager: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
|
||||||
|
udp := &layers.UDP{SrcPort: 50000, DstPort: 443, BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}}
|
||||||
|
|
||||||
|
ctx1 := &udpContext{Verdict: udpVerdictAccept}
|
||||||
|
mgr.MatchWithContext(100, ipFlow, udp, ctx1)
|
||||||
|
if got := newCalls.Load(); got != 1 {
|
||||||
|
t.Fatalf("new stream calls=%d want=1", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx2 := &udpContext{Verdict: udpVerdictAccept}
|
||||||
|
mgr.MatchWithContext(200, ipFlow, udp, ctx2)
|
||||||
|
if got := newCalls.Load(); got != 1 {
|
||||||
|
t.Fatalf("expected stream reuse by tuple, new stream calls=%d want=1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
+38
-14
@@ -12,24 +12,32 @@ import (
|
|||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ Engine = (*engine)(nil)
|
|
||||||
|
|
||||||
type workerPacket struct {
|
type workerPacket struct {
|
||||||
StreamID uint32
|
Packet io.Packet
|
||||||
Data []byte
|
StreamID uint32
|
||||||
SrcMAC net.HardwareAddr
|
Data []byte
|
||||||
DstMAC net.HardwareAddr
|
SrcMAC net.HardwareAddr
|
||||||
SetVerdict func(io.Verdict, []byte) error
|
DstMAC net.HardwareAddr
|
||||||
|
Gen int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type workerResult struct {
|
||||||
|
Packet io.Packet
|
||||||
|
StreamID uint32
|
||||||
|
Verdict io.Verdict
|
||||||
|
ModifiedPacket []byte
|
||||||
|
Gen int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type worker struct {
|
type worker struct {
|
||||||
id int
|
id int
|
||||||
packetChan chan *workerPacket
|
packetChan chan *workerPacket
|
||||||
|
resultChan chan workerResult
|
||||||
logger Logger
|
logger Logger
|
||||||
macResolver *sourceMACResolver
|
macResolver *sourceMACResolver
|
||||||
|
|
||||||
tcpFlowMgr *tcpFlowManager
|
tcpFlowMgr *tcpFlowManager
|
||||||
udpSM *udpStreamManager
|
udpSM *udpStreamManager
|
||||||
|
|
||||||
modSerializeBuffer gopacket.SerializeBuffer
|
modSerializeBuffer gopacket.SerializeBuffer
|
||||||
}
|
}
|
||||||
@@ -43,6 +51,9 @@ type workerConfig struct {
|
|||||||
TCPMaxBufferedPagesTotal int // unused, kept for config compat
|
TCPMaxBufferedPagesTotal int // unused, kept for config compat
|
||||||
TCPMaxBufferedPagesPerConn int // unused, kept for config compat
|
TCPMaxBufferedPagesPerConn int // unused, kept for config compat
|
||||||
UDPMaxStreams int
|
UDPMaxStreams int
|
||||||
|
AnalyzerSelectionMode AnalyzerSelectionMode
|
||||||
|
ResultChan chan workerResult
|
||||||
|
Stats *statsCounters
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *workerConfig) fillDefaults() {
|
func (c *workerConfig) fillDefaults() {
|
||||||
@@ -61,7 +72,8 @@ func newWorker(config workerConfig) (*worker, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode)
|
selector := newAnalyzerSelector(config.AnalyzerSelectionMode, config.Stats)
|
||||||
|
tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode, selector)
|
||||||
if config.Ruleset != nil {
|
if config.Ruleset != nil {
|
||||||
tcpMgr.updateRuleset(config.Ruleset, 0)
|
tcpMgr.updateRuleset(config.Ruleset, 0)
|
||||||
}
|
}
|
||||||
@@ -71,8 +83,10 @@ func newWorker(config workerConfig) (*worker, error) {
|
|||||||
Logger: config.Logger,
|
Logger: config.Logger,
|
||||||
Node: sfNode,
|
Node: sfNode,
|
||||||
Ruleset: config.Ruleset,
|
Ruleset: config.Ruleset,
|
||||||
|
Selector: selector,
|
||||||
|
Stats: config.Stats,
|
||||||
}
|
}
|
||||||
udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams)
|
udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams, config.Stats)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -80,6 +94,7 @@ func newWorker(config workerConfig) (*worker, error) {
|
|||||||
return &worker{
|
return &worker{
|
||||||
id: config.ID,
|
id: config.ID,
|
||||||
packetChan: make(chan *workerPacket, config.ChanSize),
|
packetChan: make(chan *workerPacket, config.ChanSize),
|
||||||
|
resultChan: config.ResultChan,
|
||||||
logger: config.Logger,
|
logger: config.Logger,
|
||||||
macResolver: config.MACResolver,
|
macResolver: config.MACResolver,
|
||||||
tcpFlowMgr: tcpMgr,
|
tcpFlowMgr: tcpMgr,
|
||||||
@@ -97,6 +112,10 @@ func (w *worker) Feed(p *workerPacket) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *worker) FeedBlocking(p *workerPacket) {
|
||||||
|
w.packetChan <- p
|
||||||
|
}
|
||||||
|
|
||||||
func (w *worker) Run(ctx context.Context) {
|
func (w *worker) Run(ctx context.Context) {
|
||||||
w.logger.WorkerStart(w.id)
|
w.logger.WorkerStart(w.id)
|
||||||
defer w.logger.WorkerStop(w.id)
|
defer w.logger.WorkerStop(w.id)
|
||||||
@@ -109,7 +128,13 @@ func (w *worker) Run(ctx context.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
v, b := w.handle(wp)
|
v, b := w.handle(wp)
|
||||||
_ = wp.SetVerdict(v, b)
|
w.resultChan <- workerResult{
|
||||||
|
Packet: wp.Packet,
|
||||||
|
StreamID: wp.StreamID,
|
||||||
|
Verdict: v,
|
||||||
|
ModifiedPacket: b,
|
||||||
|
Gen: wp.Gen,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -185,8 +210,7 @@ func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []by
|
|||||||
SrcMAC: srcMAC,
|
SrcMAC: srcMAC,
|
||||||
DstMAC: dstMAC,
|
DstMAC: dstMAC,
|
||||||
}
|
}
|
||||||
// Temporarily set payload on a UDP layer so existing UDP handling works
|
// Temporarily set payload on a UDP layer so existing UDP handling works.
|
||||||
// We pass the payload through the context
|
|
||||||
w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{
|
w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{
|
||||||
BaseLayer: layers.BaseLayer{Payload: payload},
|
BaseLayer: layers.BaseLayer{Payload: payload},
|
||||||
SrcPort: layers.UDPPort(udp.SrcPort),
|
SrcPort: layers.UDPPort(udp.SrcPort),
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
//go:build linux
|
||||||
|
// +build linux
|
||||||
|
|
||||||
package io
|
package io
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -0,0 +1,43 @@
|
|||||||
|
//go:build !linux
|
||||||
|
// +build !linux
|
||||||
|
|
||||||
|
package io
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errNFQueueUnsupported = errors.New("nfqueue packet io is only supported on linux")
|
||||||
|
|
||||||
|
type NFQueuePacketIOConfig struct {
|
||||||
|
QueueSize uint32
|
||||||
|
ReadBuffer int
|
||||||
|
WriteBuffer int
|
||||||
|
Local bool
|
||||||
|
RST bool
|
||||||
|
NumQueues int
|
||||||
|
MaxPacketLen uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
|
||||||
|
_ = config
|
||||||
|
return nil, errNFQueueUnsupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*unsupportedPacketIO) Register(context.Context, PacketCallback) error {
|
||||||
|
return errNFQueueUnsupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*unsupportedPacketIO) SetVerdict(Packet, Verdict, []byte) error {
|
||||||
|
return errNFQueueUnsupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*unsupportedPacketIO) ProtectedDialContext(context.Context, string, string) (net.Conn, error) {
|
||||||
|
return nil, errNFQueueUnsupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*unsupportedPacketIO) Close() error { return nil }
|
||||||
|
|
||||||
|
type unsupportedPacketIO struct{}
|
||||||
@@ -1,52 +1,45 @@
|
|||||||
package geo
|
package geo
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"container/list"
|
||||||
"net"
|
"net"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
geoSiteResultCacheSize = 1 << 16
|
||||||
|
geoSiteSetResultCacheSize = 1 << 16
|
||||||
|
)
|
||||||
|
|
||||||
type GeoMatcher struct {
|
type GeoMatcher struct {
|
||||||
geoLoader GeoLoader
|
geoLoader GeoLoader
|
||||||
geoSiteMatcher map[string]hostMatcher
|
geoSiteMatcher map[string]hostMatcher
|
||||||
siteMatcherLock sync.Mutex
|
siteMatcherLock sync.RWMutex
|
||||||
|
geoSiteSets map[string][]hostMatcher
|
||||||
|
siteSetLock sync.RWMutex
|
||||||
geoIpMatcher map[string]hostMatcher
|
geoIpMatcher map[string]hostMatcher
|
||||||
ipMatcherLock sync.Mutex
|
ipMatcherLock sync.RWMutex
|
||||||
|
geoSiteResult *boolLRUCache
|
||||||
|
geoSiteSetCache *boolLRUCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGeoMatcher(geoSiteFilename, geoIpFilename string) *GeoMatcher {
|
func NewGeoMatcher(geoSiteFilename, geoIpFilename string) *GeoMatcher {
|
||||||
return &GeoMatcher{
|
return &GeoMatcher{
|
||||||
geoLoader: NewDefaultGeoLoader(geoSiteFilename, geoIpFilename),
|
geoLoader: NewDefaultGeoLoader(geoSiteFilename, geoIpFilename),
|
||||||
geoSiteMatcher: make(map[string]hostMatcher),
|
geoSiteMatcher: make(map[string]hostMatcher),
|
||||||
geoIpMatcher: make(map[string]hostMatcher),
|
geoSiteSets: make(map[string][]hostMatcher),
|
||||||
|
geoIpMatcher: make(map[string]hostMatcher),
|
||||||
|
geoSiteResult: newBoolLRUCache(geoSiteResultCacheSize),
|
||||||
|
geoSiteSetCache: newBoolLRUCache(geoSiteSetResultCacheSize),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool {
|
func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool {
|
||||||
g.ipMatcherLock.Lock()
|
matcher, ok := g.getOrCreateGeoIPMatcher(condition)
|
||||||
defer g.ipMatcherLock.Unlock()
|
if !ok || matcher == nil {
|
||||||
|
return false
|
||||||
matcher, ok := g.geoIpMatcher[condition]
|
|
||||||
if !ok {
|
|
||||||
// GeoIP matcher
|
|
||||||
condition = strings.ToLower(condition)
|
|
||||||
country := condition
|
|
||||||
if len(country) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
gMap, err := g.geoLoader.LoadGeoIP()
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
list, ok := gMap[country]
|
|
||||||
if !ok || list == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
matcher, err = newGeoIPMatcher(list)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
g.geoIpMatcher[condition] = matcher
|
|
||||||
}
|
}
|
||||||
parseIp := net.ParseIP(ip)
|
parseIp := net.ParseIP(ip)
|
||||||
if parseIp == nil {
|
if parseIp == nil {
|
||||||
@@ -64,32 +57,69 @@ func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *GeoMatcher) MatchGeoSite(site, condition string) bool {
|
func (g *GeoMatcher) MatchGeoSite(site, condition string) bool {
|
||||||
g.siteMatcherLock.Lock()
|
conditionKey := strings.TrimSpace(strings.ToLower(condition))
|
||||||
defer g.siteMatcherLock.Unlock()
|
if conditionKey == "" {
|
||||||
|
return false
|
||||||
matcher, ok := g.geoSiteMatcher[condition]
|
|
||||||
if !ok {
|
|
||||||
// MatchGeoSite matcher
|
|
||||||
condition = strings.ToLower(condition)
|
|
||||||
name, attrs := parseGeoSiteName(condition)
|
|
||||||
if len(name) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
gMap, err := g.geoLoader.LoadGeoSite()
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
list, ok := gMap[name]
|
|
||||||
if !ok || list == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
matcher, err = newGeositeMatcher(list, attrs)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
g.geoSiteMatcher[condition] = matcher
|
|
||||||
}
|
}
|
||||||
return matcher.Match(HostInfo{Name: site})
|
cacheKey := site + "\x1f" + conditionKey
|
||||||
|
if v, ok := g.geoSiteResult.Get(cacheKey); ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
matcher, ok := g.getOrCreateGeoSiteMatcher(condition)
|
||||||
|
if !ok || matcher == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
result := matcher.Match(HostInfo{Name: site})
|
||||||
|
g.geoSiteResult.Set(cacheKey, result)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GeoMatcher) MatchGeoSiteSet(site string, set *SiteConditionSet) bool {
|
||||||
|
if set == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
conditions := normalizeGeoSiteSetConditions(set.Conditions)
|
||||||
|
if len(conditions) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
key := strings.Join(conditions, "\x1f")
|
||||||
|
cacheKey := site + "\x1e" + key
|
||||||
|
if v, ok := g.geoSiteSetCache.Get(cacheKey); ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
g.siteSetLock.RLock()
|
||||||
|
matchers, ok := g.geoSiteSets[key]
|
||||||
|
g.siteSetLock.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
compiled := make([]hostMatcher, 0, len(conditions))
|
||||||
|
for _, condition := range conditions {
|
||||||
|
m, ok := g.getOrCreateGeoSiteMatcher(condition)
|
||||||
|
if ok && m != nil {
|
||||||
|
compiled = append(compiled, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
g.siteSetLock.Lock()
|
||||||
|
if existing, exists := g.geoSiteSets[key]; exists {
|
||||||
|
matchers = existing
|
||||||
|
} else {
|
||||||
|
g.geoSiteSets[key] = compiled
|
||||||
|
matchers = compiled
|
||||||
|
}
|
||||||
|
g.siteSetLock.Unlock()
|
||||||
|
}
|
||||||
|
if len(matchers) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
host := HostInfo{Name: site}
|
||||||
|
for _, matcher := range matchers {
|
||||||
|
if matcher.Match(host) {
|
||||||
|
g.geoSiteSetCache.Set(cacheKey, true)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
g.geoSiteSetCache.Set(cacheKey, false)
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GeoMatcher) LoadGeoSite() error {
|
func (g *GeoMatcher) LoadGeoSite() error {
|
||||||
@@ -111,3 +141,152 @@ func parseGeoSiteName(s string) (string, []string) {
|
|||||||
}
|
}
|
||||||
return base, attrs
|
return base, attrs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *GeoMatcher) getOrCreateGeoSiteMatcher(condition string) (hostMatcher, bool) {
|
||||||
|
condition = strings.TrimSpace(strings.ToLower(condition))
|
||||||
|
if condition == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
g.siteMatcherLock.RLock()
|
||||||
|
matcher, ok := g.geoSiteMatcher[condition]
|
||||||
|
g.siteMatcherLock.RUnlock()
|
||||||
|
if ok {
|
||||||
|
return matcher, true
|
||||||
|
}
|
||||||
|
|
||||||
|
name, attrs := parseGeoSiteName(condition)
|
||||||
|
if len(name) == 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
gMap, err := g.geoLoader.LoadGeoSite()
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
list, ok := gMap[name]
|
||||||
|
if !ok || list == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
matcher, err = newGeositeMatcher(list, attrs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
g.siteMatcherLock.Lock()
|
||||||
|
if existing, exists := g.geoSiteMatcher[condition]; exists {
|
||||||
|
matcher = existing
|
||||||
|
} else {
|
||||||
|
g.geoSiteMatcher[condition] = matcher
|
||||||
|
}
|
||||||
|
g.siteMatcherLock.Unlock()
|
||||||
|
return matcher, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GeoMatcher) getOrCreateGeoIPMatcher(condition string) (hostMatcher, bool) {
|
||||||
|
condition = strings.TrimSpace(strings.ToLower(condition))
|
||||||
|
if condition == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
g.ipMatcherLock.RLock()
|
||||||
|
matcher, ok := g.geoIpMatcher[condition]
|
||||||
|
g.ipMatcherLock.RUnlock()
|
||||||
|
if ok {
|
||||||
|
return matcher, true
|
||||||
|
}
|
||||||
|
|
||||||
|
gMap, err := g.geoLoader.LoadGeoIP()
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
list, ok := gMap[condition]
|
||||||
|
if !ok || list == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
matcher, err = newGeoIPMatcher(list)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
g.ipMatcherLock.Lock()
|
||||||
|
if existing, exists := g.geoIpMatcher[condition]; exists {
|
||||||
|
matcher = existing
|
||||||
|
} else {
|
||||||
|
g.geoIpMatcher[condition] = matcher
|
||||||
|
}
|
||||||
|
g.ipMatcherLock.Unlock()
|
||||||
|
return matcher, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeGeoSiteSetConditions(in []string) []string {
|
||||||
|
if len(in) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]string, 0, len(in))
|
||||||
|
seen := make(map[string]struct{}, len(in))
|
||||||
|
for _, v := range in {
|
||||||
|
s := strings.TrimSpace(strings.ToLower(v))
|
||||||
|
if s == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[s]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[s] = struct{}{}
|
||||||
|
out = append(out, s)
|
||||||
|
}
|
||||||
|
sort.Strings(out)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
type boolLRUCache struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
cap int
|
||||||
|
ll *list.List
|
||||||
|
items map[string]*list.Element
|
||||||
|
}
|
||||||
|
|
||||||
|
type boolCacheEntry struct {
|
||||||
|
key string
|
||||||
|
value bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBoolLRUCache(capacity int) *boolLRUCache {
|
||||||
|
if capacity <= 0 {
|
||||||
|
capacity = 1
|
||||||
|
}
|
||||||
|
return &boolLRUCache{
|
||||||
|
cap: capacity,
|
||||||
|
ll: list.New(),
|
||||||
|
items: make(map[string]*list.Element, capacity),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *boolLRUCache) Get(key string) (bool, bool) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if ele, ok := c.items[key]; ok {
|
||||||
|
c.ll.MoveToFront(ele)
|
||||||
|
entry := ele.Value.(boolCacheEntry)
|
||||||
|
return entry.value, true
|
||||||
|
}
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *boolLRUCache) Set(key string, value bool) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if ele, ok := c.items[key]; ok {
|
||||||
|
ele.Value = boolCacheEntry{key: key, value: value}
|
||||||
|
c.ll.MoveToFront(ele)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ele := c.ll.PushFront(boolCacheEntry{key: key, value: value})
|
||||||
|
c.items[key] = ele
|
||||||
|
if c.ll.Len() <= c.cap {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
back := c.ll.Back()
|
||||||
|
if back == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entry := back.Value.(boolCacheEntry)
|
||||||
|
delete(c.items, entry.key)
|
||||||
|
c.ll.Remove(back)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package geo
|
package geo
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
|
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
|
||||||
)
|
)
|
||||||
|
|
||||||
type fakeGeoLoader struct {
|
type fakeGeoLoader struct {
|
||||||
geoip map[string]*v2geo.GeoIP
|
geoip map[string]*v2geo.GeoIP
|
||||||
geosite map[string]*v2geo.GeoSite
|
geosite map[string]*v2geo.GeoSite
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,6 +111,83 @@ func TestGeoMatcher_MatchGeoSite_MissingSite(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGeoMatcher_MatchGeoSiteSet(t *testing.T) {
|
||||||
|
loader := &fakeGeoLoader{
|
||||||
|
geosite: map[string]*v2geo.GeoSite{
|
||||||
|
"openai": {
|
||||||
|
Domain: []*v2geo.Domain{
|
||||||
|
{Type: v2geo.Domain_Plain, Value: "openai"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"google": {
|
||||||
|
Domain: []*v2geo.Domain{
|
||||||
|
{Type: v2geo.Domain_RootDomain, Value: "google.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
g := NewGeoMatcher("", "")
|
||||||
|
g.geoLoader = loader
|
||||||
|
|
||||||
|
set := &SiteConditionSet{Conditions: []string{" google ", "openai", "OPENAI"}}
|
||||||
|
if !g.MatchGeoSiteSet("api.openai.com", set) {
|
||||||
|
t.Error("MatchGeoSiteSet should match openai")
|
||||||
|
}
|
||||||
|
if !g.MatchGeoSiteSet("mail.google.com", set) {
|
||||||
|
t.Error("MatchGeoSiteSet should match google")
|
||||||
|
}
|
||||||
|
if g.MatchGeoSiteSet("example.com", set) {
|
||||||
|
t.Error("MatchGeoSiteSet should not match unrelated host")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type countingMatcher struct {
|
||||||
|
calls *atomic.Uint64
|
||||||
|
match bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m countingMatcher) Match(host HostInfo) bool {
|
||||||
|
_ = host
|
||||||
|
m.calls.Add(1)
|
||||||
|
return m.match
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeoMatcher_MatchGeoSite_UsesResultCache(t *testing.T) {
|
||||||
|
g := NewGeoMatcher("", "")
|
||||||
|
var calls atomic.Uint64
|
||||||
|
g.geoSiteMatcher["openai"] = countingMatcher{calls: &calls, match: true}
|
||||||
|
|
||||||
|
if !g.MatchGeoSite("api.openai.com", "openai") {
|
||||||
|
t.Fatal("expected match")
|
||||||
|
}
|
||||||
|
if !g.MatchGeoSite("api.openai.com", "openai") {
|
||||||
|
t.Fatal("expected cached match")
|
||||||
|
}
|
||||||
|
if got := calls.Load(); got != 1 {
|
||||||
|
t.Fatalf("matcher calls=%d want=1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeoMatcher_MatchGeoSiteSet_UsesResultCache(t *testing.T) {
|
||||||
|
g := NewGeoMatcher("", "")
|
||||||
|
var calls atomic.Uint64
|
||||||
|
g.geoSiteSets["openai\x1fyoutube"] = []hostMatcher{
|
||||||
|
countingMatcher{calls: &calls, match: false},
|
||||||
|
countingMatcher{calls: &calls, match: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
set := &SiteConditionSet{Conditions: []string{"youtube", "openai"}}
|
||||||
|
if !g.MatchGeoSiteSet("www.youtube.com", set) {
|
||||||
|
t.Fatal("expected match")
|
||||||
|
}
|
||||||
|
if !g.MatchGeoSiteSet("www.youtube.com", set) {
|
||||||
|
t.Fatal("expected cached match")
|
||||||
|
}
|
||||||
|
if got := calls.Load(); got != 2 {
|
||||||
|
t.Fatalf("matcher calls=%d want=2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func ipv4(a, b, c, d byte) []byte {
|
func ipv4(a, b, c, d byte) []byte {
|
||||||
return []byte{a, b, c, d}
|
return []byte{a, b, c, d}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,10 @@ type HostInfo struct {
|
|||||||
IPv6 net.IP
|
IPv6 net.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SiteConditionSet struct {
|
||||||
|
Conditions []string
|
||||||
|
}
|
||||||
|
|
||||||
func (h HostInfo) String() string {
|
func (h HostInfo) String() string {
|
||||||
return fmt.Sprintf("%s|%s|%s", h.Name, h.IPv4, h.IPv6)
|
return fmt.Sprintf("%s|%s|%s", h.Name, h.IPv4, h.IPv6)
|
||||||
}
|
}
|
||||||
|
|||||||
+203
-24
@@ -60,8 +60,8 @@ type compiledExprRule struct {
|
|||||||
ModInstance modifier.Instance
|
ModInstance modifier.Instance
|
||||||
Program *vm.Program
|
Program *vm.Program
|
||||||
GeoSiteConditions []string
|
GeoSiteConditions []string
|
||||||
StartTimeSecs int // seconds since midnight, -1 if unset
|
StartTimeSecs int // seconds since midnight, -1 if unset
|
||||||
StopTimeSecs int // seconds since midnight, -1 if unset
|
StopTimeSecs int // seconds since midnight, -1 if unset
|
||||||
Weekdays []time.Weekday
|
Weekdays []time.Weekday
|
||||||
WeekdaysNegated bool
|
WeekdaysNegated bool
|
||||||
}
|
}
|
||||||
@@ -86,6 +86,7 @@ type exprRuleset struct {
|
|||||||
Ans []analyzer.Analyzer
|
Ans []analyzer.Analyzer
|
||||||
Logger Logger
|
Logger Logger
|
||||||
GeoMatcher *geo.GeoMatcher
|
GeoMatcher *geo.GeoMatcher
|
||||||
|
stats *statsCounters
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
|
func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
|
||||||
@@ -93,9 +94,24 @@ func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
||||||
|
start := time.Now()
|
||||||
|
if r.stats != nil {
|
||||||
|
r.stats.MatchCalls.Add(1)
|
||||||
|
defer func() {
|
||||||
|
r.stats.MatchLatencyNanos.Add(uint64(time.Since(start).Nanoseconds()))
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
env := envPool.Get().(map[string]any)
|
env := envPool.Get().(map[string]any)
|
||||||
clear(env)
|
clear(env)
|
||||||
populateExprEnv(env, info)
|
macMap, ipMap, portMap := populateExprEnv(env, info)
|
||||||
|
releaseEnv := func() {
|
||||||
|
clear(env)
|
||||||
|
envPool.Put(env)
|
||||||
|
putSubMap(macMap)
|
||||||
|
putSubMap(ipMap)
|
||||||
|
putSubMap(portMap)
|
||||||
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
for _, rule := range r.Rules {
|
for _, rule := range r.Rules {
|
||||||
if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) {
|
if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) {
|
||||||
@@ -103,6 +119,9 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
|||||||
}
|
}
|
||||||
v, err := vm.Run(rule.Program, env)
|
v, err := vm.Run(rule.Program, env)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if r.stats != nil {
|
||||||
|
r.stats.MatchErrors.Add(1)
|
||||||
|
}
|
||||||
r.Logger.MatchError(info, rule.Name, err)
|
r.Logger.MatchError(info, rule.Name, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -115,7 +134,7 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
|||||||
r.Logger.Log(logInfo, rule.Name)
|
r.Logger.Log(logInfo, rule.Name)
|
||||||
}
|
}
|
||||||
if rule.Action != nil {
|
if rule.Action != nil {
|
||||||
envPool.Put(env)
|
releaseEnv()
|
||||||
return MatchResult{
|
return MatchResult{
|
||||||
Action: *rule.Action,
|
Action: *rule.Action,
|
||||||
ModInstance: rule.ModInstance,
|
ModInstance: rule.ModInstance,
|
||||||
@@ -123,12 +142,26 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
envPool.Put(env)
|
releaseEnv()
|
||||||
return MatchResult{
|
return MatchResult{
|
||||||
Action: ActionMaybe,
|
Action: ActionMaybe,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *exprRuleset) Stats() Stats {
|
||||||
|
if r == nil || r.stats == nil {
|
||||||
|
return Stats{}
|
||||||
|
}
|
||||||
|
return Stats{
|
||||||
|
MatchCalls: r.stats.MatchCalls.Load(),
|
||||||
|
MatchErrors: r.stats.MatchErrors.Load(),
|
||||||
|
MatchLatencyNanos: r.stats.MatchLatencyNanos.Load(),
|
||||||
|
LookupCalls: r.stats.LookupCalls.Load(),
|
||||||
|
LookupErrors: r.stats.LookupErrors.Load(),
|
||||||
|
LookupLatencyNanos: r.stats.LookupLatencyNanos.Load(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CompileExprRules compiles a list of expression rules into a ruleset.
|
// CompileExprRules compiles a list of expression rules into a ruleset.
|
||||||
// It returns an error if any of the rules are invalid, or if any of the analyzers
|
// It returns an error if any of the rules are invalid, or if any of the analyzers
|
||||||
// used by the rules are unknown (not provided in the analyzer list).
|
// used by the rules are unknown (not provided in the analyzer list).
|
||||||
@@ -137,7 +170,8 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
|||||||
fullAnMap := analyzersToMap(ans)
|
fullAnMap := analyzersToMap(ans)
|
||||||
fullModMap := modifiersToMap(mods)
|
fullModMap := modifiersToMap(mods)
|
||||||
depAnMap := make(map[string]analyzer.Analyzer)
|
depAnMap := make(map[string]analyzer.Analyzer)
|
||||||
funcMap, geoMatcher := buildFunctionMap(config)
|
stats := &statsCounters{}
|
||||||
|
funcMap, geoMatcher := buildFunctionMap(config, stats)
|
||||||
// Compile all rules and build a map of analyzers that are used by the rules.
|
// Compile all rules and build a map of analyzers that are used by the rules.
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.Action == "" && !rule.Log {
|
if rule.Action == "" && !rule.Log {
|
||||||
@@ -152,7 +186,7 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
|||||||
action = &a
|
action = &a
|
||||||
}
|
}
|
||||||
visitor := &idVisitor{Variables: make(map[string]bool), Identifiers: make(map[string]bool)}
|
visitor := &idVisitor{Variables: make(map[string]bool), Identifiers: make(map[string]bool)}
|
||||||
patcher := &idPatcher{FuncMap: funcMap}
|
patcher := &idPatcher{FuncMap: funcMap, GeoMatcher: geoMatcher}
|
||||||
program, err := expr.Compile(rule.Expr,
|
program, err := expr.Compile(rule.Expr,
|
||||||
func(c *conf.Config) {
|
func(c *conf.Config) {
|
||||||
c.Strict = false
|
c.Strict = false
|
||||||
@@ -242,29 +276,47 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
|||||||
Ans: depAns,
|
Ans: depAns,
|
||||||
Logger: config.Logger,
|
Logger: config.Logger,
|
||||||
GeoMatcher: geoMatcher,
|
GeoMatcher: geoMatcher,
|
||||||
|
stats: stats,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func populateExprEnv(m map[string]any, info StreamInfo) {
|
func populateExprEnv(m map[string]any, info StreamInfo) (macMap, ipMap, portMap map[string]any) {
|
||||||
|
macMap = getSubMap()
|
||||||
|
ipMap = getSubMap()
|
||||||
|
portMap = getSubMap()
|
||||||
|
|
||||||
|
macMap["src"] = info.SrcMAC.String()
|
||||||
|
macMap["dst"] = info.DstMAC.String()
|
||||||
|
ipMap["src"] = info.SrcIP.String()
|
||||||
|
ipMap["dst"] = info.DstIP.String()
|
||||||
|
portMap["src"] = info.SrcPort
|
||||||
|
portMap["dst"] = info.DstPort
|
||||||
|
|
||||||
m["id"] = info.ID
|
m["id"] = info.ID
|
||||||
m["proto"] = info.Protocol.String()
|
m["proto"] = info.Protocol.String()
|
||||||
m["mac"] = map[string]string{
|
m["mac"] = macMap
|
||||||
"src": info.SrcMAC.String(),
|
m["ip"] = ipMap
|
||||||
"dst": info.DstMAC.String(),
|
m["port"] = portMap
|
||||||
}
|
|
||||||
m["ip"] = map[string]string{
|
|
||||||
"src": info.SrcIP.String(),
|
|
||||||
"dst": info.DstIP.String(),
|
|
||||||
}
|
|
||||||
m["port"] = map[string]uint16{
|
|
||||||
"src": info.SrcPort,
|
|
||||||
"dst": info.DstPort,
|
|
||||||
}
|
|
||||||
for anName, anProps := range info.Props {
|
for anName, anProps := range info.Props {
|
||||||
if len(anProps) != 0 {
|
if len(anProps) != 0 {
|
||||||
m[anName] = anProps
|
m[anName] = anProps
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return macMap, ipMap, portMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSubMap() map[string]any {
|
||||||
|
m := subMapPool.Get().(map[string]any)
|
||||||
|
clear(m)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func putSubMap(m map[string]any) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clear(m)
|
||||||
|
subMapPool.Put(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isBuiltInAnalyzer(name string) bool {
|
func isBuiltInAnalyzer(name string) bool {
|
||||||
@@ -329,11 +381,15 @@ func (v *idVisitor) Visit(node *ast.Node) {
|
|||||||
// idPatcher patches the AST during expr compilation, replacing certain values with
|
// idPatcher patches the AST during expr compilation, replacing certain values with
|
||||||
// their internal representations for better runtime performance.
|
// their internal representations for better runtime performance.
|
||||||
type idPatcher struct {
|
type idPatcher struct {
|
||||||
FuncMap map[string]*Function
|
FuncMap map[string]*Function
|
||||||
Err error
|
GeoMatcher *geo.GeoMatcher
|
||||||
|
Err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *idPatcher) Visit(node *ast.Node) {
|
func (p *idPatcher) Visit(node *ast.Node) {
|
||||||
|
if p.tryPatchGeoSiteORChain(node) {
|
||||||
|
return
|
||||||
|
}
|
||||||
switch (*node).(type) {
|
switch (*node).(type) {
|
||||||
case *ast.CallNode:
|
case *ast.CallNode:
|
||||||
callNode := (*node).(*ast.CallNode)
|
callNode := (*node).(*ast.CallNode)
|
||||||
@@ -352,6 +408,108 @@ func (p *idPatcher) Visit(node *ast.Node) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *idPatcher) tryPatchGeoSiteORChain(node *ast.Node) bool {
|
||||||
|
if p == nil || p.GeoMatcher == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
terms, ok := collectGeoSiteORChain(*node)
|
||||||
|
if !ok || len(terms) < 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
hostExpr := strings.TrimSpace(terms[0].hostExpr)
|
||||||
|
if hostExpr == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
conditions := make([]string, 0, len(terms))
|
||||||
|
for _, term := range terms {
|
||||||
|
if strings.TrimSpace(term.hostExpr) != hostExpr {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
conditions = append(conditions, term.condition)
|
||||||
|
}
|
||||||
|
normalized := normalizeUniqueLowerStrings(conditions)
|
||||||
|
if len(normalized) < 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
hostNode, err := parser.Parse(hostExpr)
|
||||||
|
if err != nil || hostNode == nil || hostNode.Node == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
call := &ast.CallNode{
|
||||||
|
Callee: &ast.IdentifierNode{Value: "geosite_set"},
|
||||||
|
Arguments: []ast.Node{
|
||||||
|
hostNode.Node,
|
||||||
|
&ast.ConstantNode{Value: &geo.SiteConditionSet{Conditions: normalized}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ast.Patch(node, call)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type geositeTerm struct {
|
||||||
|
hostExpr string
|
||||||
|
condition string
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectGeoSiteORChain(node ast.Node) ([]geositeTerm, bool) {
|
||||||
|
switch n := node.(type) {
|
||||||
|
case *ast.BinaryNode:
|
||||||
|
if n.Operator != "or" && n.Operator != "||" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
left, ok := collectGeoSiteORChain(n.Left)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
right, ok := collectGeoSiteORChain(n.Right)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
out := make([]geositeTerm, 0, len(left)+len(right))
|
||||||
|
out = append(out, left...)
|
||||||
|
out = append(out, right...)
|
||||||
|
return out, true
|
||||||
|
case *ast.CallNode:
|
||||||
|
idNode, ok := n.Callee.(*ast.IdentifierNode)
|
||||||
|
if !ok || len(n.Arguments) < 2 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
name := strings.ToLower(idNode.Value)
|
||||||
|
if name == "geosite" {
|
||||||
|
condNode, ok := n.Arguments[1].(*ast.StringNode)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return []geositeTerm{{
|
||||||
|
hostExpr: n.Arguments[0].String(),
|
||||||
|
condition: condNode.Value,
|
||||||
|
}}, true
|
||||||
|
}
|
||||||
|
if name != "geosite_set" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
setNode, ok := n.Arguments[1].(*ast.ConstantNode)
|
||||||
|
if !ok || setNode.Value == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
set, ok := setNode.Value.(*geo.SiteConditionSet)
|
||||||
|
if !ok || set == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if len(set.Conditions) == 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
out := make([]geositeTerm, 0, len(set.Conditions))
|
||||||
|
hostExpr := n.Arguments[0].String()
|
||||||
|
for _, condition := range set.Conditions {
|
||||||
|
out = append(out, geositeTerm{hostExpr: hostExpr, condition: condition})
|
||||||
|
}
|
||||||
|
return out, true
|
||||||
|
default:
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Function struct {
|
type Function struct {
|
||||||
InitFunc func() error
|
InitFunc func() error
|
||||||
PatchFunc func(args *[]ast.Node) error
|
PatchFunc func(args *[]ast.Node) error
|
||||||
@@ -359,7 +517,7 @@ type Function struct {
|
|||||||
Types []reflect.Type
|
Types []reflect.Type
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildFunctionMap(config *BuiltinConfig) (map[string]*Function, *geo.GeoMatcher) {
|
func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*Function, *geo.GeoMatcher) {
|
||||||
geoMatcher := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename)
|
geoMatcher := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename)
|
||||||
return map[string]*Function{
|
return map[string]*Function{
|
||||||
"geoip": {
|
"geoip": {
|
||||||
@@ -378,6 +536,16 @@ func buildFunctionMap(config *BuiltinConfig) (map[string]*Function, *geo.GeoMatc
|
|||||||
},
|
},
|
||||||
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
|
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
|
||||||
},
|
},
|
||||||
|
"geosite_set": {
|
||||||
|
InitFunc: geoMatcher.LoadGeoSite,
|
||||||
|
PatchFunc: nil,
|
||||||
|
Func: func(params ...any) (any, error) {
|
||||||
|
return geoMatcher.MatchGeoSiteSet(params[0].(string), params[1].(*geo.SiteConditionSet)), nil
|
||||||
|
},
|
||||||
|
Types: []reflect.Type{
|
||||||
|
reflect.TypeOf((func(string, *geo.SiteConditionSet) bool)(nil)),
|
||||||
|
},
|
||||||
|
},
|
||||||
"cidr": {
|
"cidr": {
|
||||||
InitFunc: nil,
|
InitFunc: nil,
|
||||||
PatchFunc: func(args *[]ast.Node) error {
|
PatchFunc: func(args *[]ast.Node) error {
|
||||||
@@ -425,9 +593,20 @@ func buildFunctionMap(config *BuiltinConfig) (map[string]*Function, *geo.GeoMatc
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
Func: func(params ...any) (any, error) {
|
Func: func(params ...any) (any, error) {
|
||||||
|
start := time.Now()
|
||||||
|
if stats != nil {
|
||||||
|
stats.LookupCalls.Add(1)
|
||||||
|
defer func() {
|
||||||
|
stats.LookupLatencyNanos.Add(uint64(time.Since(start).Nanoseconds()))
|
||||||
|
}()
|
||||||
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
return params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
|
out, err := params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
|
||||||
|
if err != nil && stats != nil {
|
||||||
|
stats.LookupErrors.Add(1)
|
||||||
|
}
|
||||||
|
return out, err
|
||||||
},
|
},
|
||||||
Types: []reflect.Type{
|
Types: []reflect.Type{
|
||||||
reflect.TypeOf((func(string, *net.Resolver) []string)(nil)),
|
reflect.TypeOf((func(string, *net.Resolver) []string)(nil)),
|
||||||
|
|||||||
@@ -2,9 +2,14 @@ package ruleset
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||||
|
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo"
|
||||||
|
|
||||||
|
"github.com/expr-lang/expr/ast"
|
||||||
|
"github.com/expr-lang/expr/parser"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestExtractGeoSiteConditions(t *testing.T) {
|
func TestExtractGeoSiteConditions(t *testing.T) {
|
||||||
@@ -63,3 +68,23 @@ func TestMatchGeoSiteConditions(t *testing.T) {
|
|||||||
t.Fatalf("matchGeoSiteConditions() = %v, want %v", got, want)
|
t.Fatalf("matchGeoSiteConditions() = %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIDPatcher_PatchesGeoSiteORChainToGeoSiteSet(t *testing.T) {
|
||||||
|
tree, err := parser.Parse(`geosite(tls.req.sni, "google") || geosite(tls.req.sni, "youtube") || geosite(tls.req.sni, "openai")`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse expression: %v", err)
|
||||||
|
}
|
||||||
|
root := tree.Node
|
||||||
|
patcher := &idPatcher{GeoMatcher: geo.NewGeoMatcher("", "")}
|
||||||
|
ast.Walk(&root, patcher)
|
||||||
|
if patcher.Err != nil {
|
||||||
|
t.Fatalf("patch error: %v", patcher.Err)
|
||||||
|
}
|
||||||
|
got := root.String()
|
||||||
|
if !strings.Contains(got, "geosite_set(") {
|
||||||
|
t.Fatalf("expected geosite_set rewrite, got %q", got)
|
||||||
|
}
|
||||||
|
if strings.Contains(got, "||") || strings.Contains(got, " or ") {
|
||||||
|
t.Fatalf("expected OR chain to be collapsed, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||||
"git.difuse.io/Difuse/Mellaris/modifier"
|
"git.difuse.io/Difuse/Mellaris/modifier"
|
||||||
@@ -95,6 +96,28 @@ type Ruleset interface {
|
|||||||
Match(StreamInfo) MatchResult
|
Match(StreamInfo) MatchResult
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Stats struct {
|
||||||
|
MatchCalls uint64
|
||||||
|
MatchErrors uint64
|
||||||
|
MatchLatencyNanos uint64
|
||||||
|
LookupCalls uint64
|
||||||
|
LookupErrors uint64
|
||||||
|
LookupLatencyNanos uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type statsCounters struct {
|
||||||
|
MatchCalls atomic.Uint64
|
||||||
|
MatchErrors atomic.Uint64
|
||||||
|
MatchLatencyNanos atomic.Uint64
|
||||||
|
LookupCalls atomic.Uint64
|
||||||
|
LookupErrors atomic.Uint64
|
||||||
|
LookupLatencyNanos atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type StatsProvider interface {
|
||||||
|
Stats() Stats
|
||||||
|
}
|
||||||
|
|
||||||
// Logger is the logging interface for the ruleset.
|
// Logger is the logging interface for the ruleset.
|
||||||
type Logger interface {
|
type Logger interface {
|
||||||
Log(info StreamInfo, name string)
|
Log(info StreamInfo, name string)
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package mellaris
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.difuse.io/Difuse/Mellaris/engine"
|
||||||
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Stats struct {
|
||||||
|
Engine engine.Stats
|
||||||
|
Ruleset ruleset.Stats
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *App) Stats() Stats {
|
||||||
|
if a == nil || a.engine == nil {
|
||||||
|
return Stats{}
|
||||||
|
}
|
||||||
|
out := Stats{Engine: a.engine.Stats()}
|
||||||
|
if rs, ok := a.ruleset.(ruleset.StatsProvider); ok {
|
||||||
|
out.Ruleset = rs.Stats()
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user