From dc16b979e7cc6a0ebaa3c6478433fd5bf4b4d100 Mon Sep 17 00:00:00 2001 From: hayzam Date: Tue, 12 May 2026 13:31:19 +0000 Subject: [PATCH] engine, io: more caching + optimizations --- engine/engine.go | 47 ++++++++++++++++++++++++++++++------------ engine/mac_resolver.go | 8 ++----- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/engine/engine.go b/engine/engine.go index 8bf7ff8..ccd8dfe 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -4,6 +4,8 @@ import ( "context" "encoding/binary" "runtime" + "sync" + "sync/atomic" "git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/ruleset" @@ -14,10 +16,17 @@ import ( var _ Engine = (*engine)(nil) +type verdictEntry struct { + Verdict io.Verdict + Gen int64 +} + type engine struct { - logger Logger - io io.PacketIO - workers []*worker + logger Logger + io io.PacketIO + workers []*worker + verdicts sync.Map // streamID(uint32) → verdictEntry + verdictsGen atomic.Int64 // incremented on ruleset update } func NewEngine(config Config) (Engine, error) { @@ -43,14 +52,17 @@ func NewEngine(config Config) (Engine, error) { return nil, err } } - return &engine{ + e := &engine{ logger: config.Logger, io: config.IO, workers: workers, - }, nil + } + return e, nil } func (e *engine) UpdateRuleset(r ruleset.Ruleset) error { + e.verdictsGen.Add(1) + e.verdicts = sync.Map{} for _, w := range e.workers { if err := w.UpdateRuleset(r); err != nil { return err @@ -61,14 +73,12 @@ func (e *engine) UpdateRuleset(r ruleset.Ruleset) error { func (e *engine) Run(ctx context.Context) error { ioCtx, ioCancel := context.WithCancel(ctx) - defer ioCancel() // Stop workers & IO + defer ioCancel() - // Start workers for _, w := range e.workers { go w.Run(ioCtx) } - // Register IO callback errChan := make(chan error, 1) err := e.io.Register(ioCtx, func(p io.Packet, err error) bool { if err != nil { @@ -81,7 +91,6 @@ func (e *engine) Run(ctx context.Context) error { return err } - // Block until IO errors or context is cancelled select { case err := <-errChan: return err @@ -90,23 +99,35 @@ func (e *engine) Run(ctx context.Context) error { } } -// dispatch dispatches a packet to a worker. func (e *engine) dispatch(p io.Packet) bool { + streamID := p.StreamID() + + if v, ok := e.verdicts.Load(streamID); ok { + entry := v.(verdictEntry) + if entry.Gen == e.verdictsGen.Load() { + _ = e.io.SetVerdict(p, entry.Verdict, nil) + return true + } + } + data := p.Data() layerType, srcMAC, dstMAC, ok := classifyPacket(data) if !ok { - // Unsupported network layer _ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil) return true } - index := p.StreamID() % uint32(len(e.workers)) + gen := e.verdictsGen.Load() + index := streamID % uint32(len(e.workers)) e.workers[index].Feed(&workerPacket{ - StreamID: p.StreamID(), + StreamID: streamID, Data: data, LayerType: layerType, SrcMAC: srcMAC, DstMAC: dstMAC, SetVerdict: func(v io.Verdict, b []byte) error { + if v == io.VerdictAcceptStream || v == io.VerdictDropStream { + e.verdicts.Store(streamID, verdictEntry{Verdict: v, Gen: gen}) + } return e.io.SetVerdict(p, v, b) }, }) diff --git a/engine/mac_resolver.go b/engine/mac_resolver.go index a1eecd8..91fefd7 100644 --- a/engine/mac_resolver.go +++ b/engine/mac_resolver.go @@ -343,17 +343,13 @@ func lookupNeighborMACNetlink(target net.IP) (net.HardwareAddr, bool) { func readIPv6NeighCommand() map[string]net.HardwareAddr { commands := [][]string{ {"ip", "-6", "neigh", "show"}, - {"/sbin/ip", "-6", "neigh", "show"}, - {"/usr/sbin/ip", "-6", "neigh", "show"}, - {"busybox", "ip", "-6", "neigh", "show"}, - {"/bin/busybox", "ip", "-6", "neigh", "show"}, } + m := make(map[string]net.HardwareAddr) for _, cmd := range commands { out, err := exec.Command(cmd[0], cmd[1:]...).Output() if err != nil || len(out) == 0 { continue } - m := make(map[string]net.HardwareAddr) for _, line := range strings.Split(string(out), "\n") { ip, mac, ok := parseNeighborLine(line) if !ok { @@ -365,7 +361,7 @@ func readIPv6NeighCommand() map[string]net.HardwareAddr { return m } } - return map[string]net.HardwareAddr{} + return m } func parseNeighborLine(line string) (string, net.HardwareAddr, bool) {