Files
Mellaris/engine/engine.go
T

166 lines
3.8 KiB
Go

package engine
import (
"context"
"encoding/binary"
"runtime"
"sync"
"sync/atomic"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
var _ Engine = (*engine)(nil)
type verdictEntry struct {
Verdict io.Verdict
Gen int64
}
type engine struct {
logger Logger
io io.PacketIO
workers []*worker
verdicts sync.Map // streamID(uint32) → verdictEntry
verdictsGen atomic.Int64 // incremented on ruleset update
}
func NewEngine(config Config) (Engine, error) {
workerCount := config.Workers
if workerCount <= 0 {
workerCount = runtime.NumCPU()
}
macResolver := newSourceMACResolver()
var err error
workers := make([]*worker, workerCount)
for i := range workers {
workers[i], err = newWorker(workerConfig{
ID: i,
ChanSize: config.WorkerQueueSize,
Logger: config.Logger,
Ruleset: config.Ruleset,
MACResolver: macResolver,
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
UDPMaxStreams: config.WorkerUDPMaxStreams,
})
if err != nil {
return nil, err
}
}
e := &engine{
logger: config.Logger,
io: config.IO,
workers: workers,
}
return e, nil
}
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
e.verdictsGen.Add(1)
e.verdicts = sync.Map{}
for _, w := range e.workers {
if err := w.UpdateRuleset(r); err != nil {
return err
}
}
return nil
}
func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel()
for _, w := range e.workers {
go w.Run(ioCtx)
}
errChan := make(chan error, 1)
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
if err != nil {
errChan <- err
return false
}
return e.dispatch(p)
})
if err != nil {
return err
}
select {
case err := <-errChan:
return err
case <-ctx.Done():
return nil
}
}
func (e *engine) dispatch(p io.Packet) bool {
streamID := p.StreamID()
if v, ok := e.verdicts.Load(streamID); ok {
entry := v.(verdictEntry)
if entry.Gen == e.verdictsGen.Load() {
_ = e.io.SetVerdict(p, entry.Verdict, nil)
return true
}
}
data := p.Data()
layerType, srcMAC, dstMAC, ok := classifyPacket(data)
if !ok {
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
return true
}
gen := e.verdictsGen.Load()
index := streamID % uint32(len(e.workers))
e.workers[index].Feed(&workerPacket{
StreamID: streamID,
Data: data,
LayerType: layerType,
SrcMAC: srcMAC,
DstMAC: dstMAC,
SetVerdict: func(v io.Verdict, b []byte) error {
if v == io.VerdictAcceptStream || v == io.VerdictDropStream {
e.verdicts.Store(streamID, verdictEntry{Verdict: v, Gen: gen})
}
return e.io.SetVerdict(p, v, b)
},
})
return true
}
// classifyPacket detects packet framing and returns a gopacket decode layer
// plus best-effort source/destination MAC addresses when available.
func classifyPacket(data []byte) (gopacket.LayerType, []byte, []byte, bool) {
if len(data) == 0 {
return 0, nil, nil, false
}
// Fast path for IP packets (NFQUEUE payloads are typically IP-only).
ipVersion := data[0] >> 4
if ipVersion == 4 {
return layers.LayerTypeIPv4, nil, nil, true
}
if ipVersion == 6 {
return layers.LayerTypeIPv6, nil, nil, true
}
// Ethernet frame path (for custom PacketIO implementations).
if len(data) >= 14 {
etherType := binary.BigEndian.Uint16(data[12:14])
if etherType == uint16(layers.EthernetTypeIPv4) || etherType == uint16(layers.EthernetTypeIPv6) {
return layers.LayerTypeEthernet,
append([]byte(nil), data[6:12]...),
append([]byte(nil), data[:6]...),
true
}
}
return 0, nil, nil, false
}