Files
Mellaris/engine/engine.go
T
hayzam 7a3f6e945d Improves flow handling and adds runtime stats APIs
Refactors TCP and UDP flow managers to enhance analyzer selection and flow binding accuracy, including O(1) UDP stream rebinding by 5-tuple.
Introduces runtime stats tracking for engine and ruleset operations, exposing new APIs for granular performance and error metrics.
Optimizes GeoMatcher with result caching and supports efficient geosite set matching, reducing redundant computation in ruleset expressions.
2026-05-13 06:10:38 +05:30

212 lines
4.9 KiB
Go

package engine
import (
"context"
"runtime"
"sync"
"sync/atomic"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
)
var _ Engine = (*engine)(nil)
type verdictEntry struct {
Verdict io.Verdict
Gen int64
}
type engine struct {
logger Logger
io io.PacketIO
workers []*worker
stats *statsCounters
verdicts sync.Map // streamID(uint32) -> verdictEntry
verdictsGen atomic.Int64 // incremented on ruleset update
overflowPolicy OverflowPolicy
resultCh chan workerResult
}
func NewEngine(config Config) (Engine, error) {
workerCount := config.Workers
if workerCount <= 0 {
workerCount = runtime.GOMAXPROCS(0)
if workerCount <= 0 {
workerCount = 1
}
}
overflowPolicy := config.OverflowPolicy
if overflowPolicy == "" {
overflowPolicy = OverflowPolicyAccept
}
selectionMode := config.AnalyzerSelectionMode
if selectionMode == "" {
selectionMode = AnalyzerSelectionModeSignature
}
stats := &statsCounters{}
resultCh := make(chan workerResult, workerCount*256)
macResolver := newSourceMACResolver()
var err error
workers := make([]*worker, workerCount)
for i := range workers {
workers[i], err = newWorker(workerConfig{
ID: i,
ChanSize: config.WorkerQueueSize,
Logger: config.Logger,
Ruleset: config.Ruleset,
MACResolver: macResolver,
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
UDPMaxStreams: config.WorkerUDPMaxStreams,
AnalyzerSelectionMode: selectionMode,
ResultChan: resultCh,
Stats: stats,
})
if err != nil {
return nil, err
}
}
e := &engine{
logger: config.Logger,
io: config.IO,
workers: workers,
stats: stats,
overflowPolicy: overflowPolicy,
resultCh: resultCh,
}
return e, nil
}
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
e.verdictsGen.Add(1)
e.verdicts = sync.Map{}
for _, w := range e.workers {
if err := w.UpdateRuleset(r); err != nil {
return err
}
}
return nil
}
func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel()
for _, w := range e.workers {
go w.Run(ioCtx)
}
go e.drainResults(ioCtx)
errChan := make(chan error, 1)
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
if err != nil {
errChan <- err
return false
}
return e.dispatch(p)
})
if err != nil {
return err
}
select {
case err := <-errChan:
return err
case <-ctx.Done():
return nil
}
}
func (e *engine) dispatch(p io.Packet) bool {
streamID := p.StreamID()
if v, ok := e.verdicts.Load(streamID); ok {
entry := v.(verdictEntry)
if entry.Gen == e.verdictsGen.Load() {
_ = e.io.SetVerdict(p, entry.Verdict, nil)
return true
}
}
data := p.Data()
if !validPacket(data) {
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
return true
}
gen := e.verdictsGen.Load()
index := streamID % uint32(len(e.workers))
wp := &workerPacket{
Packet: p,
StreamID: streamID,
Data: data,
Gen: gen,
}
if !e.workers[index].Feed(wp) {
e.stats.OverflowEvents.Add(1)
switch e.overflowPolicy {
case OverflowPolicyDrop:
e.stats.OverflowDrops.Add(1)
_ = e.io.SetVerdict(p, io.VerdictDrop, nil)
case OverflowPolicyBackpressure:
e.stats.OverflowBackpressureEvents.Add(1)
e.workers[index].FeedBlocking(wp)
default:
e.stats.OverflowAccepts.Add(1)
_ = e.io.SetVerdict(p, io.VerdictAccept, nil)
}
}
return true
}
func (e *engine) applyWorkerResult(r workerResult) {
if r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream {
e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen})
}
_ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket)
}
func validPacket(data []byte) bool {
if len(data) == 0 {
return false
}
ipVersion := data[0] >> 4
if ipVersion == 4 || ipVersion == 6 {
return true
}
if len(data) >= 14 {
etherType := uint16(data[12])<<8 | uint16(data[13])
if etherType == 0x0800 || etherType == 0x86DD {
return true
}
}
return false
}
func (e *engine) drainResults(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case r := <-e.resultCh:
e.applyWorkerResult(r)
}
}
}
func (e *engine) Stats() Stats {
return Stats{
OverflowEvents: e.stats.OverflowEvents.Load(),
OverflowAccepts: e.stats.OverflowAccepts.Load(),
OverflowDrops: e.stats.OverflowDrops.Load(),
OverflowBackpressureEvents: e.stats.OverflowBackpressureEvents.Load(),
AnalyzerSelectionsTotal: e.stats.AnalyzerSelectionsTotal.Load(),
AnalyzerSelectionsPruned: e.stats.AnalyzerSelectionsPruned.Load(),
UDPTupleLookups: e.stats.UDPTupleLookups.Load(),
UDPTupleHits: e.stats.UDPTupleHits.Load(),
}
}