engine, io: more caching + optimizations

This commit is contained in:
2026-05-12 13:31:19 +00:00
parent 4d70520e43
commit dc16b979e7
2 changed files with 36 additions and 19 deletions
+34 -13
View File
@@ -4,6 +4,8 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"runtime" "runtime"
"sync"
"sync/atomic"
"git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset" "git.difuse.io/Difuse/Mellaris/ruleset"
@@ -14,10 +16,17 @@ import (
var _ Engine = (*engine)(nil) var _ Engine = (*engine)(nil)
type verdictEntry struct {
Verdict io.Verdict
Gen int64
}
type engine struct { type engine struct {
logger Logger logger Logger
io io.PacketIO io io.PacketIO
workers []*worker workers []*worker
verdicts sync.Map // streamID(uint32) → verdictEntry
verdictsGen atomic.Int64 // incremented on ruleset update
} }
func NewEngine(config Config) (Engine, error) { func NewEngine(config Config) (Engine, error) {
@@ -43,14 +52,17 @@ func NewEngine(config Config) (Engine, error) {
return nil, err return nil, err
} }
} }
return &engine{ e := &engine{
logger: config.Logger, logger: config.Logger,
io: config.IO, io: config.IO,
workers: workers, workers: workers,
}, nil }
return e, nil
} }
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error { func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
e.verdictsGen.Add(1)
e.verdicts = sync.Map{}
for _, w := range e.workers { for _, w := range e.workers {
if err := w.UpdateRuleset(r); err != nil { if err := w.UpdateRuleset(r); err != nil {
return err return err
@@ -61,14 +73,12 @@ func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
func (e *engine) Run(ctx context.Context) error { func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx) ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel() // Stop workers & IO defer ioCancel()
// Start workers
for _, w := range e.workers { for _, w := range e.workers {
go w.Run(ioCtx) go w.Run(ioCtx)
} }
// Register IO callback
errChan := make(chan error, 1) errChan := make(chan error, 1)
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool { err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
if err != nil { if err != nil {
@@ -81,7 +91,6 @@ func (e *engine) Run(ctx context.Context) error {
return err return err
} }
// Block until IO errors or context is cancelled
select { select {
case err := <-errChan: case err := <-errChan:
return err return err
@@ -90,23 +99,35 @@ func (e *engine) Run(ctx context.Context) error {
} }
} }
// dispatch dispatches a packet to a worker.
func (e *engine) dispatch(p io.Packet) bool { func (e *engine) dispatch(p io.Packet) bool {
streamID := p.StreamID()
if v, ok := e.verdicts.Load(streamID); ok {
entry := v.(verdictEntry)
if entry.Gen == e.verdictsGen.Load() {
_ = e.io.SetVerdict(p, entry.Verdict, nil)
return true
}
}
data := p.Data() data := p.Data()
layerType, srcMAC, dstMAC, ok := classifyPacket(data) layerType, srcMAC, dstMAC, ok := classifyPacket(data)
if !ok { if !ok {
// Unsupported network layer
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil) _ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
return true return true
} }
index := p.StreamID() % uint32(len(e.workers)) gen := e.verdictsGen.Load()
index := streamID % uint32(len(e.workers))
e.workers[index].Feed(&workerPacket{ e.workers[index].Feed(&workerPacket{
StreamID: p.StreamID(), StreamID: streamID,
Data: data, Data: data,
LayerType: layerType, LayerType: layerType,
SrcMAC: srcMAC, SrcMAC: srcMAC,
DstMAC: dstMAC, DstMAC: dstMAC,
SetVerdict: func(v io.Verdict, b []byte) error { SetVerdict: func(v io.Verdict, b []byte) error {
if v == io.VerdictAcceptStream || v == io.VerdictDropStream {
e.verdicts.Store(streamID, verdictEntry{Verdict: v, Gen: gen})
}
return e.io.SetVerdict(p, v, b) return e.io.SetVerdict(p, v, b)
}, },
}) })
+2 -6
View File
@@ -343,17 +343,13 @@ func lookupNeighborMACNetlink(target net.IP) (net.HardwareAddr, bool) {
func readIPv6NeighCommand() map[string]net.HardwareAddr { func readIPv6NeighCommand() map[string]net.HardwareAddr {
commands := [][]string{ commands := [][]string{
{"ip", "-6", "neigh", "show"}, {"ip", "-6", "neigh", "show"},
{"/sbin/ip", "-6", "neigh", "show"},
{"/usr/sbin/ip", "-6", "neigh", "show"},
{"busybox", "ip", "-6", "neigh", "show"},
{"/bin/busybox", "ip", "-6", "neigh", "show"},
} }
m := make(map[string]net.HardwareAddr)
for _, cmd := range commands { for _, cmd := range commands {
out, err := exec.Command(cmd[0], cmd[1:]...).Output() out, err := exec.Command(cmd[0], cmd[1:]...).Output()
if err != nil || len(out) == 0 { if err != nil || len(out) == 0 {
continue continue
} }
m := make(map[string]net.HardwareAddr)
for _, line := range strings.Split(string(out), "\n") { for _, line := range strings.Split(string(out), "\n") {
ip, mac, ok := parseNeighborLine(line) ip, mac, ok := parseNeighborLine(line)
if !ok { if !ok {
@@ -365,7 +361,7 @@ func readIPv6NeighCommand() map[string]net.HardwareAddr {
return m return m
} }
} }
return map[string]net.HardwareAddr{} return m
} }
func parseNeighborLine(line string) (string, net.HardwareAddr, bool) { func parseNeighborLine(line string) (string, net.HardwareAddr, bool) {