engine, io: more caching + optimizations
This commit is contained in:
+34
-13
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/io"
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||
@@ -14,10 +16,17 @@ import (
|
||||
|
||||
var _ Engine = (*engine)(nil)
|
||||
|
||||
type verdictEntry struct {
|
||||
Verdict io.Verdict
|
||||
Gen int64
|
||||
}
|
||||
|
||||
type engine struct {
|
||||
logger Logger
|
||||
io io.PacketIO
|
||||
workers []*worker
|
||||
logger Logger
|
||||
io io.PacketIO
|
||||
workers []*worker
|
||||
verdicts sync.Map // streamID(uint32) → verdictEntry
|
||||
verdictsGen atomic.Int64 // incremented on ruleset update
|
||||
}
|
||||
|
||||
func NewEngine(config Config) (Engine, error) {
|
||||
@@ -43,14 +52,17 @@ func NewEngine(config Config) (Engine, error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &engine{
|
||||
e := &engine{
|
||||
logger: config.Logger,
|
||||
io: config.IO,
|
||||
workers: workers,
|
||||
}, nil
|
||||
}
|
||||
return e, nil
|
||||
}
|
||||
|
||||
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
|
||||
e.verdictsGen.Add(1)
|
||||
e.verdicts = sync.Map{}
|
||||
for _, w := range e.workers {
|
||||
if err := w.UpdateRuleset(r); err != nil {
|
||||
return err
|
||||
@@ -61,14 +73,12 @@ func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
|
||||
|
||||
func (e *engine) Run(ctx context.Context) error {
|
||||
ioCtx, ioCancel := context.WithCancel(ctx)
|
||||
defer ioCancel() // Stop workers & IO
|
||||
defer ioCancel()
|
||||
|
||||
// Start workers
|
||||
for _, w := range e.workers {
|
||||
go w.Run(ioCtx)
|
||||
}
|
||||
|
||||
// Register IO callback
|
||||
errChan := make(chan error, 1)
|
||||
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
|
||||
if err != nil {
|
||||
@@ -81,7 +91,6 @@ func (e *engine) Run(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Block until IO errors or context is cancelled
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return err
|
||||
@@ -90,23 +99,35 @@ func (e *engine) Run(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// dispatch dispatches a packet to a worker.
|
||||
func (e *engine) dispatch(p io.Packet) bool {
|
||||
streamID := p.StreamID()
|
||||
|
||||
if v, ok := e.verdicts.Load(streamID); ok {
|
||||
entry := v.(verdictEntry)
|
||||
if entry.Gen == e.verdictsGen.Load() {
|
||||
_ = e.io.SetVerdict(p, entry.Verdict, nil)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
data := p.Data()
|
||||
layerType, srcMAC, dstMAC, ok := classifyPacket(data)
|
||||
if !ok {
|
||||
// Unsupported network layer
|
||||
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
|
||||
return true
|
||||
}
|
||||
index := p.StreamID() % uint32(len(e.workers))
|
||||
gen := e.verdictsGen.Load()
|
||||
index := streamID % uint32(len(e.workers))
|
||||
e.workers[index].Feed(&workerPacket{
|
||||
StreamID: p.StreamID(),
|
||||
StreamID: streamID,
|
||||
Data: data,
|
||||
LayerType: layerType,
|
||||
SrcMAC: srcMAC,
|
||||
DstMAC: dstMAC,
|
||||
SetVerdict: func(v io.Verdict, b []byte) error {
|
||||
if v == io.VerdictAcceptStream || v == io.VerdictDropStream {
|
||||
e.verdicts.Store(streamID, verdictEntry{Verdict: v, Gen: gen})
|
||||
}
|
||||
return e.io.SetVerdict(p, v, b)
|
||||
},
|
||||
})
|
||||
|
||||
@@ -343,17 +343,13 @@ func lookupNeighborMACNetlink(target net.IP) (net.HardwareAddr, bool) {
|
||||
func readIPv6NeighCommand() map[string]net.HardwareAddr {
|
||||
commands := [][]string{
|
||||
{"ip", "-6", "neigh", "show"},
|
||||
{"/sbin/ip", "-6", "neigh", "show"},
|
||||
{"/usr/sbin/ip", "-6", "neigh", "show"},
|
||||
{"busybox", "ip", "-6", "neigh", "show"},
|
||||
{"/bin/busybox", "ip", "-6", "neigh", "show"},
|
||||
}
|
||||
m := make(map[string]net.HardwareAddr)
|
||||
for _, cmd := range commands {
|
||||
out, err := exec.Command(cmd[0], cmd[1:]...).Output()
|
||||
if err != nil || len(out) == 0 {
|
||||
continue
|
||||
}
|
||||
m := make(map[string]net.HardwareAddr)
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
ip, mac, ok := parseNeighborLine(line)
|
||||
if !ok {
|
||||
@@ -365,7 +361,7 @@ func readIPv6NeighCommand() map[string]net.HardwareAddr {
|
||||
return m
|
||||
}
|
||||
}
|
||||
return map[string]net.HardwareAddr{}
|
||||
return m
|
||||
}
|
||||
|
||||
func parseNeighborLine(line string) (string, net.HardwareAddr, bool) {
|
||||
|
||||
Reference in New Issue
Block a user