package engine import ( "context" "encoding/binary" "runtime" "git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/ruleset" "github.com/google/gopacket" "github.com/google/gopacket/layers" ) var _ Engine = (*engine)(nil) type engine struct { logger Logger io io.PacketIO workers []*worker } func NewEngine(config Config) (Engine, error) { workerCount := config.Workers if workerCount <= 0 { workerCount = runtime.NumCPU() } macResolver := newSourceMACResolver() var err error workers := make([]*worker, workerCount) for i := range workers { workers[i], err = newWorker(workerConfig{ ID: i, ChanSize: config.WorkerQueueSize, Logger: config.Logger, Ruleset: config.Ruleset, MACResolver: macResolver, TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal, TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn, UDPMaxStreams: config.WorkerUDPMaxStreams, }) if err != nil { return nil, err } } return &engine{ logger: config.Logger, io: config.IO, workers: workers, }, nil } func (e *engine) UpdateRuleset(r ruleset.Ruleset) error { for _, w := range e.workers { if err := w.UpdateRuleset(r); err != nil { return err } } return nil } func (e *engine) Run(ctx context.Context) error { ioCtx, ioCancel := context.WithCancel(ctx) defer ioCancel() // Stop workers & IO // Start workers for _, w := range e.workers { go w.Run(ioCtx) } // Register IO callback errChan := make(chan error, 1) err := e.io.Register(ioCtx, func(p io.Packet, err error) bool { if err != nil { errChan <- err return false } return e.dispatch(p) }) if err != nil { return err } // Block until IO errors or context is cancelled select { case err := <-errChan: return err case <-ctx.Done(): return nil } } // dispatch dispatches a packet to a worker. func (e *engine) dispatch(p io.Packet) bool { data := p.Data() layerType, srcMAC, dstMAC, ok := classifyPacket(data) if !ok { // Unsupported network layer _ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil) return true } // Load balance by stream ID index := p.StreamID() % uint32(len(e.workers)) packet := gopacket.NewPacket(data, layerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true}) e.workers[index].Feed(&workerPacket{ StreamID: p.StreamID(), Packet: packet, SrcMAC: srcMAC, DstMAC: dstMAC, SetVerdict: func(v io.Verdict, b []byte) error { return e.io.SetVerdict(p, v, b) }, }) return true } // classifyPacket detects packet framing and returns a gopacket decode layer // plus best-effort source/destination MAC addresses when available. func classifyPacket(data []byte) (gopacket.LayerType, []byte, []byte, bool) { if len(data) == 0 { return 0, nil, nil, false } // Fast path for IP packets (NFQUEUE payloads are typically IP-only). ipVersion := data[0] >> 4 if ipVersion == 4 { return layers.LayerTypeIPv4, nil, nil, true } if ipVersion == 6 { return layers.LayerTypeIPv6, nil, nil, true } // Ethernet frame path (for custom PacketIO implementations). if len(data) >= 14 { etherType := binary.BigEndian.Uint16(data[12:14]) if etherType == uint16(layers.EthernetTypeIPv4) || etherType == uint16(layers.EthernetTypeIPv6) { return layers.LayerTypeEthernet, append([]byte(nil), data[6:12]...), append([]byte(nil), data[:6]...), true } } return 0, nil, nil, false }