Improves flow handling and adds runtime stats APIs

Refactors TCP and UDP flow managers to enhance analyzer selection and flow binding accuracy, including O(1) UDP stream rebinding by 5-tuple.
Introduces runtime stats tracking for engine and ruleset operations, exposing new APIs for granular performance and error metrics.
Optimizes GeoMatcher with result caching and supports efficient geosite set matching, reducing redundant computation in ruleset expressions.
This commit is contained in:
2026-05-13 06:10:38 +05:30
parent 3f895adb43
commit 7a3f6e945d
23 changed files with 1440 additions and 152 deletions
+235
View File
@@ -0,0 +1,235 @@
package engine
import (
"bytes"
"strings"
"git.difuse.io/Difuse/Mellaris/analyzer"
)
type analyzerSelector struct {
mode AnalyzerSelectionMode
stats *statsCounters
}
func newAnalyzerSelector(mode AnalyzerSelectionMode, stats *statsCounters) *analyzerSelector {
if mode == "" {
mode = AnalyzerSelectionModeSignature
}
return &analyzerSelector{mode: mode, stats: stats}
}
func (s *analyzerSelector) SelectTCP(ans []analyzer.Analyzer, payload []byte) []analyzer.Analyzer {
if s == nil || s.mode == AnalyzerSelectionModeAlways || len(ans) <= 1 {
return ans
}
allowed := tcpAllowedAnalyzers(payload)
if len(allowed) == 0 {
return ans
}
out := make([]analyzer.Analyzer, 0, len(ans))
for _, a := range ans {
name := strings.ToLower(a.Name())
if _, known := knownTCPAnalyzers[name]; !known {
out = append(out, a)
continue
}
if allowed[name] {
out = append(out, a)
}
}
s.recordSelection(len(ans), len(out))
if len(out) == 0 {
return ans
}
return out
}
func (s *analyzerSelector) SelectUDP(ans []analyzer.Analyzer, payload []byte) []analyzer.Analyzer {
if s == nil || s.mode == AnalyzerSelectionModeAlways || len(ans) <= 1 {
return ans
}
allowed := udpAllowedAnalyzers(payload)
if len(allowed) == 0 {
return ans
}
out := make([]analyzer.Analyzer, 0, len(ans))
for _, a := range ans {
name := strings.ToLower(a.Name())
if _, known := knownUDPAnalyzers[name]; !known {
out = append(out, a)
continue
}
if allowed[name] {
out = append(out, a)
}
}
s.recordSelection(len(ans), len(out))
if len(out) == 0 {
return ans
}
return out
}
func (s *analyzerSelector) recordSelection(total, selected int) {
if s == nil || s.stats == nil || total <= 0 {
return
}
s.stats.AnalyzerSelectionsTotal.Add(1)
if selected < total {
s.stats.AnalyzerSelectionsPruned.Add(1)
}
}
var (
knownTCPAnalyzers = map[string]struct{}{
"fet": {},
"http": {},
"socks": {},
"ssh": {},
"tls": {},
"trojan": {},
"dns": {},
"openvpn": {},
}
knownUDPAnalyzers = map[string]struct{}{
"dns": {},
"openvpn": {},
"quic": {},
"wireguard": {},
}
)
func tcpAllowedAnalyzers(payload []byte) map[string]bool {
allowed := make(map[string]bool, 4)
if looksLikeTLS(payload) {
allowed["tls"] = true
allowed["trojan"] = true
allowed["fet"] = true
}
if looksLikeHTTP(payload) {
allowed["http"] = true
allowed["fet"] = true
}
if looksLikeSSH(payload) {
allowed["ssh"] = true
allowed["fet"] = true
}
if looksLikeSOCKS(payload) {
allowed["socks"] = true
allowed["fet"] = true
}
if looksLikeDNSTCP(payload) {
allowed["dns"] = true
allowed["fet"] = true
}
if len(allowed) == 0 {
return nil
}
return allowed
}
func udpAllowedAnalyzers(payload []byte) map[string]bool {
allowed := make(map[string]bool, 4)
if looksLikeWireGuard(payload) {
allowed["wireguard"] = true
}
if looksLikeOpenVPN(payload) {
allowed["openvpn"] = true
}
if looksLikeQUIC(payload) {
allowed["quic"] = true
}
if looksLikeDNSUDP(payload) {
allowed["dns"] = true
}
if len(allowed) == 0 {
return nil
}
return allowed
}
func looksLikeTLS(payload []byte) bool {
if len(payload) < 3 {
return false
}
return (payload[0] == 0x16 || payload[0] == 0x17) && payload[1] == 0x03 && payload[2] <= 0x09
}
func looksLikeHTTP(payload []byte) bool {
if len(payload) < 3 {
return false
}
head := strings.ToUpper(string(payload[:3]))
switch head {
case "GET", "HEA", "POS", "PUT", "DEL", "CON", "OPT", "TRA", "PAT":
return true
default:
return false
}
}
func looksLikeSSH(payload []byte) bool {
return len(payload) >= 4 && bytes.HasPrefix(payload, []byte("SSH-"))
}
func looksLikeSOCKS(payload []byte) bool {
if len(payload) < 2 {
return false
}
return payload[0] == 0x04 || payload[0] == 0x05
}
func looksLikeDNSTCP(payload []byte) bool {
if len(payload) < 14 {
return false
}
msgLen := int(payload[0])<<8 | int(payload[1])
if msgLen <= 0 || msgLen+2 > len(payload) {
return false
}
qd := int(payload[6])<<8 | int(payload[7])
an := int(payload[8])<<8 | int(payload[9])
return qd+an > 0
}
func looksLikeDNSUDP(payload []byte) bool {
if len(payload) < 12 {
return false
}
qd := int(payload[4])<<8 | int(payload[5])
an := int(payload[6])<<8 | int(payload[7])
ns := int(payload[8])<<8 | int(payload[9])
ar := int(payload[10])<<8 | int(payload[11])
return qd+an+ns+ar > 0
}
func looksLikeQUIC(payload []byte) bool {
if len(payload) < 6 {
return false
}
// Long header with non-zero version.
if payload[0]&0x80 == 0 {
return false
}
version := uint32(payload[1])<<24 | uint32(payload[2])<<16 | uint32(payload[3])<<8 | uint32(payload[4])
return version != 0
}
func looksLikeOpenVPN(payload []byte) bool {
if len(payload) == 0 {
return false
}
opcode := payload[0] >> 3
return opcode >= 1 && opcode <= 11
}
func looksLikeWireGuard(payload []byte) bool {
if len(payload) < 4 {
return false
}
if payload[0] < 1 || payload[0] > 4 {
return false
}
return payload[1] == 0 && payload[2] == 0 && payload[3] == 0
}
+56
View File
@@ -0,0 +1,56 @@
package engine
import (
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
)
type namedAnalyzer struct{ name string }
func (a namedAnalyzer) Name() string { return a.name }
func (a namedAnalyzer) Limit() int { return 0 }
func TestSignatureSelectorTCPPrunesByPayloadNotPort(t *testing.T) {
sel := newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{})
all := []analyzer.Analyzer{
namedAnalyzer{"http"},
namedAnalyzer{"tls"},
namedAnalyzer{"trojan"},
namedAnalyzer{"ssh"},
namedAnalyzer{"socks"},
namedAnalyzer{"fet"},
}
// TLS record-like prefix, regardless of destination port.
payload := []byte{0x16, 0x03, 0x03, 0x00, 0x10}
selected := sel.SelectTCP(all, payload)
got := make(map[string]bool)
for _, a := range selected {
got[a.Name()] = true
}
for _, keep := range []string{"tls", "trojan", "fet"} {
if !got[keep] {
t.Fatalf("expected analyzer %q to be selected", keep)
}
}
for _, drop := range []string{"http", "ssh", "socks"} {
if got[drop] {
t.Fatalf("expected analyzer %q to be pruned", drop)
}
}
}
func TestSignatureSelectorConservativeFallback(t *testing.T) {
sel := newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{})
all := []analyzer.Analyzer{
namedAnalyzer{"http"},
namedAnalyzer{"tls"},
namedAnalyzer{"custom"},
}
payload := []byte{0xde, 0xad, 0xbe, 0xef}
selected := sel.SelectTCP(all, payload)
if len(selected) != len(all) {
t.Fatalf("expected conservative fallback to keep all analyzers, got=%d want=%d", len(selected), len(all))
}
}
+68 -25
View File
@@ -2,6 +2,7 @@ package engine
import (
"context"
"runtime"
"sync"
"sync/atomic"
@@ -20,18 +21,34 @@ type engine struct {
logger Logger
io io.PacketIO
workers []*worker
verdicts sync.Map // streamID(uint32) → verdictEntry
stats *statsCounters
verdicts sync.Map // streamID(uint32) -> verdictEntry
verdictsGen atomic.Int64 // incremented on ruleset update
overflowCh chan *workerPacket
overflowOnce sync.Once
overflowPolicy OverflowPolicy
resultCh chan workerResult
}
func NewEngine(config Config) (Engine, error) {
workerCount := config.Workers
if workerCount <= 0 {
workerCount = 1
workerCount = runtime.GOMAXPROCS(0)
if workerCount <= 0 {
workerCount = 1
}
}
overflowPolicy := config.OverflowPolicy
if overflowPolicy == "" {
overflowPolicy = OverflowPolicyAccept
}
selectionMode := config.AnalyzerSelectionMode
if selectionMode == "" {
selectionMode = AnalyzerSelectionModeSignature
}
stats := &statsCounters{}
resultCh := make(chan workerResult, workerCount*256)
macResolver := newSourceMACResolver()
var err error
workers := make([]*worker, workerCount)
@@ -45,16 +62,21 @@ func NewEngine(config Config) (Engine, error) {
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
UDPMaxStreams: config.WorkerUDPMaxStreams,
AnalyzerSelectionMode: selectionMode,
ResultChan: resultCh,
Stats: stats,
})
if err != nil {
return nil, err
}
}
e := &engine{
logger: config.Logger,
io: config.IO,
workers: workers,
overflowCh: make(chan *workerPacket, 1024),
logger: config.Logger,
io: config.IO,
workers: workers,
stats: stats,
overflowPolicy: overflowPolicy,
resultCh: resultCh,
}
return e, nil
}
@@ -74,13 +96,10 @@ func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel()
e.overflowOnce.Do(func() {
go e.drainOverflow(ioCtx)
})
for _, w := range e.workers {
go w.Run(ioCtx)
}
go e.drainResults(ioCtx)
errChan := make(chan error, 1)
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
@@ -121,24 +140,35 @@ func (e *engine) dispatch(p io.Packet) bool {
gen := e.verdictsGen.Load()
index := streamID % uint32(len(e.workers))
wp := &workerPacket{
StreamID: streamID,
Data: data,
SetVerdict: func(v io.Verdict, b []byte) error {
if v == io.VerdictAcceptStream || v == io.VerdictDropStream {
e.verdicts.Store(streamID, verdictEntry{Verdict: v, Gen: gen})
}
return e.io.SetVerdict(p, v, b)
},
Packet: p,
StreamID: streamID,
Data: data,
Gen: gen,
}
if !e.workers[index].Feed(wp) {
select {
case e.overflowCh <- wp:
e.stats.OverflowEvents.Add(1)
switch e.overflowPolicy {
case OverflowPolicyDrop:
e.stats.OverflowDrops.Add(1)
_ = e.io.SetVerdict(p, io.VerdictDrop, nil)
case OverflowPolicyBackpressure:
e.stats.OverflowBackpressureEvents.Add(1)
e.workers[index].FeedBlocking(wp)
default:
e.stats.OverflowAccepts.Add(1)
_ = e.io.SetVerdict(p, io.VerdictAccept, nil)
}
}
return true
}
func (e *engine) applyWorkerResult(r workerResult) {
if r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream {
e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen})
}
_ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket)
}
func validPacket(data []byte) bool {
if len(data) == 0 {
return false
@@ -156,13 +186,26 @@ func validPacket(data []byte) bool {
return false
}
func (e *engine) drainOverflow(ctx context.Context) {
func (e *engine) drainResults(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case wp := <-e.overflowCh:
_ = wp.SetVerdict(io.VerdictAccept, nil)
case r := <-e.resultCh:
e.applyWorkerResult(r)
}
}
}
func (e *engine) Stats() Stats {
return Stats{
OverflowEvents: e.stats.OverflowEvents.Load(),
OverflowAccepts: e.stats.OverflowAccepts.Load(),
OverflowDrops: e.stats.OverflowDrops.Load(),
OverflowBackpressureEvents: e.stats.OverflowBackpressureEvents.Load(),
AnalyzerSelectionsTotal: e.stats.AnalyzerSelectionsTotal.Load(),
AnalyzerSelectionsPruned: e.stats.AnalyzerSelectionsPruned.Load(),
UDPTupleLookups: e.stats.UDPTupleLookups.Load(),
UDPTupleHits: e.stats.UDPTupleHits.Load(),
}
}
+46
View File
@@ -2,6 +2,7 @@ package engine
import (
"context"
"sync/atomic"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
@@ -13,6 +14,49 @@ type Engine interface {
UpdateRuleset(ruleset.Ruleset) error
// Run runs the engine, until an error occurs or the context is cancelled.
Run(context.Context) error
// Stats returns a consistent snapshot of runtime counters.
Stats() Stats
}
type OverflowPolicy string
const (
OverflowPolicyAccept OverflowPolicy = "accept"
OverflowPolicyDrop OverflowPolicy = "drop"
OverflowPolicyBackpressure OverflowPolicy = "backpressure"
)
type AnalyzerSelectionMode string
const (
AnalyzerSelectionModeAlways AnalyzerSelectionMode = "always"
AnalyzerSelectionModeSignature AnalyzerSelectionMode = "signature"
)
type statsCounters struct {
OverflowEvents atomic.Uint64
OverflowAccepts atomic.Uint64
OverflowDrops atomic.Uint64
OverflowBackpressureEvents atomic.Uint64
AnalyzerSelectionsTotal atomic.Uint64
AnalyzerSelectionsPruned atomic.Uint64
UDPTupleLookups atomic.Uint64
UDPTupleHits atomic.Uint64
}
type Stats struct {
OverflowEvents uint64
OverflowAccepts uint64
OverflowDrops uint64
OverflowBackpressureEvents uint64
AnalyzerSelectionsTotal uint64
AnalyzerSelectionsPruned uint64
UDPTupleLookups uint64
UDPTupleHits uint64
}
// Config is the configuration for the engine.
@@ -26,6 +70,8 @@ type Config struct {
WorkerTCPMaxBufferedPagesTotal int
WorkerTCPMaxBufferedPagesPerConn int
WorkerUDPMaxStreams int
OverflowPolicy OverflowPolicy
AnalyzerSelectionMode AnalyzerSelectionMode
}
// Logger is the combined logging interface for the engine, workers and analyzers.
+3
View File
@@ -1,3 +1,6 @@
//go:build linux
// +build linux
package engine
import (
+17
View File
@@ -0,0 +1,17 @@
//go:build !linux
// +build !linux
package engine
import "net"
type sourceMACResolver struct{}
func newSourceMACResolver() *sourceMACResolver {
return &sourceMACResolver{}
}
func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr {
_ = ip
return nil
}
+2 -2
View File
@@ -142,7 +142,7 @@ func TestTCPFlowUsesUpdatedRuleset(t *testing.T) {
if err != nil {
t.Fatalf("create node: %v", err)
}
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node)
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, nil)
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0)
l3 := L3Info{
@@ -180,7 +180,7 @@ func TestTCPFlowReevaluatesAfterRulesetVersionChange(t *testing.T) {
if err != nil {
t.Fatalf("create node: %v", err)
}
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node)
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, nil)
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0)
l3 := L3Info{
+8 -4
View File
@@ -163,15 +163,17 @@ type tcpFlowManager struct {
rulesetSource func() (ruleset.Ruleset, uint64)
workerID int
macResolver *sourceMACResolver
selector *analyzerSelector
}
func newTCPFlowManager(workerID int, logger Logger, macResolver *sourceMACResolver, node *snowflake.Node) *tcpFlowManager {
func newTCPFlowManager(workerID int, logger Logger, macResolver *sourceMACResolver, node *snowflake.Node, selector *analyzerSelector) *tcpFlowManager {
return &tcpFlowManager{
flows: make(map[uint32]*tcpFlow),
sfNode: node,
logger: logger,
workerID: workerID,
macResolver: macResolver,
selector: selector,
}
}
@@ -179,7 +181,7 @@ func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload
m.mu.Lock()
flow, ok := m.flows[streamID]
if !ok {
flow = m.createFlow(streamID, l3, tcp, srcMAC, dstMAC)
flow = m.createFlow(streamID, l3, tcp, payload, srcMAC, dstMAC)
m.flows[streamID] = flow
}
m.mu.Unlock()
@@ -195,7 +197,7 @@ func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload
return verdict
}
func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, srcMAC, dstMAC net.HardwareAddr) *tcpFlow {
func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) *tcpFlow {
id := m.sfNode.Generate()
ipSrc := net.IP(l3.SrcIP[:])
ipDst := net.IP(l3.DstIP[:])
@@ -217,7 +219,9 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, src
rs, version := m.rulesetSource()
var ans []analyzer.TCPAnalyzer
if rs != nil {
ans = analyzersToTCPAnalyzers(rs.Analyzers(info))
baseAns := rs.Analyzers(info)
baseAns = m.selector.SelectTCP(baseAns, payload)
ans = analyzersToTCPAnalyzers(baseAns)
}
entries := make([]*tcpFlowEntry, 0, len(ans))
for _, a := range ans {
+109 -21
View File
@@ -1,6 +1,7 @@
package engine
import (
"bytes"
"errors"
"net"
"sync"
@@ -40,6 +41,8 @@ type udpStreamFactory struct {
WorkerID int
Logger Logger
Node *snowflake.Node
Selector *analyzerSelector
Stats *statsCounters
RulesetMutex sync.RWMutex
Ruleset ruleset.Ruleset
@@ -64,7 +67,11 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
rs, version := f.currentRuleset()
var ans []analyzer.UDPAnalyzer
if rs != nil {
ans = analyzersToUDPAnalyzers(rs.Analyzers(info))
baseAns := rs.Analyzers(info)
if f.Selector != nil {
baseAns = f.Selector.SelectUDP(baseAns, udp.Payload)
}
ans = analyzersToUDPAnalyzers(baseAns)
}
// Create entries for each analyzer
entries := make([]*udpStreamEntry, 0, len(ans))
@@ -110,8 +117,11 @@ func (f *udpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) {
}
type udpStreamManager struct {
factory *udpStreamFactory
streams *lru.Cache[uint32, *udpStreamValue]
factory *udpStreamFactory
streams *lru.Cache[uint32, *udpStreamValue]
tupleIndex map[udpTupleKey]uint32
streamTuples map[uint32]udpTupleKey
stats *statsCounters
}
type udpStreamValue struct {
@@ -120,36 +130,71 @@ type udpStreamValue struct {
UDPFlow gopacket.Flow
}
type udpTupleKey struct {
AIP [16]byte
BIP [16]byte
ALen uint8
BLen uint8
APort uint16
BPort uint16
}
func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) {
fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow
rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()
return fwd || rev, rev
}
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int) (*udpStreamManager, error) {
ss, err := lru.New[uint32, *udpStreamValue](maxStreams)
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) {
m := &udpStreamManager{
factory: factory,
tupleIndex: make(map[udpTupleKey]uint32, maxStreams),
streamTuples: make(map[uint32]udpTupleKey, maxStreams),
stats: stats,
}
ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) {
m.removeTupleMappingLocked(k)
})
if err != nil {
return nil, err
}
return &udpStreamManager{
factory: factory,
streams: ss,
}, nil
m.streams = ss
return m, nil
}
func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) {
rev := false
value, ok := m.streams.Get(streamID)
tuple := canonicalUDPTupleKey(ipFlow, udp)
if !ok {
// Fallback: conntrack IDs can change during early flow lifetime on some systems.
// Try to find an existing stream by 5-tuple before creating a new stream.
matchedKey, matchedValue, matchedRev, found := m.findByFlow(ipFlow, udp.TransportFlow())
if m.stats != nil {
m.stats.UDPTupleLookups.Add(1)
}
// Conntrack IDs can change during early flow lifetime on some systems.
// Rebind by canonical 5-tuple in O(1).
matchedKey, found := m.tupleIndex[tuple]
var matchedValue *udpStreamValue
var matchedRev bool
if found {
if m.stats != nil {
m.stats.UDPTupleHits.Add(1)
}
var hasValue bool
matchedValue, hasValue = m.streams.Get(matchedKey)
if !hasValue || matchedValue == nil {
delete(m.tupleIndex, tuple)
delete(m.streamTuples, matchedKey)
found = false
}
}
if found {
_, matchedRev = matchedValue.Match(ipFlow, udp.TransportFlow())
value = matchedValue
rev = matchedRev
if matchedKey != streamID {
m.streams.Remove(matchedKey)
m.streams.Add(streamID, matchedValue)
m.bindTupleLocked(streamID, tuple)
}
} else {
// New stream
@@ -159,6 +204,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
UDPFlow: udp.TransportFlow(),
}
m.streams.Add(streamID, value)
m.bindTupleLocked(streamID, tuple)
}
} else {
// Stream ID exists, but is it really the same stream?
@@ -172,6 +218,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
UDPFlow: udp.TransportFlow(),
}
m.streams.Add(streamID, value)
m.bindTupleLocked(streamID, tuple)
}
}
if value.Stream.Accept(udp, rev, uc) {
@@ -179,17 +226,58 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
}
}
func (m *udpStreamManager) findByFlow(ipFlow, udpFlow gopacket.Flow) (key uint32, value *udpStreamValue, rev bool, found bool) {
for _, k := range m.streams.Keys() {
v, ok := m.streams.Peek(k)
if !ok || v == nil {
continue
}
if ok2, rev2 := v.Match(ipFlow, udpFlow); ok2 {
return k, v, rev2, true
func (m *udpStreamManager) bindTupleLocked(streamID uint32, key udpTupleKey) {
m.removeTupleMappingLocked(streamID)
m.tupleIndex[key] = streamID
m.streamTuples[streamID] = key
}
func (m *udpStreamManager) removeTupleMappingLocked(streamID uint32) {
if key, ok := m.streamTuples[streamID]; ok {
delete(m.streamTuples, streamID)
current, exists := m.tupleIndex[key]
if exists && current == streamID {
delete(m.tupleIndex, key)
}
}
return 0, nil, false, false
}
func canonicalUDPTupleKey(ipFlow gopacket.Flow, udp *layers.UDP) udpTupleKey {
srcIP := ipFlow.Src().Raw()
dstIP := ipFlow.Dst().Raw()
srcPort := uint16(udp.SrcPort)
dstPort := uint16(udp.DstPort)
if compareIPEndpoint(srcIP, srcPort, dstIP, dstPort) > 0 {
srcIP, dstIP = dstIP, srcIP
srcPort, dstPort = dstPort, srcPort
}
var key udpTupleKey
key.ALen = uint8(copy(key.AIP[:], srcIP))
key.BLen = uint8(copy(key.BIP[:], dstIP))
key.APort = srcPort
key.BPort = dstPort
return key
}
func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int {
if len(aIP) != len(bIP) {
if len(aIP) < len(bIP) {
return -1
}
return 1
}
if c := bytes.Compare(aIP, bIP); c != 0 {
return c
}
if aPort < bPort {
return -1
}
if aPort > bPort {
return 1
}
return 0
}
type udpStream struct {
+122
View File
@@ -0,0 +1,122 @@
package engine
import (
"net"
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
type legacyUDPStreamValue struct {
IPFlow gopacket.Flow
UDPFlow gopacket.Flow
}
type emptyRuleset struct{}
func (emptyRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer { return nil }
func (emptyRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult {
return ruleset.MatchResult{Action: ruleset.ActionMaybe}
}
func benchmarkUDPManager(b *testing.B, churn bool) {
node, err := snowflake.NewNode(0)
if err != nil {
b.Fatalf("create node: %v", err)
}
factory := &udpStreamFactory{WorkerID: 0, Logger: noopTestLogger{}, Node: node, Ruleset: emptyRuleset{}}
mgr, err := newUDPStreamManager(factory, 200000, &statsCounters{})
if err != nil {
b.Fatalf("new manager: %v", err)
}
const flowCount = 20000
flows := make([]gopacket.Flow, flowCount)
udps := make([]*layers.UDP, flowCount)
for i := 0; i < flowCount; i++ {
a := byte(i >> 8)
c := byte(i)
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4())
udps[i] = &layers.UDP{
SrcPort: layers.UDPPort(1024 + i%20000),
DstPort: layers.UDPPort(20000 + (i*7)%20000),
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}},
}
}
ctx := &udpContext{Verdict: udpVerdictAccept}
b.ResetTimer()
for i := 0; i < b.N; i++ {
idx := i % flowCount
streamID := uint32(idx + 1)
if churn {
streamID = uint32((i % flowCount) + 1 + ((i / flowCount) * flowCount))
}
ctx.Verdict = udpVerdictAccept
ctx.Packet = nil
mgr.MatchWithContext(streamID, flows[idx], udps[idx], ctx)
}
}
func BenchmarkUDPManagerMatchStableStreamID(b *testing.B) {
benchmarkUDPManager(b, false)
}
func BenchmarkUDPManagerMatchStreamIDChurn(b *testing.B) {
benchmarkUDPManager(b, true)
}
func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) {
const flowCount = 5000
flows := make([]gopacket.Flow, flowCount)
udps := make([]*layers.UDP, flowCount)
for i := 0; i < flowCount; i++ {
a := byte(i >> 8)
c := byte(i)
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4())
udps[i] = &layers.UDP{
SrcPort: layers.UDPPort(1024 + i%20000),
DstPort: layers.UDPPort(20000 + (i*7)%20000),
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}},
}
}
streams := make(map[uint32]*legacyUDPStreamValue, flowCount)
keys := make([]uint32, 0, flowCount)
for i := 0; i < flowCount; i++ {
streamID := uint32(i + 1)
streams[streamID] = &legacyUDPStreamValue{
IPFlow: flows[i],
UDPFlow: udps[i].TransportFlow(),
}
keys = append(keys, streamID)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
idx := i % flowCount
streamID := uint32((i % flowCount) + 1 + ((i / flowCount) * flowCount))
if _, ok := streams[streamID]; ok {
continue
}
ipFlow := flows[idx]
udpFlow := udps[idx].TransportFlow()
for _, k := range keys {
v, ok := streams[k]
if !ok || v == nil {
continue
}
if (v.IPFlow == ipFlow && v.UDPFlow == udpFlow) ||
(v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()) {
delete(streams, k)
streams[streamID] = v
break
}
}
}
}
+71
View File
@@ -0,0 +1,71 @@
package engine
import (
"net"
"sync/atomic"
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
type countingRuleset struct {
ans []analyzer.Analyzer
}
func (r countingRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer { return r.ans }
func (r countingRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult {
return ruleset.MatchResult{Action: ruleset.ActionMaybe}
}
type countingUDPAnalyzer struct{ newCalls *atomic.Uint64 }
func (a countingUDPAnalyzer) Name() string { return "countudp" }
func (a countingUDPAnalyzer) Limit() int { return 0 }
func (a countingUDPAnalyzer) NewUDP(analyzer.UDPInfo, analyzer.Logger) analyzer.UDPStream {
a.newCalls.Add(1)
return countingUDPStream{}
}
type countingUDPStream struct{}
func (countingUDPStream) Feed(bool, []byte) (*analyzer.PropUpdate, bool) { return nil, false }
func (countingUDPStream) Close(bool) *analyzer.PropUpdate { return nil }
func TestUDPStreamManagerRebindsByTupleInO1Path(t *testing.T) {
node, err := snowflake.NewNode(0)
if err != nil {
t.Fatalf("create node: %v", err)
}
var newCalls atomic.Uint64
rs := countingRuleset{ans: []analyzer.Analyzer{countingUDPAnalyzer{newCalls: &newCalls}}}
factory := &udpStreamFactory{
WorkerID: 0,
Logger: noopTestLogger{},
Node: node,
Ruleset: rs,
}
mgr, err := newUDPStreamManager(factory, 64, &statsCounters{})
if err != nil {
t.Fatalf("new manager: %v", err)
}
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
udp := &layers.UDP{SrcPort: 50000, DstPort: 443, BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}}
ctx1 := &udpContext{Verdict: udpVerdictAccept}
mgr.MatchWithContext(100, ipFlow, udp, ctx1)
if got := newCalls.Load(); got != 1 {
t.Fatalf("new stream calls=%d want=1", got)
}
ctx2 := &udpContext{Verdict: udpVerdictAccept}
mgr.MatchWithContext(200, ipFlow, udp, ctx2)
if got := newCalls.Load(); got != 1 {
t.Fatalf("expected stream reuse by tuple, new stream calls=%d want=1", got)
}
}
+38 -14
View File
@@ -12,24 +12,32 @@ import (
"github.com/google/gopacket/layers"
)
var _ Engine = (*engine)(nil)
type workerPacket struct {
StreamID uint32
Data []byte
SrcMAC net.HardwareAddr
DstMAC net.HardwareAddr
SetVerdict func(io.Verdict, []byte) error
Packet io.Packet
StreamID uint32
Data []byte
SrcMAC net.HardwareAddr
DstMAC net.HardwareAddr
Gen int64
}
type workerResult struct {
Packet io.Packet
StreamID uint32
Verdict io.Verdict
ModifiedPacket []byte
Gen int64
}
type worker struct {
id int
packetChan chan *workerPacket
resultChan chan workerResult
logger Logger
macResolver *sourceMACResolver
tcpFlowMgr *tcpFlowManager
udpSM *udpStreamManager
tcpFlowMgr *tcpFlowManager
udpSM *udpStreamManager
modSerializeBuffer gopacket.SerializeBuffer
}
@@ -43,6 +51,9 @@ type workerConfig struct {
TCPMaxBufferedPagesTotal int // unused, kept for config compat
TCPMaxBufferedPagesPerConn int // unused, kept for config compat
UDPMaxStreams int
AnalyzerSelectionMode AnalyzerSelectionMode
ResultChan chan workerResult
Stats *statsCounters
}
func (c *workerConfig) fillDefaults() {
@@ -61,7 +72,8 @@ func newWorker(config workerConfig) (*worker, error) {
return nil, err
}
tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode)
selector := newAnalyzerSelector(config.AnalyzerSelectionMode, config.Stats)
tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode, selector)
if config.Ruleset != nil {
tcpMgr.updateRuleset(config.Ruleset, 0)
}
@@ -71,8 +83,10 @@ func newWorker(config workerConfig) (*worker, error) {
Logger: config.Logger,
Node: sfNode,
Ruleset: config.Ruleset,
Selector: selector,
Stats: config.Stats,
}
udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams)
udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams, config.Stats)
if err != nil {
return nil, err
}
@@ -80,6 +94,7 @@ func newWorker(config workerConfig) (*worker, error) {
return &worker{
id: config.ID,
packetChan: make(chan *workerPacket, config.ChanSize),
resultChan: config.ResultChan,
logger: config.Logger,
macResolver: config.MACResolver,
tcpFlowMgr: tcpMgr,
@@ -97,6 +112,10 @@ func (w *worker) Feed(p *workerPacket) bool {
}
}
func (w *worker) FeedBlocking(p *workerPacket) {
w.packetChan <- p
}
func (w *worker) Run(ctx context.Context) {
w.logger.WorkerStart(w.id)
defer w.logger.WorkerStop(w.id)
@@ -109,7 +128,13 @@ func (w *worker) Run(ctx context.Context) {
return
}
v, b := w.handle(wp)
_ = wp.SetVerdict(v, b)
w.resultChan <- workerResult{
Packet: wp.Packet,
StreamID: wp.StreamID,
Verdict: v,
ModifiedPacket: b,
Gen: wp.Gen,
}
}
}
}
@@ -185,8 +210,7 @@ func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []by
SrcMAC: srcMAC,
DstMAC: dstMAC,
}
// Temporarily set payload on a UDP layer so existing UDP handling works
// We pass the payload through the context
// Temporarily set payload on a UDP layer so existing UDP handling works.
w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{
BaseLayer: layers.BaseLayer{Payload: payload},
SrcPort: layers.UDPPort(udp.SrcPort),