refactor: engine/tcp/worker perf improvements

This commit is contained in:
2026-05-12 15:16:11 +00:00
parent dc16b979e7
commit ecc2cde1c2
9 changed files with 743 additions and 546 deletions
+40 -37
View File
@@ -2,16 +2,11 @@ package engine
import (
"context"
"encoding/binary"
"runtime"
"sync"
"sync/atomic"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
var _ Engine = (*engine)(nil)
@@ -27,12 +22,15 @@ type engine struct {
workers []*worker
verdicts sync.Map // streamID(uint32) → verdictEntry
verdictsGen atomic.Int64 // incremented on ruleset update
overflowCh chan *workerPacket
overflowOnce sync.Once
}
func NewEngine(config Config) (Engine, error) {
workerCount := config.Workers
if workerCount <= 0 {
workerCount = runtime.NumCPU()
workerCount = 1
}
macResolver := newSourceMACResolver()
var err error
@@ -53,9 +51,10 @@ func NewEngine(config Config) (Engine, error) {
}
}
e := &engine{
logger: config.Logger,
io: config.IO,
workers: workers,
logger: config.Logger,
io: config.IO,
workers: workers,
overflowCh: make(chan *workerPacket, 1024),
}
return e, nil
}
@@ -75,6 +74,10 @@ func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel()
e.overflowOnce.Do(func() {
go e.drainOverflow(ioCtx)
})
for _, w := range e.workers {
go w.Run(ioCtx)
}
@@ -111,55 +114,55 @@ func (e *engine) dispatch(p io.Packet) bool {
}
data := p.Data()
layerType, srcMAC, dstMAC, ok := classifyPacket(data)
if !ok {
if !validPacket(data) {
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
return true
}
gen := e.verdictsGen.Load()
index := streamID % uint32(len(e.workers))
e.workers[index].Feed(&workerPacket{
wp := &workerPacket{
StreamID: streamID,
Data: data,
LayerType: layerType,
SrcMAC: srcMAC,
DstMAC: dstMAC,
SetVerdict: func(v io.Verdict, b []byte) error {
if v == io.VerdictAcceptStream || v == io.VerdictDropStream {
e.verdicts.Store(streamID, verdictEntry{Verdict: v, Gen: gen})
}
return e.io.SetVerdict(p, v, b)
},
})
}
if !e.workers[index].Feed(wp) {
select {
case e.overflowCh <- wp:
default:
}
}
return true
}
// classifyPacket detects packet framing and returns a gopacket decode layer
// plus best-effort source/destination MAC addresses when available.
func classifyPacket(data []byte) (gopacket.LayerType, []byte, []byte, bool) {
func validPacket(data []byte) bool {
if len(data) == 0 {
return 0, nil, nil, false
return false
}
// Fast path for IP packets (NFQUEUE payloads are typically IP-only).
ipVersion := data[0] >> 4
if ipVersion == 4 {
return layers.LayerTypeIPv4, nil, nil, true
if ipVersion == 4 || ipVersion == 6 {
return true
}
if ipVersion == 6 {
return layers.LayerTypeIPv6, nil, nil, true
}
// Ethernet frame path (for custom PacketIO implementations).
if len(data) >= 14 {
etherType := binary.BigEndian.Uint16(data[12:14])
if etherType == uint16(layers.EthernetTypeIPv4) || etherType == uint16(layers.EthernetTypeIPv6) {
return layers.LayerTypeEthernet,
append([]byte(nil), data[6:12]...),
append([]byte(nil), data[:6]...),
true
etherType := uint16(data[12])<<8 | uint16(data[13])
if etherType == 0x0800 || etherType == 0x86DD {
return true
}
}
return false
}
func (e *engine) drainOverflow(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case wp := <-e.overflowCh:
_ = wp.SetVerdict(io.VerdictAccept, nil)
}
}
return 0, nil, nil, false
}
+91
View File
@@ -0,0 +1,91 @@
package engine
import "net"
type L3Info struct {
Version uint8
Protocol uint8
IHL uint8
SrcIP [4]byte
DstIP [4]byte
Length uint16
}
func (i L3Info) SrcIPAddr() net.IP { return net.IP(i.SrcIP[:]) }
func (i L3Info) DstIPAddr() net.IP { return net.IP(i.DstIP[:]) }
type TCPInfo struct {
SrcPort uint16
DstPort uint16
Seq uint32
Ack uint32
HdrLen uint8
SYN bool
FIN bool
RST bool
ACK bool
}
type UDPInfo struct {
SrcPort uint16
DstPort uint16
}
func ParseL3(data []byte) (l3 L3Info, transport []byte, ok bool) {
if len(data) < 20 {
return
}
version := data[0] >> 4
if version != 4 {
return
}
ihl := data[0] & 0x0F
if ihl < 5 || len(data) < int(ihl)*4 {
return
}
totalLen := int(uint16(data[2])<<8 | uint16(data[3]))
if totalLen < int(ihl)*4 || totalLen > len(data) {
totalLen = len(data)
}
return L3Info{
Version: 4,
Protocol: data[9],
IHL: ihl,
Length: uint16(totalLen),
SrcIP: [4]byte{data[12], data[13], data[14], data[15]},
DstIP: [4]byte{data[16], data[17], data[18], data[19]},
}, data[ihl*4:totalLen], true
}
func ParseTCP(transport []byte) (TCPInfo, []byte, bool) {
if len(transport) < 20 {
return TCPInfo{}, nil, false
}
dataOff := uint8(transport[12]>>4) * 4
if dataOff < 20 || len(transport) < int(dataOff) {
return TCPInfo{}, nil, false
}
flags := transport[13]
payloadLen := len(transport) - int(dataOff)
return TCPInfo{
SrcPort: uint16(transport[0])<<8 | uint16(transport[1]),
DstPort: uint16(transport[2])<<8 | uint16(transport[3]),
Seq: uint32(transport[4])<<24 | uint32(transport[5])<<16 | uint32(transport[6])<<8 | uint32(transport[7]),
Ack: uint32(transport[8])<<24 | uint32(transport[9])<<16 | uint32(transport[10])<<8 | uint32(transport[11]),
HdrLen: dataOff,
SYN: flags&0x02 != 0,
FIN: flags&0x01 != 0,
RST: flags&0x04 != 0,
ACK: flags&0x10 != 0,
}, transport[dataOff : dataOff+uint8(payloadLen)], true
}
func ParseUDP(transport []byte) (UDPInfo, []byte, bool) {
if len(transport) < 8 {
return UDPInfo{}, nil, false
}
return UDPInfo{
SrcPort: uint16(transport[0])<<8 | uint16(transport[1]),
DstPort: uint16(transport[2])<<8 | uint16(transport[3]),
}, transport[8:], true
}
-262
View File
@@ -1,262 +0,0 @@
package engine
import (
"net"
"sync"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/reassembly"
)
// tcpVerdict is a subset of io.Verdict for TCP streams.
// We don't allow modifying or dropping a single packet
// for TCP streams for now, as it doesn't make much sense.
type tcpVerdict io.Verdict
const (
tcpVerdictAccept = tcpVerdict(io.VerdictAccept)
tcpVerdictAcceptStream = tcpVerdict(io.VerdictAcceptStream)
tcpVerdictDropStream = tcpVerdict(io.VerdictDropStream)
)
type tcpContext struct {
*gopacket.PacketMetadata
Verdict tcpVerdict
SrcMAC, DstMAC net.HardwareAddr
}
func (ctx *tcpContext) GetCaptureInfo() gopacket.CaptureInfo {
return ctx.CaptureInfo
}
type tcpStreamFactory struct {
WorkerID int
Logger Logger
Node *snowflake.Node
RulesetMutex sync.RWMutex
Ruleset ruleset.Ruleset
RulesetVersion uint64
}
func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream {
id := f.Node.Generate()
ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw())
ctx := ac.(*tcpContext)
info := ruleset.StreamInfo{
ID: id.Int64(),
Protocol: ruleset.ProtocolTCP,
SrcMAC: append(net.HardwareAddr(nil), ctx.SrcMAC...),
DstMAC: append(net.HardwareAddr(nil), ctx.DstMAC...),
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: uint16(tcp.SrcPort),
DstPort: uint16(tcp.DstPort),
Props: make(analyzer.CombinedPropMap),
}
f.Logger.TCPStreamNew(f.WorkerID, info)
rs, version := f.currentRuleset()
var ans []analyzer.TCPAnalyzer
if rs != nil {
ans = analyzersToTCPAnalyzers(rs.Analyzers(info))
}
// Create entries for each analyzer
entries := make([]*tcpStreamEntry, 0, len(ans))
for _, a := range ans {
entries = append(entries, &tcpStreamEntry{
Name: a.Name(),
Stream: a.NewTCP(analyzer.TCPInfo{
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: uint16(tcp.SrcPort),
DstPort: uint16(tcp.DstPort),
}, &analyzerLogger{
StreamID: id.Int64(),
Name: a.Name(),
Logger: f.Logger,
}),
HasLimit: a.Limit() > 0,
Quota: a.Limit(),
})
}
return &tcpStream{
info: info,
virgin: true,
logger: f.Logger,
rulesetVersion: version,
rulesetSource: f.currentRuleset,
activeEntries: entries,
}
}
func (f *tcpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error {
f.RulesetMutex.Lock()
defer f.RulesetMutex.Unlock()
f.Ruleset = r
f.RulesetVersion++
return nil
}
func (f *tcpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) {
f.RulesetMutex.RLock()
defer f.RulesetMutex.RUnlock()
return f.Ruleset, f.RulesetVersion
}
type tcpStream struct {
info ruleset.StreamInfo
virgin bool // true if no packets have been processed
logger Logger
rulesetVersion uint64
rulesetSource func() (ruleset.Ruleset, uint64)
activeEntries []*tcpStreamEntry
doneEntries []*tcpStreamEntry
lastVerdict tcpVerdict
}
type tcpStreamEntry struct {
Name string
Stream analyzer.TCPStream
HasLimit bool
Quota int
}
func (s *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool {
if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() {
// Make sure every stream matches against the ruleset at least once,
// even if there are no activeEntries, as the ruleset may have built-in
// properties that need to be matched.
return true
} else {
ctx := ac.(*tcpContext)
ctx.Verdict = s.lastVerdict
return false
}
}
func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) {
dir, start, end, skip := sg.Info()
rev := dir == reassembly.TCPDirServerToClient
avail, _ := sg.Lengths()
data := sg.Fetch(avail)
updated := false
for i := len(s.activeEntries) - 1; i >= 0; i-- {
// Important: reverse order so we can remove entries
entry := s.activeEntries[i]
update, closeUpdate, done := s.feedEntry(entry, rev, start, end, skip, data)
up1 := processPropUpdate(s.info.Props, entry.Name, update)
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
updated = updated || up1 || up2
if done {
s.activeEntries = append(s.activeEntries[:i], s.activeEntries[i+1:]...)
s.doneEntries = append(s.doneEntries, entry)
}
}
ctx := ac.(*tcpContext)
rs, version := s.currentRuleset()
rulesetChanged := version != s.rulesetVersion
s.rulesetVersion = version
if updated || s.virgin || rulesetChanged {
s.virgin = false
s.logger.TCPStreamPropUpdate(s.info, false)
// Match properties against ruleset
result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
if rs != nil {
result = rs.Match(s.info)
}
action := result.Action
if action != ruleset.ActionMaybe && action != ruleset.ActionModify {
verdict := actionToTCPVerdict(action)
s.lastVerdict = verdict
ctx.Verdict = verdict
s.logger.TCPStreamAction(s.info, action, false)
// Verdict issued, no need to process any more packets
s.closeActiveEntries()
}
}
if len(s.activeEntries) == 0 && ctx.Verdict == tcpVerdictAccept {
// All entries are done but no verdict issued, accept stream
s.lastVerdict = tcpVerdictAcceptStream
ctx.Verdict = tcpVerdictAcceptStream
s.logger.TCPStreamAction(s.info, ruleset.ActionAllow, true)
}
}
func (s *tcpStream) currentRuleset() (ruleset.Ruleset, uint64) {
if s.rulesetSource == nil {
return nil, s.rulesetVersion
}
return s.rulesetSource()
}
func (s *tcpStream) rulesetChanged() bool {
_, version := s.currentRuleset()
return version != s.rulesetVersion
}
func (s *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool {
s.closeActiveEntries()
return true
}
func (s *tcpStream) closeActiveEntries() {
// Signal close to all active entries & move them to doneEntries
updated := false
for _, entry := range s.activeEntries {
update := entry.Stream.Close(false)
up := processPropUpdate(s.info.Props, entry.Name, update)
updated = updated || up
}
if updated {
s.logger.TCPStreamPropUpdate(s.info, true)
}
s.doneEntries = append(s.doneEntries, s.activeEntries...)
s.activeEntries = nil
}
func (s *tcpStream) feedEntry(entry *tcpStreamEntry, rev, start, end bool, skip int, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
if !entry.HasLimit {
update, done = entry.Stream.Feed(rev, start, end, skip, data)
} else {
qData := data
if len(qData) > entry.Quota {
qData = qData[:entry.Quota]
}
update, done = entry.Stream.Feed(rev, start, end, skip, qData)
entry.Quota -= len(qData)
if entry.Quota <= 0 {
// Quota exhausted, signal close & move to doneEntries
closeUpdate = entry.Stream.Close(true)
done = true
}
}
return
}
func analyzersToTCPAnalyzers(ans []analyzer.Analyzer) []analyzer.TCPAnalyzer {
tcpAns := make([]analyzer.TCPAnalyzer, 0, len(ans))
for _, a := range ans {
if tcpM, ok := a.(analyzer.TCPAnalyzer); ok {
tcpAns = append(tcpAns, tcpM)
}
}
return tcpAns
}
func actionToTCPVerdict(a ruleset.Action) tcpVerdict {
switch a {
case ruleset.ActionMaybe, ruleset.ActionAllow, ruleset.ActionModify:
return tcpVerdictAcceptStream
case ruleset.ActionBlock, ruleset.ActionDrop:
return tcpVerdictDropStream
default:
// Should never happen
return tcpVerdictAcceptStream
}
}
+302
View File
@@ -0,0 +1,302 @@
package engine
import (
"net"
"sync"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
)
const tcpFlowMaxBuffer = 16384
type tcpFlowDirection uint8
const (
tcpDirC2S tcpFlowDirection = iota
tcpDirS2C
)
type tcpFlow struct {
streamID uint32
srcIP [4]byte
dstIP [4]byte
srcPort uint16
dstPort uint16
dirSeq [2]uint32
dirBuf [2][]byte
info ruleset.StreamInfo
virgin bool
logger Logger
rulesetVersion uint64
rulesetSource func() (ruleset.Ruleset, uint64)
activeEntries []*tcpFlowEntry
doneEntries []*tcpFlowEntry
lastVerdict io.Verdict
feedCalled [2]bool
}
type tcpFlowEntry struct {
Name string
Stream analyzer.TCPStream
HasLimit bool
Quota int
}
func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
if f.rulesetChanged() || f.virgin {
f.virgin = false
return io.VerdictAccept
}
if len(f.activeEntries) == 0 {
return f.lastVerdict
}
dir, rev := f.resolveDirection(tcp)
if tcp.RST || tcp.FIN {
f.closeActiveEntries()
f.maybeFinalizeVerdict()
return f.lastVerdict
}
if len(payload) == 0 {
return io.VerdictAccept
}
expected := f.dirSeq[dir]
if f.feedCalled[dir] && expected != 0 && tcp.Seq != expected {
return io.VerdictAccept
}
f.feedCalled[dir] = true
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
if len(f.dirBuf[dir]) > tcpFlowMaxBuffer {
return io.VerdictAccept
}
updated := false
for i := len(f.activeEntries) - 1; i >= 0; i-- {
entry := f.activeEntries[i]
update, closeUpdate, done := feedFlowEntry(entry, rev, f.dirBuf[dir])
u1 := processPropUpdate(f.info.Props, entry.Name, update)
u2 := processPropUpdate(f.info.Props, entry.Name, closeUpdate)
updated = updated || u1 || u2
if done {
f.activeEntries = append(f.activeEntries[:i], f.activeEntries[i+1:]...)
f.doneEntries = append(f.doneEntries, entry)
}
}
if updated {
f.logger.TCPStreamPropUpdate(f.info, false)
rs, version := f.currentRuleset()
f.rulesetVersion = version
result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
if rs != nil {
result = rs.Match(f.info)
}
action := result.Action
if action != ruleset.ActionMaybe && action != ruleset.ActionModify {
verdict := actionToTCPVerdict(action)
f.lastVerdict = verdict
f.closeActiveEntries()
f.logger.TCPStreamAction(f.info, action, false)
return verdict
}
}
f.maybeFinalizeVerdict()
return f.lastVerdict
}
func (f *tcpFlow) maybeFinalizeVerdict() {
if len(f.activeEntries) == 0 && f.lastVerdict == io.VerdictAccept {
f.lastVerdict = io.VerdictAcceptStream
f.logger.TCPStreamAction(f.info, ruleset.ActionAllow, true)
}
}
func (f *tcpFlow) resolveDirection(tcp TCPInfo) (dir uint8, rev bool) {
if tcp.SrcPort == f.srcPort {
return uint8(tcpDirC2S), false
}
return uint8(tcpDirS2C), true
}
func (f *tcpFlow) currentRuleset() (ruleset.Ruleset, uint64) {
if f.rulesetSource == nil {
return nil, f.rulesetVersion
}
return f.rulesetSource()
}
func (f *tcpFlow) rulesetChanged() bool {
_, version := f.currentRuleset()
return version != f.rulesetVersion
}
func (f *tcpFlow) closeActiveEntries() {
updated := false
for _, entry := range f.activeEntries {
update := entry.Stream.Close(false)
updated = updated || processPropUpdate(f.info.Props, entry.Name, update)
}
if updated {
f.logger.TCPStreamPropUpdate(f.info, true)
}
f.doneEntries = append(f.doneEntries, f.activeEntries...)
f.activeEntries = nil
}
type tcpFlowManager struct {
mu sync.Mutex
flows map[uint32]*tcpFlow
sfNode *snowflake.Node
logger Logger
rulesetSource func() (ruleset.Ruleset, uint64)
workerID int
macResolver *sourceMACResolver
}
func newTCPFlowManager(workerID int, logger Logger, macResolver *sourceMACResolver, node *snowflake.Node) *tcpFlowManager {
return &tcpFlowManager{
flows: make(map[uint32]*tcpFlow),
sfNode: node,
logger: logger,
workerID: workerID,
macResolver: macResolver,
}
}
func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) io.Verdict {
m.mu.Lock()
flow, ok := m.flows[streamID]
if !ok {
flow = m.createFlow(streamID, l3, tcp, srcMAC, dstMAC)
m.flows[streamID] = flow
}
m.mu.Unlock()
verdict := flow.feed(l3, tcp, payload)
if verdict == io.VerdictAcceptStream || verdict == io.VerdictDropStream || tcp.RST || tcp.FIN {
m.mu.Lock()
delete(m.flows, streamID)
m.mu.Unlock()
}
return verdict
}
func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, srcMAC, dstMAC net.HardwareAddr) *tcpFlow {
id := m.sfNode.Generate()
ipSrc := net.IP(l3.SrcIP[:])
ipDst := net.IP(l3.DstIP[:])
if len(srcMAC) == 0 && m.macResolver != nil {
srcMAC = m.macResolver.Resolve(ipSrc)
}
info := ruleset.StreamInfo{
ID: id.Int64(),
Protocol: ruleset.ProtocolTCP,
SrcMAC: append(net.HardwareAddr(nil), srcMAC...),
DstMAC: append(net.HardwareAddr(nil), dstMAC...),
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: tcp.SrcPort,
DstPort: tcp.DstPort,
Props: make(analyzer.CombinedPropMap),
}
m.logger.TCPStreamNew(m.workerID, info)
rs, version := m.rulesetSource()
var ans []analyzer.TCPAnalyzer
if rs != nil {
ans = analyzersToTCPAnalyzers(rs.Analyzers(info))
}
entries := make([]*tcpFlowEntry, 0, len(ans))
for _, a := range ans {
entries = append(entries, &tcpFlowEntry{
Name: a.Name(),
Stream: a.NewTCP(analyzer.TCPInfo{
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: tcp.SrcPort,
DstPort: tcp.DstPort,
}, &analyzerLogger{
StreamID: id.Int64(),
Name: a.Name(),
Logger: m.logger,
}),
HasLimit: a.Limit() > 0,
Quota: a.Limit(),
})
}
flow := &tcpFlow{
streamID: streamID,
srcIP: l3.SrcIP,
dstIP: l3.DstIP,
srcPort: tcp.SrcPort,
dstPort: tcp.DstPort,
info: info,
virgin: true,
logger: m.logger,
rulesetSource: m.rulesetSource,
rulesetVersion: version,
activeEntries: entries,
lastVerdict: io.VerdictAccept,
}
flow.dirSeq[tcpDirC2S] = tcp.Seq + 1
return flow
}
func (m *tcpFlowManager) updateRuleset(r ruleset.Ruleset, version uint64) {
m.rulesetSource = func() (ruleset.Ruleset, uint64) {
return r, version
}
}
func feedFlowEntry(entry *tcpFlowEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
if !entry.HasLimit {
update, done = entry.Stream.Feed(rev, true, false, 0, data)
} else {
qData := data
if len(qData) > entry.Quota {
qData = qData[:entry.Quota]
}
update, done = entry.Stream.Feed(rev, true, false, 0, qData)
entry.Quota -= len(qData)
if entry.Quota <= 0 {
closeUpdate = entry.Stream.Close(true)
done = true
}
}
return
}
func analyzersToTCPAnalyzers(ans []analyzer.Analyzer) []analyzer.TCPAnalyzer {
tcpAns := make([]analyzer.TCPAnalyzer, 0, len(ans))
for _, a := range ans {
if ta, ok := a.(analyzer.TCPAnalyzer); ok {
tcpAns = append(tcpAns, ta)
}
}
return tcpAns
}
func actionToTCPVerdict(a ruleset.Action) io.Verdict {
switch a {
case ruleset.ActionMaybe, ruleset.ActionAllow, ruleset.ActionModify:
return io.VerdictAcceptStream
case ruleset.ActionBlock, ruleset.ActionDrop:
return io.VerdictDropStream
default:
return io.VerdictAcceptStream
}
}
+106 -90
View File
@@ -10,20 +10,13 @@ import (
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/reassembly"
)
const (
defaultChanSize = 64
defaultTCPMaxBufferedPagesTotal = 4096
defaultTCPMaxBufferedPagesPerConnection = 64
defaultUDPMaxStreams = 4096
)
var _ Engine = (*engine)(nil)
type workerPacket struct {
StreamID uint32
Data []byte
LayerType gopacket.LayerType
SrcMAC net.HardwareAddr
DstMAC net.HardwareAddr
SetVerdict func(io.Verdict, []byte) error
@@ -35,12 +28,8 @@ type worker struct {
logger Logger
macResolver *sourceMACResolver
tcpStreamFactory *tcpStreamFactory
tcpStreamPool *reassembly.StreamPool
tcpAssembler *reassembly.Assembler
udpStreamFactory *udpStreamFactory
udpStreamManager *udpStreamManager
tcpFlowMgr *tcpFlowManager
udpSM *udpStreamManager
modSerializeBuffer gopacket.SerializeBuffer
}
@@ -51,23 +40,17 @@ type workerConfig struct {
Logger Logger
Ruleset ruleset.Ruleset
MACResolver *sourceMACResolver
TCPMaxBufferedPagesTotal int
TCPMaxBufferedPagesPerConn int
TCPMaxBufferedPagesTotal int // unused, kept for config compat
TCPMaxBufferedPagesPerConn int // unused, kept for config compat
UDPMaxStreams int
}
func (c *workerConfig) fillDefaults() {
if c.ChanSize <= 0 {
c.ChanSize = defaultChanSize
}
if c.TCPMaxBufferedPagesTotal <= 0 {
c.TCPMaxBufferedPagesTotal = defaultTCPMaxBufferedPagesTotal
}
if c.TCPMaxBufferedPagesPerConn <= 0 {
c.TCPMaxBufferedPagesPerConn = defaultTCPMaxBufferedPagesPerConnection
c.ChanSize = 64
}
if c.UDPMaxStreams <= 0 {
c.UDPMaxStreams = defaultUDPMaxStreams
c.UDPMaxStreams = 4096
}
}
@@ -77,16 +60,12 @@ func newWorker(config workerConfig) (*worker, error) {
if err != nil {
return nil, err
}
tcpSF := &tcpStreamFactory{
WorkerID: config.ID,
Logger: config.Logger,
Node: sfNode,
Ruleset: config.Ruleset,
tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode)
if config.Ruleset != nil {
tcpMgr.updateRuleset(config.Ruleset, 0)
}
tcpStreamPool := reassembly.NewStreamPool(tcpSF)
tcpAssembler := reassembly.NewAssembler(tcpStreamPool)
tcpAssembler.MaxBufferedPagesTotal = config.TCPMaxBufferedPagesTotal
tcpAssembler.MaxBufferedPagesPerConnection = config.TCPMaxBufferedPagesPerConn
udpSF := &udpStreamFactory{
WorkerID: config.ID,
Logger: config.Logger,
@@ -97,25 +76,24 @@ func newWorker(config workerConfig) (*worker, error) {
if err != nil {
return nil, err
}
return &worker{
id: config.ID,
packetChan: make(chan *workerPacket, config.ChanSize),
logger: config.Logger,
macResolver: config.MACResolver,
tcpStreamFactory: tcpSF,
tcpStreamPool: tcpStreamPool,
tcpAssembler: tcpAssembler,
udpStreamFactory: udpSF,
udpStreamManager: udpSM,
tcpFlowMgr: tcpMgr,
udpSM: udpSM,
modSerializeBuffer: gopacket.NewSerializeBuffer(),
}, nil
}
func (w *worker) Feed(p *workerPacket) {
func (w *worker) Feed(p *workerPacket) bool {
select {
case w.packetChan <- p:
return true
default:
_ = p.SetVerdict(io.VerdictAccept, nil)
return false
}
}
@@ -126,78 +104,116 @@ func (w *worker) Run(ctx context.Context) {
select {
case <-ctx.Done():
return
case wPkt := <-w.packetChan:
if wPkt == nil {
case wp := <-w.packetChan:
if wp == nil {
return
}
pkt := gopacket.NewPacket(wPkt.Data, wPkt.LayerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
v, b := w.handle(wPkt.StreamID, pkt, wPkt.SrcMAC, wPkt.DstMAC)
_ = wPkt.SetVerdict(v, b)
v, b := w.handle(wp)
_ = wp.SetVerdict(v, b)
}
}
}
func (w *worker) UpdateRuleset(r ruleset.Ruleset) error {
if err := w.tcpStreamFactory.UpdateRuleset(r); err != nil {
return err
}
return w.udpStreamFactory.UpdateRuleset(r)
w.tcpFlowMgr.updateRuleset(r, 0)
return w.udpSM.factory.UpdateRuleset(r)
}
func (w *worker) handle(streamID uint32, p gopacket.Packet, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) {
netLayer, trLayer := p.NetworkLayer(), p.TransportLayer()
if netLayer == nil || trLayer == nil {
// Invalid packet
func (w *worker) handle(wp *workerPacket) (io.Verdict, []byte) {
data := wp.Data
if len(data) == 0 {
return io.VerdictAccept, nil
}
ipFlow := netLayer.NetworkFlow()
if len(srcMAC) == 0 && w.macResolver != nil {
srcMAC = w.macResolver.Resolve(net.IP(ipFlow.Src().Raw()))
}
switch tr := trLayer.(type) {
case *layers.TCP:
return w.handleTCP(ipFlow, srcMAC, dstMAC, p.Metadata(), tr), nil
case *layers.UDP:
v, modPayload := w.handleUDP(streamID, ipFlow, srcMAC, dstMAC, tr)
if v == io.VerdictAcceptModify && modPayload != nil {
tr.Payload = modPayload
_ = tr.SetNetworkLayerForChecksum(netLayer)
_ = w.modSerializeBuffer.Clear()
err := gopacket.SerializePacket(w.modSerializeBuffer,
gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}, p)
if err != nil {
// Just accept without modification for now
ipVersion := data[0] >> 4
if ipVersion == 4 {
l3, transport, ok := ParseL3(data)
if !ok {
return io.VerdictAccept, nil
}
switch l3.Protocol {
case 6: // TCP
tcp, payload, ok := ParseTCP(transport)
if !ok {
return io.VerdictAccept, nil
}
return v, w.modSerializeBuffer.Bytes()
verdict := w.tcpFlowMgr.handle(
wp.StreamID, l3, tcp, payload,
wp.SrcMAC, wp.DstMAC,
)
return verdict, nil
case 17: // UDP
udp, payload, ok := ParseUDP(transport)
if !ok {
return io.VerdictAccept, nil
}
v, modPayload := w.handleUDP(
wp.StreamID, l3, udp, payload,
wp.SrcMAC, wp.DstMAC,
)
if v == io.VerdictAcceptModify && modPayload != nil {
return w.serializeModifiedUDP(data, l3, udp, transport, modPayload)
}
return v, nil
default:
return io.VerdictAccept, nil
}
return v, nil
default:
// Unsupported protocol
}
// Ethernet frame path (for custom PacketIO)
if ipVersion == 6 {
// TODO: IPv6 support with raw parsing
return io.VerdictAccept, nil
}
return io.VerdictAccept, nil
}
func (w *worker) handleTCP(ipFlow gopacket.Flow, srcMAC, dstMAC net.HardwareAddr, pMeta *gopacket.PacketMetadata, tcp *layers.TCP) io.Verdict {
ctx := &tcpContext{
PacketMetadata: pMeta,
Verdict: tcpVerdictAccept,
SrcMAC: srcMAC,
DstMAC: dstMAC,
func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) {
ipSrc := net.IP(l3.SrcIP[:])
ipDst := net.IP(l3.DstIP[:])
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, ipSrc.To4(), ipDst.To4())
udpFlow := gopacket.NewFlow(layers.EndpointUDPPort, []byte{byte(udp.SrcPort >> 8), byte(udp.SrcPort)}, []byte{byte(udp.DstPort >> 8), byte(udp.DstPort)})
if len(srcMAC) == 0 && w.macResolver != nil {
srcMAC = w.macResolver.Resolve(ipSrc)
}
w.tcpAssembler.AssembleWithContext(ipFlow, tcp, ctx)
return io.Verdict(ctx.Verdict)
}
func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, srcMAC, dstMAC net.HardwareAddr, udp *layers.UDP) (io.Verdict, []byte) {
ctx := &udpContext{
uc := &udpContext{
Verdict: udpVerdictAccept,
SrcMAC: srcMAC,
DstMAC: dstMAC,
}
w.udpStreamManager.MatchWithContext(streamID, ipFlow, udp, ctx)
return io.Verdict(ctx.Verdict), ctx.Packet
// Temporarily set payload on a UDP layer so existing UDP handling works
// We pass the payload through the context
w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{
BaseLayer: layers.BaseLayer{Payload: payload},
SrcPort: layers.UDPPort(udp.SrcPort),
DstPort: layers.UDPPort(udp.DstPort),
}, uc)
return io.Verdict(uc.Verdict), uc.Packet
}
func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, udp UDPInfo, transport []byte, modPayload []byte) (io.Verdict, []byte) {
ipPkt := gopacket.NewPacket(fullData, layers.LayerTypeIPv4, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
netLayer := ipPkt.NetworkLayer()
trLayer := ipPkt.TransportLayer()
if netLayer == nil || trLayer == nil {
return io.VerdictAccept, nil
}
udpLayer, ok := trLayer.(*layers.UDP)
if !ok {
return io.VerdictAccept, nil
}
udpLayer.Payload = modPayload
_ = udpLayer.SetNetworkLayerForChecksum(netLayer)
_ = w.modSerializeBuffer.Clear()
err := gopacket.SerializePacket(w.modSerializeBuffer,
gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}, ipPkt)
if err != nil {
return io.VerdictAccept, nil
}
return io.VerdictAcceptModify, w.modSerializeBuffer.Bytes()
}