Files
Mellaris/engine/engine.go
hayzam a879ab4140
Some checks failed
Quality check / Tests (push) Has been cancelled
Quality check / Static analysis (push) Has been cancelled
mac resolution
2026-02-11 12:04:11 +05:30

146 lines
3.5 KiB
Go

package engine
import (
"context"
"encoding/binary"
"runtime"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
var _ Engine = (*engine)(nil)
type engine struct {
logger Logger
io io.PacketIO
workers []*worker
}
func NewEngine(config Config) (Engine, error) {
workerCount := config.Workers
if workerCount <= 0 {
workerCount = runtime.NumCPU()
}
macResolver := newSourceMACResolver()
var err error
workers := make([]*worker, workerCount)
for i := range workers {
workers[i], err = newWorker(workerConfig{
ID: i,
ChanSize: config.WorkerQueueSize,
Logger: config.Logger,
Ruleset: config.Ruleset,
MACResolver: macResolver,
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
UDPMaxStreams: config.WorkerUDPMaxStreams,
})
if err != nil {
return nil, err
}
}
return &engine{
logger: config.Logger,
io: config.IO,
workers: workers,
}, nil
}
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
for _, w := range e.workers {
if err := w.UpdateRuleset(r); err != nil {
return err
}
}
return nil
}
func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel() // Stop workers & IO
// Start workers
for _, w := range e.workers {
go w.Run(ioCtx)
}
// Register IO callback
errChan := make(chan error, 1)
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
if err != nil {
errChan <- err
return false
}
return e.dispatch(p)
})
if err != nil {
return err
}
// Block until IO errors or context is cancelled
select {
case err := <-errChan:
return err
case <-ctx.Done():
return nil
}
}
// dispatch dispatches a packet to a worker.
func (e *engine) dispatch(p io.Packet) bool {
data := p.Data()
layerType, srcMAC, dstMAC, ok := classifyPacket(data)
if !ok {
// Unsupported network layer
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
return true
}
// Load balance by stream ID
index := p.StreamID() % uint32(len(e.workers))
packet := gopacket.NewPacket(data, layerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
e.workers[index].Feed(&workerPacket{
StreamID: p.StreamID(),
Packet: packet,
SrcMAC: srcMAC,
DstMAC: dstMAC,
SetVerdict: func(v io.Verdict, b []byte) error {
return e.io.SetVerdict(p, v, b)
},
})
return true
}
// classifyPacket detects packet framing and returns a gopacket decode layer
// plus best-effort source/destination MAC addresses when available.
func classifyPacket(data []byte) (gopacket.LayerType, []byte, []byte, bool) {
if len(data) == 0 {
return 0, nil, nil, false
}
// Fast path for IP packets (NFQUEUE payloads are typically IP-only).
ipVersion := data[0] >> 4
if ipVersion == 4 {
return layers.LayerTypeIPv4, nil, nil, true
}
if ipVersion == 6 {
return layers.LayerTypeIPv6, nil, nil, true
}
// Ethernet frame path (for custom PacketIO implementations).
if len(data) >= 14 {
etherType := binary.BigEndian.Uint16(data[12:14])
if etherType == uint16(layers.EthernetTypeIPv4) || etherType == uint16(layers.EthernetTypeIPv6) {
return layers.LayerTypeEthernet,
append([]byte(nil), data[6:12]...),
append([]byte(nil), data[:6]...),
true
}
}
return 0, nil, nil, false
}