engine, io: more caching + optimizations
This commit is contained in:
+34
-13
@@ -4,6 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"git.difuse.io/Difuse/Mellaris/io"
|
"git.difuse.io/Difuse/Mellaris/io"
|
||||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||||
@@ -14,10 +16,17 @@ import (
|
|||||||
|
|
||||||
var _ Engine = (*engine)(nil)
|
var _ Engine = (*engine)(nil)
|
||||||
|
|
||||||
|
type verdictEntry struct {
|
||||||
|
Verdict io.Verdict
|
||||||
|
Gen int64
|
||||||
|
}
|
||||||
|
|
||||||
type engine struct {
|
type engine struct {
|
||||||
logger Logger
|
logger Logger
|
||||||
io io.PacketIO
|
io io.PacketIO
|
||||||
workers []*worker
|
workers []*worker
|
||||||
|
verdicts sync.Map // streamID(uint32) → verdictEntry
|
||||||
|
verdictsGen atomic.Int64 // incremented on ruleset update
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewEngine(config Config) (Engine, error) {
|
func NewEngine(config Config) (Engine, error) {
|
||||||
@@ -43,14 +52,17 @@ func NewEngine(config Config) (Engine, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &engine{
|
e := &engine{
|
||||||
logger: config.Logger,
|
logger: config.Logger,
|
||||||
io: config.IO,
|
io: config.IO,
|
||||||
workers: workers,
|
workers: workers,
|
||||||
}, nil
|
}
|
||||||
|
return e, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
|
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
|
||||||
|
e.verdictsGen.Add(1)
|
||||||
|
e.verdicts = sync.Map{}
|
||||||
for _, w := range e.workers {
|
for _, w := range e.workers {
|
||||||
if err := w.UpdateRuleset(r); err != nil {
|
if err := w.UpdateRuleset(r); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -61,14 +73,12 @@ func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
|
|||||||
|
|
||||||
func (e *engine) Run(ctx context.Context) error {
|
func (e *engine) Run(ctx context.Context) error {
|
||||||
ioCtx, ioCancel := context.WithCancel(ctx)
|
ioCtx, ioCancel := context.WithCancel(ctx)
|
||||||
defer ioCancel() // Stop workers & IO
|
defer ioCancel()
|
||||||
|
|
||||||
// Start workers
|
|
||||||
for _, w := range e.workers {
|
for _, w := range e.workers {
|
||||||
go w.Run(ioCtx)
|
go w.Run(ioCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register IO callback
|
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
|
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -81,7 +91,6 @@ func (e *engine) Run(ctx context.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Block until IO errors or context is cancelled
|
|
||||||
select {
|
select {
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
return err
|
return err
|
||||||
@@ -90,23 +99,35 @@ func (e *engine) Run(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// dispatch dispatches a packet to a worker.
|
|
||||||
func (e *engine) dispatch(p io.Packet) bool {
|
func (e *engine) dispatch(p io.Packet) bool {
|
||||||
|
streamID := p.StreamID()
|
||||||
|
|
||||||
|
if v, ok := e.verdicts.Load(streamID); ok {
|
||||||
|
entry := v.(verdictEntry)
|
||||||
|
if entry.Gen == e.verdictsGen.Load() {
|
||||||
|
_ = e.io.SetVerdict(p, entry.Verdict, nil)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
data := p.Data()
|
data := p.Data()
|
||||||
layerType, srcMAC, dstMAC, ok := classifyPacket(data)
|
layerType, srcMAC, dstMAC, ok := classifyPacket(data)
|
||||||
if !ok {
|
if !ok {
|
||||||
// Unsupported network layer
|
|
||||||
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
|
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
index := p.StreamID() % uint32(len(e.workers))
|
gen := e.verdictsGen.Load()
|
||||||
|
index := streamID % uint32(len(e.workers))
|
||||||
e.workers[index].Feed(&workerPacket{
|
e.workers[index].Feed(&workerPacket{
|
||||||
StreamID: p.StreamID(),
|
StreamID: streamID,
|
||||||
Data: data,
|
Data: data,
|
||||||
LayerType: layerType,
|
LayerType: layerType,
|
||||||
SrcMAC: srcMAC,
|
SrcMAC: srcMAC,
|
||||||
DstMAC: dstMAC,
|
DstMAC: dstMAC,
|
||||||
SetVerdict: func(v io.Verdict, b []byte) error {
|
SetVerdict: func(v io.Verdict, b []byte) error {
|
||||||
|
if v == io.VerdictAcceptStream || v == io.VerdictDropStream {
|
||||||
|
e.verdicts.Store(streamID, verdictEntry{Verdict: v, Gen: gen})
|
||||||
|
}
|
||||||
return e.io.SetVerdict(p, v, b)
|
return e.io.SetVerdict(p, v, b)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -343,17 +343,13 @@ func lookupNeighborMACNetlink(target net.IP) (net.HardwareAddr, bool) {
|
|||||||
func readIPv6NeighCommand() map[string]net.HardwareAddr {
|
func readIPv6NeighCommand() map[string]net.HardwareAddr {
|
||||||
commands := [][]string{
|
commands := [][]string{
|
||||||
{"ip", "-6", "neigh", "show"},
|
{"ip", "-6", "neigh", "show"},
|
||||||
{"/sbin/ip", "-6", "neigh", "show"},
|
|
||||||
{"/usr/sbin/ip", "-6", "neigh", "show"},
|
|
||||||
{"busybox", "ip", "-6", "neigh", "show"},
|
|
||||||
{"/bin/busybox", "ip", "-6", "neigh", "show"},
|
|
||||||
}
|
}
|
||||||
|
m := make(map[string]net.HardwareAddr)
|
||||||
for _, cmd := range commands {
|
for _, cmd := range commands {
|
||||||
out, err := exec.Command(cmd[0], cmd[1:]...).Output()
|
out, err := exec.Command(cmd[0], cmd[1:]...).Output()
|
||||||
if err != nil || len(out) == 0 {
|
if err != nil || len(out) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
m := make(map[string]net.HardwareAddr)
|
|
||||||
for _, line := range strings.Split(string(out), "\n") {
|
for _, line := range strings.Split(string(out), "\n") {
|
||||||
ip, mac, ok := parseNeighborLine(line)
|
ip, mac, ok := parseNeighborLine(line)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -365,7 +361,7 @@ func readIPv6NeighCommand() map[string]net.HardwareAddr {
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return map[string]net.HardwareAddr{}
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseNeighborLine(line string) (string, net.HardwareAddr, bool) {
|
func parseNeighborLine(line string) (string, net.HardwareAddr, bool) {
|
||||||
|
|||||||
Reference in New Issue
Block a user