First Commit
Some checks failed
Quality check / Static analysis (push) Has been cancelled
Quality check / Tests (push) Has been cancelled

This commit is contained in:
Hayzam Sherif
2026-02-11 06:27:36 +05:30
commit 94e1e26cc3
56 changed files with 8530 additions and 0 deletions

115
engine/engine.go Normal file
View File

@@ -0,0 +1,115 @@
package engine
import (
"context"
"runtime"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
var _ Engine = (*engine)(nil)
type engine struct {
logger Logger
io io.PacketIO
workers []*worker
}
func NewEngine(config Config) (Engine, error) {
workerCount := config.Workers
if workerCount <= 0 {
workerCount = runtime.NumCPU()
}
var err error
workers := make([]*worker, workerCount)
for i := range workers {
workers[i], err = newWorker(workerConfig{
ID: i,
ChanSize: config.WorkerQueueSize,
Logger: config.Logger,
Ruleset: config.Ruleset,
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
UDPMaxStreams: config.WorkerUDPMaxStreams,
})
if err != nil {
return nil, err
}
}
return &engine{
logger: config.Logger,
io: config.IO,
workers: workers,
}, nil
}
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
for _, w := range e.workers {
if err := w.UpdateRuleset(r); err != nil {
return err
}
}
return nil
}
func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel() // Stop workers & IO
// Start workers
for _, w := range e.workers {
go w.Run(ioCtx)
}
// Register IO callback
errChan := make(chan error, 1)
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
if err != nil {
errChan <- err
return false
}
return e.dispatch(p)
})
if err != nil {
return err
}
// Block until IO errors or context is cancelled
select {
case err := <-errChan:
return err
case <-ctx.Done():
return nil
}
}
// dispatch dispatches a packet to a worker.
func (e *engine) dispatch(p io.Packet) bool {
data := p.Data()
ipVersion := data[0] >> 4
var layerType gopacket.LayerType
if ipVersion == 4 {
layerType = layers.LayerTypeIPv4
} else if ipVersion == 6 {
layerType = layers.LayerTypeIPv6
} else {
// Unsupported network layer
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
return true
}
// Load balance by stream ID
index := p.StreamID() % uint32(len(e.workers))
packet := gopacket.NewPacket(data, layerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
e.workers[index].Feed(&workerPacket{
StreamID: p.StreamID(),
Packet: packet,
SetVerdict: func(v io.Verdict, b []byte) error {
return e.io.SetVerdict(p, v, b)
},
})
return true
}

49
engine/interface.go Normal file
View File

@@ -0,0 +1,49 @@
package engine
import (
"context"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
)
// Engine is the main engine for Mellaris.
type Engine interface {
// UpdateRuleset updates the ruleset.
UpdateRuleset(ruleset.Ruleset) error
// Run runs the engine, until an error occurs or the context is cancelled.
Run(context.Context) error
}
// Config is the configuration for the engine.
type Config struct {
Logger Logger
IO io.PacketIO
Ruleset ruleset.Ruleset
Workers int // Number of workers. Zero or negative means auto (number of CPU cores).
WorkerQueueSize int
WorkerTCPMaxBufferedPagesTotal int
WorkerTCPMaxBufferedPagesPerConn int
WorkerUDPMaxStreams int
}
// Logger is the combined logging interface for the engine, workers and analyzers.
type Logger interface {
WorkerStart(id int)
WorkerStop(id int)
TCPStreamNew(workerID int, info ruleset.StreamInfo)
TCPStreamPropUpdate(info ruleset.StreamInfo, close bool)
TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool)
UDPStreamNew(workerID int, info ruleset.StreamInfo)
UDPStreamPropUpdate(info ruleset.StreamInfo, close bool)
UDPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool)
ModifyError(info ruleset.StreamInfo, err error)
AnalyzerDebugf(streamID int64, name string, format string, args ...interface{})
AnalyzerInfof(streamID int64, name string, format string, args ...interface{})
AnalyzerErrorf(streamID int64, name string, format string, args ...interface{})
}

229
engine/tcp.go Normal file
View File

