package engine import ( "context" "sync" "sync/atomic" "git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/ruleset" ) var _ Engine = (*engine)(nil) type verdictEntry struct { Verdict io.Verdict Gen int64 } type engine struct { logger Logger io io.PacketIO workers []*worker verdicts sync.Map // streamID(uint32) → verdictEntry verdictsGen atomic.Int64 // incremented on ruleset update overflowCh chan *workerPacket overflowOnce sync.Once } func NewEngine(config Config) (Engine, error) { workerCount := config.Workers if workerCount <= 0 { workerCount = 1 } macResolver := newSourceMACResolver() var err error workers := make([]*worker, workerCount) for i := range workers { workers[i], err = newWorker(workerConfig{ ID: i, ChanSize: config.WorkerQueueSize, Logger: config.Logger, Ruleset: config.Ruleset, MACResolver: macResolver, TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal, TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn, UDPMaxStreams: config.WorkerUDPMaxStreams, }) if err != nil { return nil, err } } e := &engine{ logger: config.Logger, io: config.IO, workers: workers, overflowCh: make(chan *workerPacket, 1024), } return e, nil } func (e *engine) UpdateRuleset(r ruleset.Ruleset) error { e.verdictsGen.Add(1) e.verdicts = sync.Map{} for _, w := range e.workers { if err := w.UpdateRuleset(r); err != nil { return err } } return nil } func (e *engine) Run(ctx context.Context) error { ioCtx, ioCancel := context.WithCancel(ctx) defer ioCancel() e.overflowOnce.Do(func() { go e.drainOverflow(ioCtx) }) for _, w := range e.workers { go w.Run(ioCtx) } errChan := make(chan error, 1) err := e.io.Register(ioCtx, func(p io.Packet, err error) bool { if err != nil { errChan <- err return false } return e.dispatch(p) }) if err != nil { return err } select { case err := <-errChan: return err case <-ctx.Done(): return nil } } func (e *engine) dispatch(p io.Packet) bool { streamID := p.StreamID() if v, ok := e.verdicts.Load(streamID); ok { entry := v.(verdictEntry) if entry.Gen == e.verdictsGen.Load() { _ = e.io.SetVerdict(p, entry.Verdict, nil) return true } } data := p.Data() if !validPacket(data) { _ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil) return true } gen := e.verdictsGen.Load() index := streamID % uint32(len(e.workers)) wp := &workerPacket{ StreamID: streamID, Data: data, SetVerdict: func(v io.Verdict, b []byte) error { if v == io.VerdictAcceptStream || v == io.VerdictDropStream { e.verdicts.Store(streamID, verdictEntry{Verdict: v, Gen: gen}) } return e.io.SetVerdict(p, v, b) }, } if !e.workers[index].Feed(wp) { select { case e.overflowCh <- wp: default: } } return true } func validPacket(data []byte) bool { if len(data) == 0 { return false } ipVersion := data[0] >> 4 if ipVersion == 4 || ipVersion == 6 { return true } if len(data) >= 14 { etherType := uint16(data[12])<<8 | uint16(data[13]) if etherType == 0x0800 || etherType == 0x86DD { return true } } return false } func (e *engine) drainOverflow(ctx context.Context) { for { select { case <-ctx.Done(): return case wp := <-e.overflowCh: _ = wp.SetVerdict(io.VerdictAccept, nil) } } }