Improves flow handling and adds runtime stats APIs
Refactors TCP and UDP flow managers to enhance analyzer selection and flow binding accuracy, including O(1) UDP stream rebinding by 5-tuple. Introduces runtime stats tracking for engine and ruleset operations, exposing new APIs for granular performance and error metrics. Optimizes GeoMatcher with result caching and supports efficient geosite set matching, reducing redundant computation in ruleset expressions.
This commit is contained in:
@@ -0,0 +1,235 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||
)
|
||||
|
||||
type analyzerSelector struct {
|
||||
mode AnalyzerSelectionMode
|
||||
stats *statsCounters
|
||||
}
|
||||
|
||||
func newAnalyzerSelector(mode AnalyzerSelectionMode, stats *statsCounters) *analyzerSelector {
|
||||
if mode == "" {
|
||||
mode = AnalyzerSelectionModeSignature
|
||||
}
|
||||
return &analyzerSelector{mode: mode, stats: stats}
|
||||
}
|
||||
|
||||
func (s *analyzerSelector) SelectTCP(ans []analyzer.Analyzer, payload []byte) []analyzer.Analyzer {
|
||||
if s == nil || s.mode == AnalyzerSelectionModeAlways || len(ans) <= 1 {
|
||||
return ans
|
||||
}
|
||||
allowed := tcpAllowedAnalyzers(payload)
|
||||
if len(allowed) == 0 {
|
||||
return ans
|
||||
}
|
||||
out := make([]analyzer.Analyzer, 0, len(ans))
|
||||
for _, a := range ans {
|
||||
name := strings.ToLower(a.Name())
|
||||
if _, known := knownTCPAnalyzers[name]; !known {
|
||||
out = append(out, a)
|
||||
continue
|
||||
}
|
||||
if allowed[name] {
|
||||
out = append(out, a)
|
||||
}
|
||||
}
|
||||
s.recordSelection(len(ans), len(out))
|
||||
if len(out) == 0 {
|
||||
return ans
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *analyzerSelector) SelectUDP(ans []analyzer.Analyzer, payload []byte) []analyzer.Analyzer {
|
||||
if s == nil || s.mode == AnalyzerSelectionModeAlways || len(ans) <= 1 {
|
||||
return ans
|
||||
}
|
||||
allowed := udpAllowedAnalyzers(payload)
|
||||
if len(allowed) == 0 {
|
||||
return ans
|
||||
}
|
||||
out := make([]analyzer.Analyzer, 0, len(ans))
|
||||
for _, a := range ans {
|
||||
name := strings.ToLower(a.Name())
|
||||
if _, known := knownUDPAnalyzers[name]; !known {
|
||||
out = append(out, a)
|
||||
continue
|
||||
}
|
||||
if allowed[name] {
|
||||
out = append(out, a)
|
||||
}
|
||||
}
|
||||
s.recordSelection(len(ans), len(out))
|
||||
if len(out) == 0 {
|
||||
return ans
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *analyzerSelector) recordSelection(total, selected int) {
|
||||
if s == nil || s.stats == nil || total <= 0 {
|
||||
return
|
||||
}
|
||||
s.stats.AnalyzerSelectionsTotal.Add(1)
|
||||
if selected < total {
|
||||
s.stats.AnalyzerSelectionsPruned.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
knownTCPAnalyzers = map[string]struct{}{
|
||||
"fet": {},
|
||||
"http": {},
|
||||
"socks": {},
|
||||
"ssh": {},
|
||||
"tls": {},
|
||||
"trojan": {},
|
||||
"dns": {},
|
||||
"openvpn": {},
|
||||
}
|
||||
knownUDPAnalyzers = map[string]struct{}{
|
||||
"dns": {},
|
||||
"openvpn": {},
|
||||
"quic": {},
|
||||
"wireguard": {},
|
||||
}
|
||||
)
|
||||
|
||||
func tcpAllowedAnalyzers(payload []byte) map[string]bool {
|
||||
allowed := make(map[string]bool, 4)
|
||||
if looksLikeTLS(payload) {
|
||||
allowed["tls"] = true
|
||||
allowed["trojan"] = true
|
||||
allowed["fet"] = true
|
||||
}
|
||||
if looksLikeHTTP(payload) {
|
||||
allowed["http"] = true
|
||||
allowed["fet"] = true
|
||||
}
|
||||
if looksLikeSSH(payload) {
|
||||
allowed["ssh"] = true
|
||||
allowed["fet"] = true
|
||||
}
|
||||
if looksLikeSOCKS(payload) {
|
||||
allowed["socks"] = true
|
||||
allowed["fet"] = true
|
||||
}
|
||||
if looksLikeDNSTCP(payload) {
|
||||
allowed["dns"] = true
|
||||
allowed["fet"] = true
|
||||
}
|
||||
if len(allowed) == 0 {
|
||||
return nil
|
||||
}
|
||||
return allowed
|
||||
}
|
||||
|
||||
func udpAllowedAnalyzers(payload []byte) map[string]bool {
|
||||
allowed := make(map[string]bool, 4)
|
||||
if looksLikeWireGuard(payload) {
|
||||
allowed["wireguard"] = true
|
||||
}
|
||||
if looksLikeOpenVPN(payload) {
|
||||
allowed["openvpn"] = true
|
||||
}
|
||||
if looksLikeQUIC(payload) {
|
||||
allowed["quic"] = true
|
||||
}
|
||||
if looksLikeDNSUDP(payload) {
|
||||
allowed["dns"] = true
|
||||
}
|
||||
if len(allowed) == 0 {
|
||||
return nil
|
||||
}
|
||||
return allowed
|
||||
}
|
||||
|
||||
func looksLikeTLS(payload []byte) bool {
|
||||
if len(payload) < 3 {
|
||||
return false
|
||||
}
|
||||
return (payload[0] == 0x16 || payload[0] == 0x17) && payload[1] == 0x03 && payload[2] <= 0x09
|
||||
}
|
||||
|
||||
func looksLikeHTTP(payload []byte) bool {
|
||||
if len(payload) < 3 {
|
||||
return false
|
||||
}
|
||||
head := strings.ToUpper(string(payload[:3]))
|
||||
switch head {
|
||||
case "GET", "HEA", "POS", "PUT", "DEL", "CON", "OPT", "TRA", "PAT":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func looksLikeSSH(payload []byte) bool {
|
||||
return len(payload) >= 4 && bytes.HasPrefix(payload, []byte("SSH-"))
|
||||
}
|
||||
|
||||
func looksLikeSOCKS(payload []byte) bool {
|
||||
if len(payload) < 2 {
|
||||
return false
|
||||
}
|
||||
return payload[0] == 0x04 || payload[0] == 0x05
|
||||
}
|
||||
|
||||
func looksLikeDNSTCP(payload []byte) bool {
|
||||
if len(payload) < 14 {
|
||||
return false
|
||||
}
|
||||
msgLen := int(payload[0])<<8 | int(payload[1])
|
||||
if msgLen <= 0 || msgLen+2 > len(payload) {
|
||||
return false
|
||||
}
|
||||
qd := int(payload[6])<<8 | int(payload[7])
|
||||
an := int(payload[8])<<8 | int(payload[9])
|
||||
return qd+an > 0
|
||||
}
|
||||
|
||||
func looksLikeDNSUDP(payload []byte) bool {
|
||||
if len(payload) < 12 {
|
||||
return false
|
||||
}
|
||||
qd := int(payload[4])<<8 | int(payload[5])
|
||||
an := int(payload[6])<<8 | int(payload[7])
|
||||
ns := int(payload[8])<<8 | int(payload[9])
|
||||
ar := int(payload[10])<<8 | int(payload[11])
|
||||
return qd+an+ns+ar > 0
|
||||
}
|
||||
|
||||
func looksLikeQUIC(payload []byte) bool {
|
||||
if len(payload) < 6 {
|
||||
return false
|
||||
}
|
||||
// Long header with non-zero version.
|
||||
if payload[0]&0x80 == 0 {
|
||||
return false
|
||||
}
|
||||
version := uint32(payload[1])<<24 | uint32(payload[2])<<16 | uint32(payload[3])<<8 | uint32(payload[4])
|
||||
return version != 0
|
||||
}
|
||||
|
||||
func looksLikeOpenVPN(payload []byte) bool {
|
||||
if len(payload) == 0 {
|
||||
return false
|
||||
}
|
||||
opcode := payload[0] >> 3
|
||||
return opcode >= 1 && opcode <= 11
|
||||
}
|
||||
|
||||
func looksLikeWireGuard(payload []byte) bool {
|
||||
if len(payload) < 4 {
|
||||
return false
|
||||
}
|
||||
if payload[0] < 1 || payload[0] > 4 {
|
||||
return false
|
||||
}
|
||||
return payload[1] == 0 && payload[2] == 0 && payload[3] == 0
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||
)
|
||||
|
||||
type namedAnalyzer struct{ name string }
|
||||
|
||||
func (a namedAnalyzer) Name() string { return a.name }
|
||||
func (a namedAnalyzer) Limit() int { return 0 }
|
||||
|
||||
func TestSignatureSelectorTCPPrunesByPayloadNotPort(t *testing.T) {
|
||||
sel := newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{})
|
||||
all := []analyzer.Analyzer{
|
||||
namedAnalyzer{"http"},
|
||||
namedAnalyzer{"tls"},
|
||||
namedAnalyzer{"trojan"},
|
||||
namedAnalyzer{"ssh"},
|
||||
namedAnalyzer{"socks"},
|
||||
namedAnalyzer{"fet"},
|
||||
}
|
||||
// TLS record-like prefix, regardless of destination port.
|
||||
payload := []byte{0x16, 0x03, 0x03, 0x00, 0x10}
|
||||
selected := sel.SelectTCP(all, payload)
|
||||
|
||||
got := make(map[string]bool)
|
||||
for _, a := range selected {
|
||||
got[a.Name()] = true
|
||||
}
|
||||
for _, keep := range []string{"tls", "trojan", "fet"} {
|
||||
if !got[keep] {
|
||||
t.Fatalf("expected analyzer %q to be selected", keep)
|
||||
}
|
||||
}
|
||||
for _, drop := range []string{"http", "ssh", "socks"} {
|
||||
if got[drop] {
|
||||
t.Fatalf("expected analyzer %q to be pruned", drop)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignatureSelectorConservativeFallback(t *testing.T) {
|
||||
sel := newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{})
|
||||
all := []analyzer.Analyzer{
|
||||
namedAnalyzer{"http"},
|
||||
namedAnalyzer{"tls"},
|
||||
namedAnalyzer{"custom"},
|
||||
}
|
||||
payload := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||
selected := sel.SelectTCP(all, payload)
|
||||
if len(selected) != len(all) {
|
||||
t.Fatalf("expected conservative fallback to keep all analyzers, got=%d want=%d", len(selected), len(all))
|
||||
}
|
||||
}
|
||||
+68
-25
@@ -2,6 +2,7 @@ package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
@@ -20,18 +21,34 @@ type engine struct {
|
||||
logger Logger
|
||||
io io.PacketIO
|
||||
workers []*worker
|
||||
verdicts sync.Map // streamID(uint32) → verdictEntry
|
||||
stats *statsCounters
|
||||
verdicts sync.Map // streamID(uint32) -> verdictEntry
|
||||
verdictsGen atomic.Int64 // incremented on ruleset update
|
||||
|
||||
overflowCh chan *workerPacket
|
||||
overflowOnce sync.Once
|
||||
overflowPolicy OverflowPolicy
|
||||
resultCh chan workerResult
|
||||
}
|
||||
|
||||
func NewEngine(config Config) (Engine, error) {
|
||||
workerCount := config.Workers
|
||||
if workerCount <= 0 {
|
||||
workerCount = 1
|
||||
workerCount = runtime.GOMAXPROCS(0)
|
||||
if workerCount <= 0 {
|
||||
workerCount = 1
|
||||
}
|
||||
}
|
||||
overflowPolicy := config.OverflowPolicy
|
||||
if overflowPolicy == "" {
|
||||
overflowPolicy = OverflowPolicyAccept
|
||||
}
|
||||
selectionMode := config.AnalyzerSelectionMode
|
||||
if selectionMode == "" {
|
||||
selectionMode = AnalyzerSelectionModeSignature
|
||||
}
|
||||
|
||||
stats := &statsCounters{}
|
||||
resultCh := make(chan workerResult, workerCount*256)
|
||||
|
||||
macResolver := newSourceMACResolver()
|
||||
var err error
|
||||
workers := make([]*worker, workerCount)
|
||||
@@ -45,16 +62,21 @@ func NewEngine(config Config) (Engine, error) {
|
||||
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
|
||||
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
|
||||
UDPMaxStreams: config.WorkerUDPMaxStreams,
|
||||
AnalyzerSelectionMode: selectionMode,
|
||||
ResultChan: resultCh,
|
||||
Stats: stats,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
e := &engine{
|
||||
logger: config.Logger,
|
||||
io: config.IO,
|
||||
workers: workers,
|
||||
overflowCh: make(chan *workerPacket, 1024),
|
||||
logger: config.Logger,
|
||||
io: config.IO,
|
||||
workers: workers,
|
||||
stats: stats,
|
||||
overflowPolicy: overflowPolicy,
|
||||
resultCh: resultCh,
|
||||
}
|
||||
return e, nil
|
||||
}
|
||||
@@ -74,13 +96,10 @@ func (e *engine) Run(ctx context.Context) error {
|
||||
ioCtx, ioCancel := context.WithCancel(ctx)
|
||||
defer ioCancel()
|
||||
|
||||
e.overflowOnce.Do(func() {
|
||||
go e.drainOverflow(ioCtx)
|
||||
})
|
||||
|
||||
for _, w := range e.workers {
|
||||
go w.Run(ioCtx)
|
||||
}
|
||||
go e.drainResults(ioCtx)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
|
||||
@@ -121,24 +140,35 @@ func (e *engine) dispatch(p io.Packet) bool {
|
||||
gen := e.verdictsGen.Load()
|
||||
index := streamID % uint32(len(e.workers))
|
||||
wp := &workerPacket{
|
||||
StreamID: streamID,
|
||||
Data: data,
|
||||
SetVerdict: func(v io.Verdict, b []byte) error {
|
||||
if v == io.VerdictAcceptStream || v == io.VerdictDropStream {
|
||||
e.verdicts.Store(streamID, verdictEntry{Verdict: v, Gen: gen})
|
||||
}
|
||||
return e.io.SetVerdict(p, v, b)
|
||||
},
|
||||
Packet: p,
|
||||
StreamID: streamID,
|
||||
Data: data,
|
||||
Gen: gen,
|
||||
}
|
||||
if !e.workers[index].Feed(wp) {
|
||||
select {
|
||||
case e.overflowCh <- wp:
|
||||
e.stats.OverflowEvents.Add(1)
|
||||
switch e.overflowPolicy {
|
||||
case OverflowPolicyDrop:
|
||||
e.stats.OverflowDrops.Add(1)
|
||||
_ = e.io.SetVerdict(p, io.VerdictDrop, nil)
|
||||
case OverflowPolicyBackpressure:
|
||||
e.stats.OverflowBackpressureEvents.Add(1)
|
||||
e.workers[index].FeedBlocking(wp)
|
||||
default:
|
||||
e.stats.OverflowAccepts.Add(1)
|
||||
_ = e.io.SetVerdict(p, io.VerdictAccept, nil)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *engine) applyWorkerResult(r workerResult) {
|
||||
if r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream {
|
||||
e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen})
|
||||
}
|
||||
_ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket)
|
||||
}
|
||||
|
||||
func validPacket(data []byte) bool {
|
||||
if len(data) == 0 {
|
||||
return false
|
||||
@@ -156,13 +186,26 @@ func validPacket(data []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *engine) drainOverflow(ctx context.Context) {
|
||||
func (e *engine) drainResults(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case wp := <-e.overflowCh:
|
||||
_ = wp.SetVerdict(io.VerdictAccept, nil)
|
||||
case r := <-e.resultCh:
|
||||
e.applyWorkerResult(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *engine) Stats() Stats {
|
||||
return Stats{
|
||||
OverflowEvents: e.stats.OverflowEvents.Load(),
|
||||
OverflowAccepts: e.stats.OverflowAccepts.Load(),
|
||||
OverflowDrops: e.stats.OverflowDrops.Load(),
|
||||
OverflowBackpressureEvents: e.stats.OverflowBackpressureEvents.Load(),
|
||||
AnalyzerSelectionsTotal: e.stats.AnalyzerSelectionsTotal.Load(),
|
||||
AnalyzerSelectionsPruned: e.stats.AnalyzerSelectionsPruned.Load(),
|
||||
UDPTupleLookups: e.stats.UDPTupleLookups.Load(),
|
||||
UDPTupleHits: e.stats.UDPTupleHits.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/io"
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||
@@ -13,6 +14,49 @@ type Engine interface {
|
||||
UpdateRuleset(ruleset.Ruleset) error
|
||||
// Run runs the engine, until an error occurs or the context is cancelled.
|
||||
Run(context.Context) error
|
||||
// Stats returns a consistent snapshot of runtime counters.
|
||||
Stats() Stats
|
||||
}
|
||||
|
||||
type OverflowPolicy string
|
||||
|
||||
const (
|
||||
OverflowPolicyAccept OverflowPolicy = "accept"
|
||||
OverflowPolicyDrop OverflowPolicy = "drop"
|
||||
OverflowPolicyBackpressure OverflowPolicy = "backpressure"
|
||||
)
|
||||
|
||||
type AnalyzerSelectionMode string
|
||||
|
||||
const (
|
||||
AnalyzerSelectionModeAlways AnalyzerSelectionMode = "always"
|
||||
AnalyzerSelectionModeSignature AnalyzerSelectionMode = "signature"
|
||||
)
|
||||
|
||||
type statsCounters struct {
|
||||
OverflowEvents atomic.Uint64
|
||||
OverflowAccepts atomic.Uint64
|
||||
OverflowDrops atomic.Uint64
|
||||
OverflowBackpressureEvents atomic.Uint64
|
||||
|
||||
AnalyzerSelectionsTotal atomic.Uint64
|
||||
AnalyzerSelectionsPruned atomic.Uint64
|
||||
|
||||
UDPTupleLookups atomic.Uint64
|
||||
UDPTupleHits atomic.Uint64
|
||||
}
|
||||
|
||||
type Stats struct {
|
||||
OverflowEvents uint64
|
||||
OverflowAccepts uint64
|
||||
OverflowDrops uint64
|
||||
OverflowBackpressureEvents uint64
|
||||
|
||||
AnalyzerSelectionsTotal uint64
|
||||
AnalyzerSelectionsPruned uint64
|
||||
|
||||
UDPTupleLookups uint64
|
||||
UDPTupleHits uint64
|
||||
}
|
||||
|
||||
// Config is the configuration for the engine.
|
||||
@@ -26,6 +70,8 @@ type Config struct {
|
||||
WorkerTCPMaxBufferedPagesTotal int
|
||||
WorkerTCPMaxBufferedPagesPerConn int
|
||||
WorkerUDPMaxStreams int
|
||||
OverflowPolicy OverflowPolicy
|
||||
AnalyzerSelectionMode AnalyzerSelectionMode
|
||||
}
|
||||
|
||||
// Logger is the combined logging interface for the engine, workers and analyzers.
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package engine
|
||||
|
||||
import (
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package engine
|
||||
|
||||
import "net"
|
||||
|
||||
type sourceMACResolver struct{}
|
||||
|
||||
func newSourceMACResolver() *sourceMACResolver {
|
||||
return &sourceMACResolver{}
|
||||
}
|
||||
|
||||
func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr {
|
||||
_ = ip
|
||||
return nil
|
||||
}
|
||||
@@ -142,7 +142,7 @@ func TestTCPFlowUsesUpdatedRuleset(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("create node: %v", err)
|
||||
}
|
||||
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node)
|
||||
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, nil)
|
||||
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0)
|
||||
|
||||
l3 := L3Info{
|
||||
@@ -180,7 +180,7 @@ func TestTCPFlowReevaluatesAfterRulesetVersionChange(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("create node: %v", err)
|
||||
}
|
||||
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node)
|
||||
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, nil)
|
||||
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0)
|
||||
|
||||
l3 := L3Info{
|
||||
|
||||
+8
-4
@@ -163,15 +163,17 @@ type tcpFlowManager struct {
|
||||
rulesetSource func() (ruleset.Ruleset, uint64)
|
||||
workerID int
|
||||
macResolver *sourceMACResolver
|
||||
selector *analyzerSelector
|
||||
}
|
||||
|
||||
func newTCPFlowManager(workerID int, logger Logger, macResolver *sourceMACResolver, node *snowflake.Node) *tcpFlowManager {
|
||||
func newTCPFlowManager(workerID int, logger Logger, macResolver *sourceMACResolver, node *snowflake.Node, selector *analyzerSelector) *tcpFlowManager {
|
||||
return &tcpFlowManager{
|
||||
flows: make(map[uint32]*tcpFlow),
|
||||
sfNode: node,
|
||||
logger: logger,
|
||||
workerID: workerID,
|
||||
macResolver: macResolver,
|
||||
selector: selector,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,7 +181,7 @@ func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload
|
||||
m.mu.Lock()
|
||||
flow, ok := m.flows[streamID]
|
||||
if !ok {
|
||||
flow = m.createFlow(streamID, l3, tcp, srcMAC, dstMAC)
|
||||
flow = m.createFlow(streamID, l3, tcp, payload, srcMAC, dstMAC)
|
||||
m.flows[streamID] = flow
|
||||
}
|
||||
m.mu.Unlock()
|
||||
@@ -195,7 +197,7 @@ func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload
|
||||
return verdict
|
||||
}
|
||||
|
||||
func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, srcMAC, dstMAC net.HardwareAddr) *tcpFlow {
|
||||
func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) *tcpFlow {
|
||||
id := m.sfNode.Generate()
|
||||
ipSrc := net.IP(l3.SrcIP[:])
|
||||
ipDst := net.IP(l3.DstIP[:])
|
||||
@@ -217,7 +219,9 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, src
|
||||
rs, version := m.rulesetSource()
|
||||
var ans []analyzer.TCPAnalyzer
|
||||
if rs != nil {
|
||||
ans = analyzersToTCPAnalyzers(rs.Analyzers(info))
|
||||
baseAns := rs.Analyzers(info)
|
||||
baseAns = m.selector.SelectTCP(baseAns, payload)
|
||||
ans = analyzersToTCPAnalyzers(baseAns)
|
||||
}
|
||||
entries := make([]*tcpFlowEntry, 0, len(ans))
|
||||
for _, a := range ans {
|
||||
|
||||
+109
-21
@@ -1,6 +1,7 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
@@ -40,6 +41,8 @@ type udpStreamFactory struct {
|
||||
WorkerID int
|
||||
Logger Logger
|
||||
Node *snowflake.Node
|
||||
Selector *analyzerSelector
|
||||
Stats *statsCounters
|
||||
|
||||
RulesetMutex sync.RWMutex
|
||||
Ruleset ruleset.Ruleset
|
||||
@@ -64,7 +67,11 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
|
||||
rs, version := f.currentRuleset()
|
||||
var ans []analyzer.UDPAnalyzer
|
||||
if rs != nil {
|
||||
ans = analyzersToUDPAnalyzers(rs.Analyzers(info))
|
||||
baseAns := rs.Analyzers(info)
|
||||
if f.Selector != nil {
|
||||
baseAns = f.Selector.SelectUDP(baseAns, udp.Payload)
|
||||
}
|
||||
ans = analyzersToUDPAnalyzers(baseAns)
|
||||
}
|
||||
// Create entries for each analyzer
|
||||
entries := make([]*udpStreamEntry, 0, len(ans))
|
||||
@@ -110,8 +117,11 @@ func (f *udpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) {
|
||||
}
|
||||
|
||||
type udpStreamManager struct {
|
||||
factory *udpStreamFactory
|
||||
streams *lru.Cache[uint32, *udpStreamValue]
|
||||
factory *udpStreamFactory
|
||||
streams *lru.Cache[uint32, *udpStreamValue]
|
||||
tupleIndex map[udpTupleKey]uint32
|
||||
streamTuples map[uint32]udpTupleKey
|
||||
stats *statsCounters
|
||||
}
|
||||
|
||||
type udpStreamValue struct {
|
||||
@@ -120,36 +130,71 @@ type udpStreamValue struct {
|
||||
UDPFlow gopacket.Flow
|
||||
}
|
||||
|
||||
type udpTupleKey struct {
|
||||
AIP [16]byte
|
||||
BIP [16]byte
|
||||
ALen uint8
|
||||
BLen uint8
|
||||
APort uint16
|
||||
BPort uint16
|
||||
}
|
||||
|
||||
func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) {
|
||||
fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow
|
||||
rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()
|
||||
return fwd || rev, rev
|
||||
}
|
||||
|
||||
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int) (*udpStreamManager, error) {
|
||||
ss, err := lru.New[uint32, *udpStreamValue](maxStreams)
|
||||
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) {
|
||||
m := &udpStreamManager{
|
||||
factory: factory,
|
||||
tupleIndex: make(map[udpTupleKey]uint32, maxStreams),
|
||||
streamTuples: make(map[uint32]udpTupleKey, maxStreams),
|
||||
stats: stats,
|
||||
}
|
||||
ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) {
|
||||
m.removeTupleMappingLocked(k)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &udpStreamManager{
|
||||
factory: factory,
|
||||
streams: ss,
|
||||
}, nil
|
||||
m.streams = ss
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) {
|
||||
rev := false
|
||||
value, ok := m.streams.Get(streamID)
|
||||
tuple := canonicalUDPTupleKey(ipFlow, udp)
|
||||
if !ok {
|
||||
// Fallback: conntrack IDs can change during early flow lifetime on some systems.
|
||||
// Try to find an existing stream by 5-tuple before creating a new stream.
|
||||
matchedKey, matchedValue, matchedRev, found := m.findByFlow(ipFlow, udp.TransportFlow())
|
||||
if m.stats != nil {
|
||||
m.stats.UDPTupleLookups.Add(1)
|
||||
}
|
||||
// Conntrack IDs can change during early flow lifetime on some systems.
|
||||
// Rebind by canonical 5-tuple in O(1).
|
||||
matchedKey, found := m.tupleIndex[tuple]
|
||||
var matchedValue *udpStreamValue
|
||||
var matchedRev bool
|
||||
if found {
|
||||
if m.stats != nil {
|
||||
m.stats.UDPTupleHits.Add(1)
|
||||
}
|
||||
var hasValue bool
|
||||
matchedValue, hasValue = m.streams.Get(matchedKey)
|
||||
if !hasValue || matchedValue == nil {
|
||||
delete(m.tupleIndex, tuple)
|
||||
delete(m.streamTuples, matchedKey)
|
||||
found = false
|
||||
}
|
||||
}
|
||||
if found {
|
||||
_, matchedRev = matchedValue.Match(ipFlow, udp.TransportFlow())
|
||||
value = matchedValue
|
||||
rev = matchedRev
|
||||
if matchedKey != streamID {
|
||||
m.streams.Remove(matchedKey)
|
||||
m.streams.Add(streamID, matchedValue)
|
||||
m.bindTupleLocked(streamID, tuple)
|
||||
}
|
||||
} else {
|
||||
// New stream
|
||||
@@ -159,6 +204,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
|
||||
UDPFlow: udp.TransportFlow(),
|
||||
}
|
||||
m.streams.Add(streamID, value)
|
||||
m.bindTupleLocked(streamID, tuple)
|
||||
}
|
||||
} else {
|
||||
// Stream ID exists, but is it really the same stream?
|
||||
@@ -172,6 +218,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
|
||||
UDPFlow: udp.TransportFlow(),
|
||||
}
|
||||
m.streams.Add(streamID, value)
|
||||
m.bindTupleLocked(streamID, tuple)
|
||||
}
|
||||
}
|
||||
if value.Stream.Accept(udp, rev, uc) {
|
||||
@@ -179,17 +226,58 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpStreamManager) findByFlow(ipFlow, udpFlow gopacket.Flow) (key uint32, value *udpStreamValue, rev bool, found bool) {
|
||||
for _, k := range m.streams.Keys() {
|
||||
v, ok := m.streams.Peek(k)
|
||||
if !ok || v == nil {
|
||||
continue
|
||||
}
|
||||
if ok2, rev2 := v.Match(ipFlow, udpFlow); ok2 {
|
||||
return k, v, rev2, true
|
||||
func (m *udpStreamManager) bindTupleLocked(streamID uint32, key udpTupleKey) {
|
||||
m.removeTupleMappingLocked(streamID)
|
||||
m.tupleIndex[key] = streamID
|
||||
m.streamTuples[streamID] = key
|
||||
}
|
||||
|
||||
func (m *udpStreamManager) removeTupleMappingLocked(streamID uint32) {
|
||||
if key, ok := m.streamTuples[streamID]; ok {
|
||||
delete(m.streamTuples, streamID)
|
||||
current, exists := m.tupleIndex[key]
|
||||
if exists && current == streamID {
|
||||
delete(m.tupleIndex, key)
|
||||
}
|
||||
}
|
||||
return 0, nil, false, false
|
||||
}
|
||||
|
||||
func canonicalUDPTupleKey(ipFlow gopacket.Flow, udp *layers.UDP) udpTupleKey {
|
||||
srcIP := ipFlow.Src().Raw()
|
||||
dstIP := ipFlow.Dst().Raw()
|
||||
srcPort := uint16(udp.SrcPort)
|
||||
dstPort := uint16(udp.DstPort)
|
||||
|
||||
if compareIPEndpoint(srcIP, srcPort, dstIP, dstPort) > 0 {
|
||||
srcIP, dstIP = dstIP, srcIP
|
||||
srcPort, dstPort = dstPort, srcPort
|
||||
}
|
||||
|
||||
var key udpTupleKey
|
||||
key.ALen = uint8(copy(key.AIP[:], srcIP))
|
||||
key.BLen = uint8(copy(key.BIP[:], dstIP))
|
||||
key.APort = srcPort
|
||||
key.BPort = dstPort
|
||||
return key
|
||||
}
|
||||
|
||||
func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int {
|
||||
if len(aIP) != len(bIP) {
|
||||
if len(aIP) < len(bIP) {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if c := bytes.Compare(aIP, bIP); c != 0 {
|
||||
return c
|
||||
}
|
||||
if aPort < bPort {
|
||||
return -1
|
||||
}
|
||||
if aPort > bPort {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type udpStream struct {
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||
|
||||
"github.com/bwmarrin/snowflake"
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
type legacyUDPStreamValue struct {
|
||||
IPFlow gopacket.Flow
|
||||
UDPFlow gopacket.Flow
|
||||
}
|
||||
|
||||
type emptyRuleset struct{}
|
||||
|
||||
func (emptyRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer { return nil }
|
||||
func (emptyRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult {
|
||||
return ruleset.MatchResult{Action: ruleset.ActionMaybe}
|
||||
}
|
||||
|
||||
func benchmarkUDPManager(b *testing.B, churn bool) {
|
||||
node, err := snowflake.NewNode(0)
|
||||
if err != nil {
|
||||
b.Fatalf("create node: %v", err)
|
||||
}
|
||||
factory := &udpStreamFactory{WorkerID: 0, Logger: noopTestLogger{}, Node: node, Ruleset: emptyRuleset{}}
|
||||
mgr, err := newUDPStreamManager(factory, 200000, &statsCounters{})
|
||||
if err != nil {
|
||||
b.Fatalf("new manager: %v", err)
|
||||
}
|
||||
|
||||
const flowCount = 20000
|
||||
flows := make([]gopacket.Flow, flowCount)
|
||||
udps := make([]*layers.UDP, flowCount)
|
||||
for i := 0; i < flowCount; i++ {
|
||||
a := byte(i >> 8)
|
||||
c := byte(i)
|
||||
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4())
|
||||
udps[i] = &layers.UDP{
|
||||
SrcPort: layers.UDPPort(1024 + i%20000),
|
||||
DstPort: layers.UDPPort(20000 + (i*7)%20000),
|
||||
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}},
|
||||
}
|
||||
}
|
||||
|
||||
ctx := &udpContext{Verdict: udpVerdictAccept}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
idx := i % flowCount
|
||||
streamID := uint32(idx + 1)
|
||||
if churn {
|
||||
streamID = uint32((i % flowCount) + 1 + ((i / flowCount) * flowCount))
|
||||
}
|
||||
ctx.Verdict = udpVerdictAccept
|
||||
ctx.Packet = nil
|
||||
mgr.MatchWithContext(streamID, flows[idx], udps[idx], ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUDPManagerMatchStableStreamID(b *testing.B) {
|
||||
benchmarkUDPManager(b, false)
|
||||
}
|
||||
|
||||
func BenchmarkUDPManagerMatchStreamIDChurn(b *testing.B) {
|
||||
benchmarkUDPManager(b, true)
|
||||
}
|
||||
|
||||
func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) {
|
||||
const flowCount = 5000
|
||||
flows := make([]gopacket.Flow, flowCount)
|
||||
udps := make([]*layers.UDP, flowCount)
|
||||
for i := 0; i < flowCount; i++ {
|
||||
a := byte(i >> 8)
|
||||
c := byte(i)
|
||||
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4())
|
||||
udps[i] = &layers.UDP{
|
||||
SrcPort: layers.UDPPort(1024 + i%20000),
|
||||
DstPort: layers.UDPPort(20000 + (i*7)%20000),
|
||||
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}},
|
||||
}
|
||||
}
|
||||
|
||||
streams := make(map[uint32]*legacyUDPStreamValue, flowCount)
|
||||
keys := make([]uint32, 0, flowCount)
|
||||
for i := 0; i < flowCount; i++ {
|
||||
streamID := uint32(i + 1)
|
||||
streams[streamID] = &legacyUDPStreamValue{
|
||||
IPFlow: flows[i],
|
||||
UDPFlow: udps[i].TransportFlow(),
|
||||
}
|
||||
keys = append(keys, streamID)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
idx := i % flowCount
|
||||
streamID := uint32((i % flowCount) + 1 + ((i / flowCount) * flowCount))
|
||||
if _, ok := streams[streamID]; ok {
|
||||
continue
|
||||
}
|
||||
ipFlow := flows[idx]
|
||||
udpFlow := udps[idx].TransportFlow()
|
||||
for _, k := range keys {
|
||||
v, ok := streams[k]
|
||||
if !ok || v == nil {
|
||||
continue
|
||||
}
|
||||
if (v.IPFlow == ipFlow && v.UDPFlow == udpFlow) ||
|
||||
(v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()) {
|
||||
delete(streams, k)
|
||||
streams[streamID] = v
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||
|
||||
"github.com/bwmarrin/snowflake"
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
type countingRuleset struct {
|
||||
ans []analyzer.Analyzer
|
||||
}
|
||||
|
||||
func (r countingRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer { return r.ans }
|
||||
func (r countingRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult {
|
||||
return ruleset.MatchResult{Action: ruleset.ActionMaybe}
|
||||
}
|
||||
|
||||
type countingUDPAnalyzer struct{ newCalls *atomic.Uint64 }
|
||||
|
||||
func (a countingUDPAnalyzer) Name() string { return "countudp" }
|
||||
func (a countingUDPAnalyzer) Limit() int { return 0 }
|
||||
func (a countingUDPAnalyzer) NewUDP(analyzer.UDPInfo, analyzer.Logger) analyzer.UDPStream {
|
||||
a.newCalls.Add(1)
|
||||
return countingUDPStream{}
|
||||
}
|
||||
|
||||
type countingUDPStream struct{}
|
||||
|
||||
func (countingUDPStream) Feed(bool, []byte) (*analyzer.PropUpdate, bool) { return nil, false }
|
||||
func (countingUDPStream) Close(bool) *analyzer.PropUpdate { return nil }
|
||||
|
||||
func TestUDPStreamManagerRebindsByTupleInO1Path(t *testing.T) {
|
||||
node, err := snowflake.NewNode(0)
|
||||
if err != nil {
|
||||
t.Fatalf("create node: %v", err)
|
||||
}
|
||||
var newCalls atomic.Uint64
|
||||
rs := countingRuleset{ans: []analyzer.Analyzer{countingUDPAnalyzer{newCalls: &newCalls}}}
|
||||
factory := &udpStreamFactory{
|
||||
WorkerID: 0,
|
||||
Logger: noopTestLogger{},
|
||||
Node: node,
|
||||
Ruleset: rs,
|
||||
}
|
||||
mgr, err := newUDPStreamManager(factory, 64, &statsCounters{})
|
||||
if err != nil {
|
||||
t.Fatalf("new manager: %v", err)
|
||||
}
|
||||
|
||||
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
|
||||
udp := &layers.UDP{SrcPort: 50000, DstPort: 443, BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}}
|
||||
|
||||
ctx1 := &udpContext{Verdict: udpVerdictAccept}
|
||||
mgr.MatchWithContext(100, ipFlow, udp, ctx1)
|
||||
if got := newCalls.Load(); got != 1 {
|
||||
t.Fatalf("new stream calls=%d want=1", got)
|
||||
}
|
||||
|
||||
ctx2 := &udpContext{Verdict: udpVerdictAccept}
|
||||
mgr.MatchWithContext(200, ipFlow, udp, ctx2)
|
||||
if got := newCalls.Load(); got != 1 {
|
||||
t.Fatalf("expected stream reuse by tuple, new stream calls=%d want=1", got)
|
||||
}
|
||||
}
|
||||
+38
-14
@@ -12,24 +12,32 @@ import (
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
var _ Engine = (*engine)(nil)
|
||||
|
||||
type workerPacket struct {
|
||||
StreamID uint32
|
||||
Data []byte
|
||||
SrcMAC net.HardwareAddr
|
||||
DstMAC net.HardwareAddr
|
||||
SetVerdict func(io.Verdict, []byte) error
|
||||
Packet io.Packet
|
||||
StreamID uint32
|
||||
Data []byte
|
||||
SrcMAC net.HardwareAddr
|
||||
DstMAC net.HardwareAddr
|
||||
Gen int64
|
||||
}
|
||||
|
||||
type workerResult struct {
|
||||
Packet io.Packet
|
||||
StreamID uint32
|
||||
Verdict io.Verdict
|
||||
ModifiedPacket []byte
|
||||
Gen int64
|
||||
}
|
||||
|
||||
type worker struct {
|
||||
id int
|
||||
packetChan chan *workerPacket
|
||||
resultChan chan workerResult
|
||||
logger Logger
|
||||
macResolver *sourceMACResolver
|
||||
|
||||
tcpFlowMgr *tcpFlowManager
|
||||
udpSM *udpStreamManager
|
||||
tcpFlowMgr *tcpFlowManager
|
||||
udpSM *udpStreamManager
|
||||
|
||||
modSerializeBuffer gopacket.SerializeBuffer
|
||||
}
|
||||
@@ -43,6 +51,9 @@ type workerConfig struct {
|
||||
TCPMaxBufferedPagesTotal int // unused, kept for config compat
|
||||
TCPMaxBufferedPagesPerConn int // unused, kept for config compat
|
||||
UDPMaxStreams int
|
||||
AnalyzerSelectionMode AnalyzerSelectionMode
|
||||
ResultChan chan workerResult
|
||||
Stats *statsCounters
|
||||
}
|
||||
|
||||
func (c *workerConfig) fillDefaults() {
|
||||
@@ -61,7 +72,8 @@ func newWorker(config workerConfig) (*worker, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode)
|
||||
selector := newAnalyzerSelector(config.AnalyzerSelectionMode, config.Stats)
|
||||
tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode, selector)
|
||||
if config.Ruleset != nil {
|
||||
tcpMgr.updateRuleset(config.Ruleset, 0)
|
||||
}
|
||||
@@ -71,8 +83,10 @@ func newWorker(config workerConfig) (*worker, error) {
|
||||
Logger: config.Logger,
|
||||
Node: sfNode,
|
||||
Ruleset: config.Ruleset,
|
||||
Selector: selector,
|
||||
Stats: config.Stats,
|
||||
}
|
||||
udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams)
|
||||
udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams, config.Stats)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -80,6 +94,7 @@ func newWorker(config workerConfig) (*worker, error) {
|
||||
return &worker{
|
||||
id: config.ID,
|
||||
packetChan: make(chan *workerPacket, config.ChanSize),
|
||||
resultChan: config.ResultChan,
|
||||
logger: config.Logger,
|
||||
macResolver: config.MACResolver,
|
||||
tcpFlowMgr: tcpMgr,
|
||||
@@ -97,6 +112,10 @@ func (w *worker) Feed(p *workerPacket) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *worker) FeedBlocking(p *workerPacket) {
|
||||
w.packetChan <- p
|
||||
}
|
||||
|
||||
func (w *worker) Run(ctx context.Context) {
|
||||
w.logger.WorkerStart(w.id)
|
||||
defer w.logger.WorkerStop(w.id)
|
||||
@@ -109,7 +128,13 @@ func (w *worker) Run(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
v, b := w.handle(wp)
|
||||
_ = wp.SetVerdict(v, b)
|
||||
w.resultChan <- workerResult{
|
||||
Packet: wp.Packet,
|
||||
StreamID: wp.StreamID,
|
||||
Verdict: v,
|
||||
ModifiedPacket: b,
|
||||
Gen: wp.Gen,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -185,8 +210,7 @@ func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []by
|
||||
SrcMAC: srcMAC,
|
||||
DstMAC: dstMAC,
|
||||
}
|
||||
// Temporarily set payload on a UDP layer so existing UDP handling works
|
||||
// We pass the payload through the context
|
||||
// Temporarily set payload on a UDP layer so existing UDP handling works.
|
||||
w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{
|
||||
BaseLayer: layers.BaseLayer{Payload: payload},
|
||||
SrcPort: layers.UDPPort(udp.SrcPort),
|
||||
|
||||
Reference in New Issue
Block a user