@@ -0,0 +1,229 @@
package engine
import (
"net"
"sync"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/reassembly"
)
// tcpVerdict is a subset of io.Verdict for TCP streams.
// We don't allow modifying or dropping a single packet
// for TCP streams for now, as it doesn't make much sense.
type tcpVerdict io.Verdict
const (
tcpVerdictAccept = tcpVerdict(io.VerdictAccept)
tcpVerdictAcceptStream = tcpVerdict(io.VerdictAcceptStream)
tcpVerdictDropStream = tcpVerdict(io.VerdictDropStream)
)
type tcpContext struct {
*gopacket.PacketMetadata
Verdict tcpVerdict
}
func (ctx *tcpContext) GetCaptureInfo() gopacket.CaptureInfo {
return ctx.CaptureInfo
}
type tcpStreamFactory struct {
WorkerID int
Logger Logger
Node *snowflake.Node
RulesetMutex sync.RWMutex
Ruleset ruleset.Ruleset
}
func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream {
id := f.Node.Generate()
ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw())
info := ruleset.StreamInfo{
ID: id.Int64(),
Protocol: ruleset.ProtocolTCP,
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: uint16(tcp.SrcPort),
DstPort: uint16(tcp.DstPort),
Props: make(analyzer.CombinedPropMap),
}
f.Logger.TCPStreamNew(f.WorkerID, info)
f.RulesetMutex.RLock()
rs := f.Ruleset
f.RulesetMutex.RUnlock()
ans := analyzersToTCPAnalyzers(rs.Analyzers(info))
// Create entries for each analyzer
entries := make([]*tcpStreamEntry, 0, len(ans))
for _, a := range ans {
entries = append(entries, &tcpStreamEntry{
Name: a.Name(),
Stream: a.NewTCP(analyzer.TCPInfo{
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: uint16(tcp.SrcPort),
DstPort: uint16(tcp.DstPort),
}, &analyzerLogger{
StreamID: id.Int64(),
Name: a.Name(),
Logger: f.Logger,
}),
HasLimit: a.Limit() > 0,
Quota: a.Limit(),
})
}
return &tcpStream{
info: info,
virgin: true,
logger: f.Logger,
ruleset: rs,
activeEntries: entries,
}
}
func (f *tcpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error {
f.RulesetMutex.Lock()
defer f.RulesetMutex.Unlock()
f.Ruleset = r
return nil
}
type tcpStream struct {
info ruleset.StreamInfo
virgin bool // true if no packets have been processed
logger Logger
ruleset ruleset.Ruleset
activeEntries []*tcpStreamEntry
doneEntries []*tcpStreamEntry
lastVerdict tcpVerdict
}
type tcpStreamEntry struct {
Name string
Stream analyzer.TCPStream
HasLimit bool
Quota int
}
func (s *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool {
if len(s.activeEntries) > 0 || s.virgin {
// Make sure every stream matches against the ruleset at least once,
// even if there are no activeEntries, as the ruleset may have built-in
// properties that need to be matched.
return true
} else {
ctx := ac.(*tcpContext)
ctx.Verdict = s.lastVerdict
return false
}
}
func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) {
dir, start, end, skip := sg.Info()
rev := dir == reassembly.TCPDirServerToClient
avail, _ := sg.Lengths()
data := sg.Fetch(avail)
updated := false
for i := len(s.activeEntries) - 1; i >= 0; i-- {
// Important: reverse order so we can remove entries
entry := s.activeEntries[i]
update, closeUpdate, done := s.feedEntry(entry, rev, start, end, skip, data)
up1 := processPropUpdate(s.info.Props, entry.Name, update)
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
updated = updated || up1 || up2
if done {
s.activeEntries = append(s.activeEntries[:i], s.activeEntries[i+1:]...)
s.doneEntries = append(s.doneEntries, entry)
}
}
ctx := ac.(*tcpContext)
if updated || s.virgin {
s.virgin = false
s.logger.TCPStreamPropUpdate(s.info, false)
// Match properties against ruleset
result := s.ruleset.Match(s.info)
action := result.Action
if action != ruleset.ActionMaybe && action != ruleset.ActionModify {
verdict := actionToTCPVerdict(action)
s.lastVerdict = verdict
ctx.Verdict = verdict
s.logger.TCPStreamAction(s.info, action, false)
// Verdict issued, no need to process any more packets
s.closeActiveEntries()
}
}
if len(s.activeEntries) == 0 && ctx.Verdict == tcpVerdictAccept {
// All entries are done but no verdict issued, accept stream
s.lastVerdict = tcpVerdictAcceptStream
ctx.Verdict = tcpVerdictAcceptStream
s.logger.TCPStreamAction(s.info, ruleset.ActionAllow, true)
}
}
func (s *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool {
s.closeActiveEntries()
return true
}
func (s *tcpStream) closeActiveEntries() {
// Signal close to all active entries & move them to doneEntries
updated := false
for _, entry := range s.activeEntries {
update := entry.Stream.Close(false)
up := processPropUpdate(s.info.Props, entry.Name, update)
updated = updated || up
}
if updated {
s.logger.TCPStreamPropUpdate(s.info, true)
}
s.doneEntries = append(s.doneEntries, s.activeEntries...)
s.activeEntries = nil
}
func (s *tcpStream) feedEntry(entry *tcpStreamEntry, rev, start, end bool, skip int, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
if !entry.HasLimit {
update, done = entry.Stream.Feed(rev, start, end, skip, data)
} else {
qData := data
if len(qData) > entry.Quota {
qData = qData[:entry.Quota]
}
update, done = entry.Stream.Feed(rev, start, end, skip, qData)
entry.Quota -= len(qData)
if entry.Quota <= 0 {
// Quota exhausted, signal close & move to doneEntries
closeUpdate = entry.Stream.Close(true)
done = true
}
}
return
}
func analyzersToTCPAnalyzers(ans []analyzer.Analyzer) []analyzer.TCPAnalyzer {
tcpAns := make([]analyzer.TCPAnalyzer, 0, len(ans))
for _, a := range ans {
if tcpM, ok := a.(analyzer.TCPAnalyzer); ok {
tcpAns = append(tcpAns, tcpM)
}
}
return tcpAns
}
func actionToTCPVerdict(a ruleset.Action) tcpVerdict {
switch a {
case ruleset.ActionMaybe, ruleset.ActionAllow, ruleset.ActionModify:
return tcpVerdictAcceptStream
case ruleset.ActionBlock, ruleset.ActionDrop:
return tcpVerdictDropStream
default:
// Should never happen
return tcpVerdictAcceptStream
}
}

299
engine/udp.go Normal file
View File

@@ -0,0 +1,299 @@
package engine
import (
"errors"
"net"
"sync"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/modifier"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
lru "github.com/hashicorp/golang-lru/v2"
)
// udpVerdict is a subset of io.Verdict for UDP streams.
// For UDP, we support all verdicts.
type udpVerdict io.Verdict
const (
udpVerdictAccept = udpVerdict(io.VerdictAccept)
udpVerdictAcceptModify = udpVerdict(io.VerdictAcceptModify)
udpVerdictAcceptStream = udpVerdict(io.VerdictAcceptStream)
udpVerdictDrop = udpVerdict(io.VerdictDrop)
udpVerdictDropStream = udpVerdict(io.VerdictDropStream)
)
var errInvalidModifier = errors.New("invalid modifier")
type udpContext struct {
Verdict udpVerdict
Packet []byte
}
type udpStreamFactory struct {
WorkerID int
Logger Logger
Node *snowflake.Node
RulesetMutex sync.RWMutex
Ruleset ruleset.Ruleset
}
func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) *udpStream {
id := f.Node.Generate()
ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw())
info := ruleset.StreamInfo{
ID: id.Int64(),
Protocol: ruleset.ProtocolUDP,
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: uint16(udp.SrcPort),
DstPort: uint16(udp.DstPort),
Props: make(analyzer.CombinedPropMap),
}
f.Logger.UDPStreamNew(f.WorkerID, info)
f.RulesetMutex.RLock()
rs := f.Ruleset
f.RulesetMutex.RUnlock()
ans := analyzersToUDPAnalyzers(rs.Analyzers(info))
// Create entries for each analyzer
entries := make([]*udpStreamEntry, 0, len(ans))
for _, a := range ans {
entries = append(entries, &udpStreamEntry{
Name: a.Name(),
Stream: a.NewUDP(analyzer.UDPInfo{
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: uint16(udp.SrcPort),
DstPort: uint16(udp.DstPort),
}, &analyzerLogger{
StreamID: id.Int64(),
Name: a.Name(),
Logger: f.Logger,
}),
HasLimit: a.Limit() > 0,
Quota: a.Limit(),
})
}
return &udpStream{
info: info,
virgin: true,
logger: f.Logger,
ruleset: rs,
activeEntries: entries,
}
}
func (f *udpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error {
f.RulesetMutex.Lock()
defer f.RulesetMutex.Unlock()
f.Ruleset = r
return nil
}
type udpStreamManager struct {
factory *udpStreamFactory
streams *lru.Cache[uint32, *udpStreamValue]
}
type udpStreamValue struct {
Stream *udpStream
IPFlow gopacket.Flow
UDPFlow gopacket.Flow
}
func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) {
fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow
rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()
return fwd || rev, rev
}
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int) (*udpStreamManager, error) {
ss, err := lru.New[uint32, *udpStreamValue](maxStreams)
if err != nil {
return nil, err
}
return &udpStreamManager{
factory: factory,
streams: ss,
}, nil
}
func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) {
rev := false
value, ok := m.streams.Get(streamID)
if !ok {
// New stream
value = &udpStreamValue{
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc),
IPFlow: ipFlow,
UDPFlow: udp.TransportFlow(),
}
m.streams.Add(streamID, value)
} else {
// Stream ID exists, but is it really the same stream?
ok, rev = value.Match(ipFlow, udp.TransportFlow())
if !ok {
// It's not - close the old stream & replace it with a new one
value.Stream.Close()
value = &udpStreamValue{
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc),
IPFlow: ipFlow,
UDPFlow: udp.TransportFlow(),
}
m.streams.Add(streamID, value)
}
}
if value.Stream.Accept(udp, rev, uc) {
value.Stream.Feed(udp, rev, uc)
}
}
type udpStream struct {
info ruleset.StreamInfo
virgin bool // true if no packets have been processed
logger Logger
ruleset ruleset.Ruleset
activeEntries []*udpStreamEntry
doneEntries []*udpStreamEntry
lastVerdict udpVerdict
}
type udpStreamEntry struct {
Name string
Stream analyzer.UDPStream
HasLimit bool
Quota int
}
func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool {
if len(s.activeEntries) > 0 || s.virgin {
// Make sure every stream matches against the ruleset at least once,
// even if there are no activeEntries, as the ruleset may have built-in
// properties that need to be matched.
return true
} else {
uc.Verdict = s.lastVerdict
return false
}
}
func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) {
updated := false
for i := len(s.activeEntries) - 1; i >= 0; i-- {
// Important: reverse order so we can remove entries
entry := s.activeEntries[i]
update, closeUpdate, done := s.feedEntry(entry, rev, udp.Payload)
up1 := processPropUpdate(s.info.Props, entry.Name, update)
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
updated = updated || up1 || up2
if done {
s.activeEntries = append(s.activeEntries[:i], s.activeEntries[i+1:]...)
s.doneEntries = append(s.doneEntries, entry)
}
}
if updated || s.virgin {
s.virgin = false
s.logger.UDPStreamPropUpdate(s.info, false)
// Match properties against ruleset
result := s.ruleset.Match(s.info)
action := result.Action
if action == ruleset.ActionModify {
// Call the modifier instance
udpMI, ok := result.ModInstance.(modifier.UDPModifierInstance)
if !ok {
// Not for UDP, fallback to maybe
s.logger.ModifyError(s.info, errInvalidModifier)
action = ruleset.ActionMaybe
} else {
var err error
uc.Packet, err = udpMI.Process(udp.Payload)
if err != nil {
// Modifier error, fallback to maybe
s.logger.ModifyError(s.info, err)
action = ruleset.ActionMaybe
}
}
}
if action != ruleset.ActionMaybe {
verdict, final := actionToUDPVerdict(action)
s.lastVerdict = verdict
uc.Verdict = verdict
s.logger.UDPStreamAction(s.info, action, false)
if final {
s.closeActiveEntries()
}
}
}
if len(s.activeEntries) == 0 && uc.Verdict == udpVerdictAccept {
// All entries are done but no verdict issued, accept stream
s.lastVerdict = udpVerdictAcceptStream
uc.Verdict = udpVerdictAcceptStream
s.logger.UDPStreamAction(s.info, ruleset.ActionAllow, true)
}
}
func (s *udpStream) Close() {
s.closeActiveEntries()
}
func (s *udpStream) closeActiveEntries() {
// Signal close to all active entries & move them to doneEntries
updated := false
for _, entry := range s.activeEntries {
update := entry.Stream.Close(false)
up := processPropUpdate(s.info.Props, entry.Name, update)
updated = updated || up
}
if updated {
s.logger.UDPStreamPropUpdate(s.info, true)
}
s.doneEntries = append(s.doneEntries, s.activeEntries...)
s.activeEntries = nil
}
func (s *udpStream) feedEntry(entry *udpStreamEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
update, done = entry.Stream.Feed(rev, data)
if entry.HasLimit {
entry.Quota -= len(data)
if entry.Quota <= 0 {
// Quota exhausted, signal close & move to doneEntries
closeUpdate = entry.Stream.Close(true)
done = true
}
}
return
}
func analyzersToUDPAnalyzers(ans []analyzer.Analyzer) []analyzer.UDPAnalyzer {
udpAns := make([]analyzer.UDPAnalyzer, 0, len(ans))
for _, a := range ans {
if udpM, ok := a.(analyzer.UDPAnalyzer); ok {
udpAns = append(udpAns, udpM)
}
}
return udpAns
}
func actionToUDPVerdict(a ruleset.Action) (v udpVerdict, final bool) {
switch a {
case ruleset.ActionMaybe:
return udpVerdictAccept, false
case ruleset.ActionAllow:
return udpVerdictAcceptStream, true
case ruleset.ActionBlock:
return udpVerdictDropStream, true
case ruleset.ActionDrop:
return udpVerdictDrop, false
case ruleset.ActionModify:
return udpVerdictAcceptModify, false
default:
// Should never happen
return udpVerdictAccept, false
}
}

