7a3f6e945d
Refactors TCP and UDP flow managers to enhance analyzer selection and flow binding accuracy, including O(1) UDP stream rebinding by 5-tuple. Introduces runtime stats tracking for engine and ruleset operations, exposing new APIs for granular performance and error metrics. Optimizes GeoMatcher with result caching and supports efficient geosite set matching, reducing redundant computation in ruleset expressions.
212 lines
4.9 KiB
Go
212 lines
4.9 KiB
Go
package engine
|
|
|
|
import (
|
|
"context"
|
|
"runtime"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"git.difuse.io/Difuse/Mellaris/io"
|
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
|
)
|
|
|
|
var _ Engine = (*engine)(nil)
|
|
|
|
type verdictEntry struct {
|
|
Verdict io.Verdict
|
|
Gen int64
|
|
}
|
|
|
|
type engine struct {
|
|
logger Logger
|
|
io io.PacketIO
|
|
workers []*worker
|
|
stats *statsCounters
|
|
verdicts sync.Map // streamID(uint32) -> verdictEntry
|
|
verdictsGen atomic.Int64 // incremented on ruleset update
|
|
|
|
overflowPolicy OverflowPolicy
|
|
resultCh chan workerResult
|
|
}
|
|
|
|
func NewEngine(config Config) (Engine, error) {
|
|
workerCount := config.Workers
|
|
if workerCount <= 0 {
|
|
workerCount = runtime.GOMAXPROCS(0)
|
|
if workerCount <= 0 {
|
|
workerCount = 1
|
|
}
|
|
}
|
|
overflowPolicy := config.OverflowPolicy
|
|
if overflowPolicy == "" {
|
|
overflowPolicy = OverflowPolicyAccept
|
|
}
|
|
selectionMode := config.AnalyzerSelectionMode
|
|
if selectionMode == "" {
|
|
selectionMode = AnalyzerSelectionModeSignature
|
|
}
|
|
|
|
stats := &statsCounters{}
|
|
resultCh := make(chan workerResult, workerCount*256)
|
|
|
|
macResolver := newSourceMACResolver()
|
|
var err error
|
|
workers := make([]*worker, workerCount)
|
|
for i := range workers {
|
|
workers[i], err = newWorker(workerConfig{
|
|
ID: i,
|
|
ChanSize: config.WorkerQueueSize,
|
|
Logger: config.Logger,
|
|
Ruleset: config.Ruleset,
|
|
MACResolver: macResolver,
|
|
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
|
|
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
|
|
UDPMaxStreams: config.WorkerUDPMaxStreams,
|
|
AnalyzerSelectionMode: selectionMode,
|
|
ResultChan: resultCh,
|
|
Stats: stats,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
e := &engine{
|
|
logger: config.Logger,
|
|
io: config.IO,
|
|
workers: workers,
|
|
stats: stats,
|
|
overflowPolicy: overflowPolicy,
|
|
resultCh: resultCh,
|
|
}
|
|
return e, nil
|
|
}
|
|
|
|
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
|
|
e.verdictsGen.Add(1)
|
|
e.verdicts = sync.Map{}
|
|
for _, w := range e.workers {
|
|
if err := w.UpdateRuleset(r); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *engine) Run(ctx context.Context) error {
|
|
ioCtx, ioCancel := context.WithCancel(ctx)
|
|
defer ioCancel()
|
|
|
|
for _, w := range e.workers {
|
|
go w.Run(ioCtx)
|
|
}
|
|
go e.drainResults(ioCtx)
|
|
|
|
errChan := make(chan error, 1)
|
|
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
|
|
if err != nil {
|
|
errChan <- err
|
|
return false
|
|
}
|
|
return e.dispatch(p)
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
return err
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (e *engine) dispatch(p io.Packet) bool {
|
|
streamID := p.StreamID()
|
|
|
|
if v, ok := e.verdicts.Load(streamID); ok {
|
|
entry := v.(verdictEntry)
|
|
if entry.Gen == e.verdictsGen.Load() {
|
|
_ = e.io.SetVerdict(p, entry.Verdict, nil)
|
|
return true
|
|
}
|
|
}
|
|
|
|
data := p.Data()
|
|
if !validPacket(data) {
|
|
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
|
|
return true
|
|
}
|
|
gen := e.verdictsGen.Load()
|
|
index := streamID % uint32(len(e.workers))
|
|
wp := &workerPacket{
|
|
Packet: p,
|
|
StreamID: streamID,
|
|
Data: data,
|
|
Gen: gen,
|
|
}
|
|
if !e.workers[index].Feed(wp) {
|
|
e.stats.OverflowEvents.Add(1)
|
|
switch e.overflowPolicy {
|
|
case OverflowPolicyDrop:
|
|
e.stats.OverflowDrops.Add(1)
|
|
_ = e.io.SetVerdict(p, io.VerdictDrop, nil)
|
|
case OverflowPolicyBackpressure:
|
|
e.stats.OverflowBackpressureEvents.Add(1)
|
|
e.workers[index].FeedBlocking(wp)
|
|
default:
|
|
e.stats.OverflowAccepts.Add(1)
|
|
_ = e.io.SetVerdict(p, io.VerdictAccept, nil)
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (e *engine) applyWorkerResult(r workerResult) {
|
|
if r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream {
|
|
e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen})
|
|
}
|
|
_ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket)
|
|
}
|
|
|
|
func validPacket(data []byte) bool {
|
|
if len(data) == 0 {
|
|
return false
|
|
}
|
|
ipVersion := data[0] >> 4
|
|
if ipVersion == 4 || ipVersion == 6 {
|
|
return true
|
|
}
|
|
if len(data) >= 14 {
|
|
etherType := uint16(data[12])<<8 | uint16(data[13])
|
|
if etherType == 0x0800 || etherType == 0x86DD {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (e *engine) drainResults(ctx context.Context) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case r := <-e.resultCh:
|
|
e.applyWorkerResult(r)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (e *engine) Stats() Stats {
|
|
return Stats{
|
|
OverflowEvents: e.stats.OverflowEvents.Load(),
|
|
OverflowAccepts: e.stats.OverflowAccepts.Load(),
|
|
OverflowDrops: e.stats.OverflowDrops.Load(),
|
|
OverflowBackpressureEvents: e.stats.OverflowBackpressureEvents.Load(),
|
|
AnalyzerSelectionsTotal: e.stats.AnalyzerSelectionsTotal.Load(),
|
|
AnalyzerSelectionsPruned: e.stats.AnalyzerSelectionsPruned.Load(),
|
|
UDPTupleLookups: e.stats.UDPTupleLookups.Load(),
|
|
UDPTupleHits: e.stats.UDPTupleHits.Load(),
|
|
}
|
|
}
|