package engine import ( "context" "runtime" "sync" "sync/atomic" "git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/ruleset" ) var _ Engine = (*engine)(nil) type verdictEntry struct { Verdict io.Verdict Gen int64 } type engine struct { logger Logger io io.PacketIO workers []*worker stats *statsCounters verdicts sync.Map // streamID(uint32) -> verdictEntry verdictsGen atomic.Int64 // incremented on ruleset update overflowPolicy OverflowPolicy resultCh chan workerResult } func NewEngine(config Config) (Engine, error) { workerCount := config.Workers if workerCount <= 0 { workerCount = runtime.GOMAXPROCS(0) if workerCount <= 0 { workerCount = 1 } } overflowPolicy := config.OverflowPolicy if overflowPolicy == "" { overflowPolicy = OverflowPolicyAccept } selectionMode := config.AnalyzerSelectionMode if selectionMode == "" { selectionMode = AnalyzerSelectionModeSignature } stats := &statsCounters{} resultCh := make(chan workerResult, workerCount*256) macResolver := newSourceMACResolver() var err error workers := make([]*worker, workerCount) for i := range workers { workers[i], err = newWorker(workerConfig{ ID: i, ChanSize: config.WorkerQueueSize, Logger: config.Logger, Ruleset: config.Ruleset, MACResolver: macResolver, TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal, TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn, UDPMaxStreams: config.WorkerUDPMaxStreams, AnalyzerSelectionMode: selectionMode, ResultChan: resultCh, Stats: stats, }) if err != nil { return nil, err } } e := &engine{ logger: config.Logger, io: config.IO, workers: workers, stats: stats, overflowPolicy: overflowPolicy, resultCh: resultCh, } return e, nil } func (e *engine) UpdateRuleset(r ruleset.Ruleset) error { e.verdictsGen.Add(1) e.verdicts = sync.Map{} for _, w := range e.workers { if err := w.UpdateRuleset(r); err != nil { return err } } return nil } func (e *engine) Run(ctx context.Context) error { ioCtx, ioCancel := context.WithCancel(ctx) defer ioCancel() for _, w := range e.workers { go w.Run(ioCtx) } go e.drainResults(ioCtx) errChan := make(chan error, 1) err := e.io.Register(ioCtx, func(p io.Packet, err error) bool { if err != nil { errChan <- err return false } return e.dispatch(p) }) if err != nil { return err } select { case err := <-errChan: return err case <-ctx.Done(): return nil } } func (e *engine) dispatch(p io.Packet) bool { streamID := p.StreamID() if v, ok := e.verdicts.Load(streamID); ok { entry := v.(verdictEntry) if entry.Gen == e.verdictsGen.Load() { _ = e.io.SetVerdict(p, entry.Verdict, nil) return true } } data := p.Data() if !validPacket(data) { _ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil) return true } gen := e.verdictsGen.Load() index := streamID % uint32(len(e.workers)) wp := &workerPacket{ Packet: p, StreamID: streamID, Data: data, Gen: gen, } if !e.workers[index].Feed(wp) { e.stats.OverflowEvents.Add(1) switch e.overflowPolicy { case OverflowPolicyDrop: e.stats.OverflowDrops.Add(1) _ = e.io.SetVerdict(p, io.VerdictDrop, nil) case OverflowPolicyBackpressure: e.stats.OverflowBackpressureEvents.Add(1) e.workers[index].FeedBlocking(wp) default: e.stats.OverflowAccepts.Add(1) _ = e.io.SetVerdict(p, io.VerdictAccept, nil) } } return true } func (e *engine) applyWorkerResult(r workerResult) { if r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream { e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen}) } _ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket) } func validPacket(data []byte) bool { if len(data) == 0 { return false } ipVersion := data[0] >> 4 if ipVersion == 4 || ipVersion == 6 { return true } if len(data) >= 14 { etherType := uint16(data[12])<<8 | uint16(data[13]) if etherType == 0x0800 || etherType == 0x86DD { return true } } return false } func (e *engine) drainResults(ctx context.Context) { for { select { case <-ctx.Done(): return case r := <-e.resultCh: e.applyWorkerResult(r) } } } func (e *engine) Stats() Stats { return Stats{ OverflowEvents: e.stats.OverflowEvents.Load(), OverflowAccepts: e.stats.OverflowAccepts.Load(), OverflowDrops: e.stats.OverflowDrops.Load(), OverflowBackpressureEvents: e.stats.OverflowBackpressureEvents.Load(), AnalyzerSelectionsTotal: e.stats.AnalyzerSelectionsTotal.Load(), AnalyzerSelectionsPruned: e.stats.AnalyzerSelectionsPruned.Load(), UDPTupleLookups: e.stats.UDPTupleLookups.Load(), UDPTupleHits: e.stats.UDPTupleHits.Load(), } }