50
engine/utils.go Normal file
View File

@@ -0,0 +1,50 @@
package engine
import "git.difuse.io/Difuse/Mellaris/analyzer"
var _ analyzer.Logger = (*analyzerLogger)(nil)
type analyzerLogger struct {
StreamID int64
Name string
Logger Logger
}
func (l *analyzerLogger) Debugf(format string, args ...interface{}) {
l.Logger.AnalyzerDebugf(l.StreamID, l.Name, format, args...)
}
func (l *analyzerLogger) Infof(format string, args ...interface{}) {
l.Logger.AnalyzerInfof(l.StreamID, l.Name, format, args...)
}
func (l *analyzerLogger) Errorf(format string, args ...interface{}) {
l.Logger.AnalyzerErrorf(l.StreamID, l.Name, format, args...)
}
func processPropUpdate(cpm analyzer.CombinedPropMap, name string, update *analyzer.PropUpdate) (updated bool) {
if update == nil || update.Type == analyzer.PropUpdateNone {
return false
}
switch update.Type {
case analyzer.PropUpdateMerge:
m := cpm[name]
if m == nil {
m = make(analyzer.PropMap, len(update.M))
cpm[name] = m
}
for k, v := range update.M {
m[k] = v
}
return true
case analyzer.PropUpdateReplace:
cpm[name] = update.M
return true
case analyzer.PropUpdateDelete:
delete(cpm, name)
return true
default:
// Invalid update type, ignore for now
return false
}
}

