241 lines
5.5 KiB
Go
241 lines
5.5 KiB
Go
package engine
|
|
|
|
import (
|
|
"context"
|
|
"runtime"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"git.difuse.io/Difuse/Mellaris/io"
|
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
|
)
|
|
|
|
var _ Engine = (*engine)(nil)
|
|
|
|
type verdictEntry struct {
|
|
Verdict io.Verdict
|
|
Gen int64
|
|
CreatedAt time.Time
|
|
}
|
|
|
|
const (
|
|
verdictTTL = 15 * time.Second
|
|
verdictSweepInterval = 15 * time.Second
|
|
)
|
|
|
|
type engine struct {
|
|
logger Logger
|
|
io io.PacketIO
|
|
workers []*worker
|
|
stats *statsCounters
|
|
verdicts sync.Map // streamID(uint32) -> verdictEntry
|
|
verdictsGen atomic.Int64 // incremented on ruleset update
|
|
|
|
overflowPolicy OverflowPolicy
|
|
resultCh chan workerResult
|
|
}
|
|
|
|
func NewEngine(config Config) (Engine, error) {
|
|
workerCount := config.Workers
|
|
if workerCount <= 0 {
|
|
workerCount = runtime.GOMAXPROCS(0)
|
|
if workerCount <= 0 {
|
|
workerCount = 1
|
|
}
|
|
}
|
|
overflowPolicy := config.OverflowPolicy
|
|
if overflowPolicy == "" {
|
|
overflowPolicy = OverflowPolicyDrop
|
|
}
|
|
selectionMode := config.AnalyzerSelectionMode
|
|
if selectionMode == "" {
|
|
selectionMode = AnalyzerSelectionModeSignature
|
|
}
|
|
|
|
stats := &statsCounters{}
|
|
resultCh := make(chan workerResult, workerCount*256)
|
|
|
|
macResolver := newSourceMACResolver()
|
|
var err error
|
|
workers := make([]*worker, workerCount)
|
|
for i := range workers {
|
|
workers[i], err = newWorker(workerConfig{
|
|
ID: i,
|
|
ChanSize: config.WorkerQueueSize,
|
|
Logger: config.Logger,
|
|
Ruleset: config.Ruleset,
|
|
MACResolver: macResolver,
|
|
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
|
|
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
|
|
UDPMaxStreams: config.WorkerUDPMaxStreams,
|
|
AnalyzerSelectionMode: selectionMode,
|
|
ResultChan: resultCh,
|
|
Stats: stats,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
e := &engine{
|
|
logger: config.Logger,
|
|
io: config.IO,
|
|
workers: workers,
|
|
stats: stats,
|
|
overflowPolicy: overflowPolicy,
|
|
resultCh: resultCh,
|
|
}
|
|
return e, nil
|
|
}
|
|
|
|
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
|
|
e.verdictsGen.Add(1)
|
|
for _, w := range e.workers {
|
|
if err := w.UpdateRuleset(r); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *engine) Run(ctx context.Context) error {
|
|
ioCtx, ioCancel := context.WithCancel(ctx)
|
|
defer ioCancel()
|
|
|
|
for _, w := range e.workers {
|
|
go w.Run(ioCtx)
|
|
}
|
|
go e.drainResults(ioCtx)
|
|
go e.sweepVerdicts(ioCtx)
|
|
|
|
errChan := make(chan error, 1)
|
|
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
|
|
if err != nil {
|
|
errChan <- err
|
|
return false
|
|
}
|
|
return e.dispatch(p)
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
return err
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (e *engine) dispatch(p io.Packet) bool {
|
|
streamID := p.StreamID()
|
|
|
|
if streamID != 0 {
|
|
if v, ok := e.verdicts.Load(streamID); ok {
|
|
entry := v.(verdictEntry)
|
|
if entry.Gen == e.verdictsGen.Load() {
|
|
_ = e.io.SetVerdict(p, entry.Verdict, nil)
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
data := p.Data()
|
|
if !validPacket(data) {
|
|
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
|
|
return true
|
|
}
|
|
gen := e.verdictsGen.Load()
|
|
index := streamID % uint32(len(e.workers))
|
|
wp := &workerPacket{
|
|
Packet: p,
|
|
StreamID: streamID,
|
|
Data: data,
|
|
Gen: gen,
|
|
}
|
|
if !e.workers[index].Feed(wp) {
|
|
e.stats.OverflowEvents.Add(1)
|
|
switch e.overflowPolicy {
|
|
case OverflowPolicyDrop:
|
|
e.stats.OverflowDrops.Add(1)
|
|
_ = e.io.SetVerdict(p, io.VerdictDrop, nil)
|
|
case OverflowPolicyBackpressure:
|
|
e.stats.OverflowBackpressureEvents.Add(1)
|
|
e.workers[index].FeedBlocking(wp)
|
|
default:
|
|
e.stats.OverflowAccepts.Add(1)
|
|
_ = e.io.SetVerdict(p, io.VerdictAccept, nil)
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (e *engine) applyWorkerResult(r workerResult) {
|
|
if r.StreamID != 0 && (r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream) {
|
|
e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen, CreatedAt: time.Now()})
|
|
}
|
|
_ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket)
|
|
}
|
|
|
|
func (e *engine) sweepVerdicts(ctx context.Context) {
|
|
ticker := time.NewTicker(verdictSweepInterval)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
now := time.Now()
|
|
e.verdicts.Range(func(key, value interface{}) bool {
|
|
entry := value.(verdictEntry)
|
|
if now.Sub(entry.CreatedAt) > verdictTTL {
|
|
e.verdicts.Delete(key)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
func validPacket(data []byte) bool {
|
|
if len(data) == 0 {
|
|
return false
|
|
}
|
|
ipVersion := data[0] >> 4
|
|
if ipVersion == 4 || ipVersion == 6 {
|
|
return true
|
|
}
|
|
if len(data) >= 14 {
|
|
etherType := uint16(data[12])<<8 | uint16(data[13])
|
|
if etherType == 0x0800 || etherType == 0x86DD {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (e *engine) drainResults(ctx context.Context) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case r := <-e.resultCh:
|
|
e.applyWorkerResult(r)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (e *engine) Stats() Stats {
|
|
return Stats{
|
|
OverflowEvents: e.stats.OverflowEvents.Load(),
|
|
OverflowAccepts: e.stats.OverflowAccepts.Load(),
|
|
OverflowDrops: e.stats.OverflowDrops.Load(),
|
|
OverflowBackpressureEvents: e.stats.OverflowBackpressureEvents.Load(),
|
|
AnalyzerSelectionsTotal: e.stats.AnalyzerSelectionsTotal.Load(),
|
|
AnalyzerSelectionsPruned: e.stats.AnalyzerSelectionsPruned.Load(),
|
|
UDPTupleLookups: e.stats.UDPTupleLookups.Load(),
|
|
UDPTupleHits: e.stats.UDPTupleHits.Load(),
|
|
}
|
|
}
|