Files
Mellaris/engine/engine.go
T

241 lines
5.5 KiB
Go

package engine
import (
"context"
"runtime"
"sync"
"sync/atomic"
"time"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
)
var _ Engine = (*engine)(nil)
type verdictEntry struct {
Verdict io.Verdict
Gen int64
CreatedAt time.Time
}
const (
verdictTTL = 15 * time.Second
verdictSweepInterval = 15 * time.Second
)
type engine struct {
logger Logger
io io.PacketIO
workers []*worker
stats *statsCounters
verdicts sync.Map // streamID(uint32) -> verdictEntry
verdictsGen atomic.Int64 // incremented on ruleset update
overflowPolicy OverflowPolicy
resultCh chan workerResult
}
func NewEngine(config Config) (Engine, error) {
workerCount := config.Workers
if workerCount <= 0 {
workerCount = runtime.GOMAXPROCS(0)
if workerCount <= 0 {
workerCount = 1
}
}
overflowPolicy := config.OverflowPolicy
if overflowPolicy == "" {
overflowPolicy = OverflowPolicyDrop
}
selectionMode := config.AnalyzerSelectionMode
if selectionMode == "" {
selectionMode = AnalyzerSelectionModeSignature
}
stats := &statsCounters{}
resultCh := make(chan workerResult, workerCount*256)
macResolver := newSourceMACResolver()
var err error
workers := make([]*worker, workerCount)
for i := range workers {
workers[i], err = newWorker(workerConfig{
ID: i,
ChanSize: config.WorkerQueueSize,
Logger: config.Logger,
Ruleset: config.Ruleset,
MACResolver: macResolver,
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
UDPMaxStreams: config.WorkerUDPMaxStreams,
AnalyzerSelectionMode: selectionMode,
ResultChan: resultCh,
Stats: stats,
})
if err != nil {
return nil, err
}
}
e := &engine{
logger: config.Logger,
io: config.IO,
workers: workers,
stats: stats,
overflowPolicy: overflowPolicy,
resultCh: resultCh,
}
return e, nil
}
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
e.verdictsGen.Add(1)
for _, w := range e.workers {
if err := w.UpdateRuleset(r); err != nil {
return err
}
}
return nil
}
func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel()
for _, w := range e.workers {
go w.Run(ioCtx)
}
go e.drainResults(ioCtx)
go e.sweepVerdicts(ioCtx)
errChan := make(chan error, 1)
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
if err != nil {
errChan <- err
return false
}
return e.dispatch(p)
})
if err != nil {
return err
}
select {
case err := <-errChan:
return err
case <-ctx.Done():
return nil
}
}
func (e *engine) dispatch(p io.Packet) bool {
streamID := p.StreamID()
if streamID != 0 {
if v, ok := e.verdicts.Load(streamID); ok {
entry := v.(verdictEntry)
if entry.Gen == e.verdictsGen.Load() {
_ = e.io.SetVerdict(p, entry.Verdict, nil)
return true
}
}
}
data := p.Data()
if !validPacket(data) {
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
return true
}
gen := e.verdictsGen.Load()
index := streamID % uint32(len(e.workers))
wp := &workerPacket{
Packet: p,
StreamID: streamID,
Data: data,
Gen: gen,
}
if !e.workers[index].Feed(wp) {
e.stats.OverflowEvents.Add(1)
switch e.overflowPolicy {
case OverflowPolicyDrop:
e.stats.OverflowDrops.Add(1)
_ = e.io.SetVerdict(p, io.VerdictDrop, nil)
case OverflowPolicyBackpressure:
e.stats.OverflowBackpressureEvents.Add(1)
e.workers[index].FeedBlocking(wp)
default:
e.stats.OverflowAccepts.Add(1)
_ = e.io.SetVerdict(p, io.VerdictAccept, nil)
}
}
return true
}
func (e *engine) applyWorkerResult(r workerResult) {
if r.StreamID != 0 && (r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream) {
e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen, CreatedAt: time.Now()})
}
_ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket)
}
func (e *engine) sweepVerdicts(ctx context.Context) {
ticker := time.NewTicker(verdictSweepInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
e.verdicts.Range(func(key, value interface{}) bool {
entry := value.(verdictEntry)
if now.Sub(entry.CreatedAt) > verdictTTL {
e.verdicts.Delete(key)
}
return true
})
}
}
}
func validPacket(data []byte) bool {
if len(data) == 0 {
return false
}
ipVersion := data[0] >> 4
if ipVersion == 4 || ipVersion == 6 {
return true
}
if len(data) >= 14 {
etherType := uint16(data[12])<<8 | uint16(data[13])
if etherType == 0x0800 || etherType == 0x86DD {
return true
}
}
return false
}
func (e *engine) drainResults(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case r := <-e.resultCh:
e.applyWorkerResult(r)
}
}
}
func (e *engine) Stats() Stats {
return Stats{
OverflowEvents: e.stats.OverflowEvents.Load(),
OverflowAccepts: e.stats.OverflowAccepts.Load(),
OverflowDrops: e.stats.OverflowDrops.Load(),
OverflowBackpressureEvents: e.stats.OverflowBackpressureEvents.Load(),
AnalyzerSelectionsTotal: e.stats.AnalyzerSelectionsTotal.Load(),
AnalyzerSelectionsPruned: e.stats.AnalyzerSelectionsPruned.Load(),
UDPTupleLookups: e.stats.UDPTupleLookups.Load(),
UDPTupleHits: e.stats.UDPTupleHits.Load(),
}
}