185
engine/worker.go Normal file
View File

@@ -0,0 +1,185 @@
package engine
import (
"context"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/reassembly"
)
const (
defaultChanSize = 64
defaultTCPMaxBufferedPagesTotal = 4096
defaultTCPMaxBufferedPagesPerConnection = 64
defaultUDPMaxStreams = 4096
)
type workerPacket struct {
StreamID uint32
Packet gopacket.Packet
SetVerdict func(io.Verdict, []byte) error
}
type worker struct {
id int
packetChan chan *workerPacket
logger Logger
tcpStreamFactory *tcpStreamFactory
tcpStreamPool *reassembly.StreamPool
tcpAssembler *reassembly.Assembler
udpStreamFactory *udpStreamFactory
udpStreamManager *udpStreamManager
modSerializeBuffer gopacket.SerializeBuffer
}
type workerConfig struct {
ID int
ChanSize int
Logger Logger
Ruleset ruleset.Ruleset
TCPMaxBufferedPagesTotal int
TCPMaxBufferedPagesPerConn int
UDPMaxStreams int
}
func (c *workerConfig) fillDefaults() {
if c.ChanSize <= 0 {
c.ChanSize = defaultChanSize
}
if c.TCPMaxBufferedPagesTotal <= 0 {
c.TCPMaxBufferedPagesTotal = defaultTCPMaxBufferedPagesTotal
}
if c.TCPMaxBufferedPagesPerConn <= 0 {
c.TCPMaxBufferedPagesPerConn = defaultTCPMaxBufferedPagesPerConnection
}
if c.UDPMaxStreams <= 0 {
c.UDPMaxStreams = defaultUDPMaxStreams
}
}
func newWorker(config workerConfig) (*worker, error) {
config.fillDefaults()
sfNode, err := snowflake.NewNode(int64(config.ID))
if err != nil {
return nil, err
}
tcpSF := &tcpStreamFactory{
WorkerID: config.ID,
Logger: config.Logger,
Node: sfNode,
Ruleset: config.Ruleset,
}
tcpStreamPool := reassembly.NewStreamPool(tcpSF)
tcpAssembler := reassembly.NewAssembler(tcpStreamPool)
tcpAssembler.MaxBufferedPagesTotal = config.TCPMaxBufferedPagesTotal
tcpAssembler.MaxBufferedPagesPerConnection = config.TCPMaxBufferedPagesPerConn
udpSF := &udpStreamFactory{
WorkerID: config.ID,
Logger: config.Logger,
Node: sfNode,
Ruleset: config.Ruleset,
}
udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams)
if err != nil {
return nil, err
}
return &worker{
id: config.ID,
packetChan: make(chan *workerPacket, config.ChanSize),
logger: config.Logger,
tcpStreamFactory: tcpSF,
tcpStreamPool: tcpStreamPool,
tcpAssembler: tcpAssembler,
udpStreamFactory: udpSF,
udpStreamManager: udpSM,
modSerializeBuffer: gopacket.NewSerializeBuffer(),
}, nil
}
func (w *worker) Feed(p *workerPacket) {
w.packetChan <- p
}
func (w *worker) Run(ctx context.Context) {
w.logger.WorkerStart(w.id)
defer w.logger.WorkerStop(w.id)
for {
select {
case <-ctx.Done():
return
case wPkt := <-w.packetChan:
if wPkt == nil {
// Closed
return
}
v, b := w.handle(wPkt.StreamID, wPkt.Packet)
_ = wPkt.SetVerdict(v, b)
}
}
}
func (w *worker) UpdateRuleset(r ruleset.Ruleset) error {
if err := w.tcpStreamFactory.UpdateRuleset(r); err != nil {
return err
}
return w.udpStreamFactory.UpdateRuleset(r)
}
func (w *worker) handle(streamID uint32, p gopacket.Packet) (io.Verdict, []byte) {
netLayer, trLayer := p.NetworkLayer(), p.TransportLayer()
if netLayer == nil || trLayer == nil {
// Invalid packet
return io.VerdictAccept, nil
}
ipFlow := netLayer.NetworkFlow()
switch tr := trLayer.(type) {
case *layers.TCP:
return w.handleTCP(ipFlow, p.Metadata(), tr), nil
case *layers.UDP:
v, modPayload := w.handleUDP(streamID, ipFlow, tr)
if v == io.VerdictAcceptModify && modPayload != nil {
tr.Payload = modPayload
_ = tr.SetNetworkLayerForChecksum(netLayer)
_ = w.modSerializeBuffer.Clear()
err := gopacket.SerializePacket(w.modSerializeBuffer,
gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}, p)
if err != nil {
// Just accept without modification for now
return io.VerdictAccept, nil
}
return v, w.modSerializeBuffer.Bytes()
}
return v, nil
default:
// Unsupported protocol
return io.VerdictAccept, nil
}
}
func (w *worker) handleTCP(ipFlow gopacket.Flow, pMeta *gopacket.PacketMetadata, tcp *layers.TCP) io.Verdict {
ctx := &tcpContext{
PacketMetadata: pMeta,
Verdict: tcpVerdictAccept,
}
w.tcpAssembler.AssembleWithContext(ipFlow, tcp, ctx)
return io.Verdict(ctx.Verdict)
}
func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP) (io.Verdict, []byte) {
ctx := &udpContext{
Verdict: udpVerdictAccept,
}
w.udpStreamManager.MatchWithContext(streamID, ipFlow, udp, ctx)
return io.Verdict(ctx.Verdict), ctx.Packet
}