refactor: engine/tcp/worker perf improvements

This commit is contained in:
2026-05-12 15:16:11 +00:00
parent dc16b979e7
commit ecc2cde1c2
9 changed files with 743 additions and 546 deletions
+7 -5
View File
@@ -44,11 +44,13 @@ func New(cfg Config, opts Options) (*App, error) {
ownsIO := false
if packetIO == nil {
packetIO, err = gfwio.NewNFQueuePacketIO(gfwio.NFQueuePacketIOConfig{
QueueSize: cfg.IO.QueueSize,
ReadBuffer: cfg.IO.ReadBuffer,
WriteBuffer: cfg.IO.WriteBuffer,
Local: cfg.IO.Local,
RST: cfg.IO.RST,
QueueSize: cfg.IO.QueueSize,
ReadBuffer: cfg.IO.ReadBuffer,
WriteBuffer: cfg.IO.WriteBuffer,
Local: cfg.IO.Local,
RST: cfg.IO.RST,
NumQueues: cfg.IO.NumQueues,
MaxPacketLen: cfg.IO.MaxPacketLen,
})
if err != nil {
return nil, ConfigError{Field: "io", Err: err}
+7 -5
View File
@@ -17,11 +17,13 @@ type Config struct {
// IOConfig configures packet IO.
type IOConfig struct {
QueueSize uint32 `mapstructure:"queueSize" yaml:"queueSize"`
ReadBuffer int `mapstructure:"rcvBuf" yaml:"rcvBuf"`
WriteBuffer int `mapstructure:"sndBuf" yaml:"sndBuf"`
Local bool `mapstructure:"local" yaml:"local"`
RST bool `mapstructure:"rst" yaml:"rst"`
QueueSize uint32 `mapstructure:"queueSize" yaml:"queueSize"`
ReadBuffer int `mapstructure:"rcvBuf" yaml:"rcvBuf"`
WriteBuffer int `mapstructure:"sndBuf" yaml:"sndBuf"`
Local bool `mapstructure:"local" yaml:"local"`
RST bool `mapstructure:"rst" yaml:"rst"`
NumQueues int `mapstructure:"numQueues" yaml:"numQueues"`
MaxPacketLen uint32 `mapstructure:"maxPacketLen" yaml:"maxPacketLen"`
// PacketIO overrides NFQueue creation when set.
// When provided, App.Close will call PacketIO.Close.
+40 -37
View File
@@ -2,16 +2,11 @@ package engine
import (
"context"
"encoding/binary"
"runtime"
"sync"
"sync/atomic"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
var _ Engine = (*engine)(nil)
@@ -27,12 +22,15 @@ type engine struct {
workers []*worker
verdicts sync.Map // streamID(uint32) → verdictEntry
verdictsGen atomic.Int64 // incremented on ruleset update
overflowCh chan *workerPacket
overflowOnce sync.Once
}
func NewEngine(config Config) (Engine, error) {
workerCount := config.Workers
if workerCount <= 0 {
workerCount = runtime.NumCPU()
workerCount = 1
}
macResolver := newSourceMACResolver()
var err error
@@ -53,9 +51,10 @@ func NewEngine(config Config) (Engine, error) {
}
}
e := &engine{
logger: config.Logger,
io: config.IO,
workers: workers,
logger: config.Logger,
io: config.IO,
workers: workers,
overflowCh: make(chan *workerPacket, 1024),
}
return e, nil
}
@@ -75,6 +74,10 @@ func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel()
e.overflowOnce.Do(func() {
go e.drainOverflow(ioCtx)
})
for _, w := range e.workers {
go w.Run(ioCtx)
}
@@ -111,55 +114,55 @@ func (e *engine) dispatch(p io.Packet) bool {
}
data := p.Data()
layerType, srcMAC, dstMAC, ok := classifyPacket(data)
if !ok {
if !validPacket(data) {
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
return true
}
gen := e.verdictsGen.Load()
index := streamID % uint32(len(e.workers))
e.workers[index].Feed(&workerPacket{
wp := &workerPacket{
StreamID: streamID,
Data: data,
LayerType: layerType,
SrcMAC: srcMAC,
DstMAC: dstMAC,
SetVerdict: func(v io.Verdict, b []byte) error {
if v == io.VerdictAcceptStream || v == io.VerdictDropStream {
e.verdicts.Store(streamID, verdictEntry{Verdict: v, Gen: gen})
}
return e.io.SetVerdict(p, v, b)
},
})
}
if !e.workers[index].Feed(wp) {
select {
case e.overflowCh <- wp:
default:
}
}
return true
}
// classifyPacket detects packet framing and returns a gopacket decode layer
// plus best-effort source/destination MAC addresses when available.
func classifyPacket(data []byte) (gopacket.LayerType, []byte, []byte, bool) {
func validPacket(data []byte) bool {
if len(data) == 0 {
return 0, nil, nil, false
return false
}
// Fast path for IP packets (NFQUEUE payloads are typically IP-only).
ipVersion := data[0] >> 4
if ipVersion == 4 {
return layers.LayerTypeIPv4, nil, nil, true
if ipVersion == 4 || ipVersion == 6 {
return true
}
if ipVersion == 6 {
return layers.LayerTypeIPv6, nil, nil, true
}
// Ethernet frame path (for custom PacketIO implementations).
if len(data) >= 14 {
etherType := binary.BigEndian.Uint16(data[12:14])
if etherType == uint16(layers.EthernetTypeIPv4) || etherType == uint16(layers.EthernetTypeIPv6) {
return layers.LayerTypeEthernet,
append([]byte(nil), data[6:12]...),
append([]byte(nil), data[:6]...),
true
etherType := uint16(data[12])<<8 | uint16(data[13])
if etherType == 0x0800 || etherType == 0x86DD {
return true
}
}
return false
}
func (e *engine) drainOverflow(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case wp := <-e.overflowCh:
_ = wp.SetVerdict(io.VerdictAccept, nil)
}
}
return 0, nil, nil, false
}
+91
View File
@@ -0,0 +1,91 @@
package engine
import "net"
type L3Info struct {
Version uint8
Protocol uint8
IHL uint8
SrcIP [4]byte
DstIP [4]byte
Length uint16
}
func (i L3Info) SrcIPAddr() net.IP { return net.IP(i.SrcIP[:]) }
func (i L3Info) DstIPAddr() net.IP { return net.IP(i.DstIP[:]) }
type TCPInfo struct {
SrcPort uint16
DstPort uint16
Seq uint32
Ack uint32
HdrLen uint8
SYN bool
FIN bool
RST bool
ACK bool
}
type UDPInfo struct {
SrcPort uint16
DstPort uint16
}
func ParseL3(data []byte) (l3 L3Info, transport []byte, ok bool) {
if len(data) < 20 {
return
}
version := data[0] >> 4
if version != 4 {
return
}
ihl := data[0] & 0x0F
if ihl < 5 || len(data) < int(ihl)*4 {
return
}
totalLen := int(uint16(data[2])<<8 | uint16(data[3]))
if totalLen < int(ihl)*4 || totalLen > len(data) {
totalLen = len(data)
}
return L3Info{
Version: 4,
Protocol: data[9],
IHL: ihl,
Length: uint16(totalLen),
SrcIP: [4]byte{data[12], data[13], data[14], data[15]},
DstIP: [4]byte{data[16], data[17], data[18], data[19]},
}, data[ihl*4:totalLen], true
}
func ParseTCP(transport []byte) (TCPInfo, []byte, bool) {
if len(transport) < 20 {
return TCPInfo{}, nil, false
}
dataOff := uint8(transport[12]>>4) * 4
if dataOff < 20 || len(transport) < int(dataOff) {
return TCPInfo{}, nil, false
}
flags := transport[13]
payloadLen := len(transport) - int(dataOff)
return TCPInfo{
SrcPort: uint16(transport[0])<<8 | uint16(transport[1]),
DstPort: uint16(transport[2])<<8 | uint16(transport[3]),
Seq: uint32(transport[4])<<24 | uint32(transport[5])<<16 | uint32(transport[6])<<8 | uint32(transport[7]),
Ack: uint32(transport[8])<<24 | uint32(transport[9])<<16 | uint32(transport[10])<<8 | uint32(transport[11]),
HdrLen: dataOff,
SYN: flags&0x02 != 0,
FIN: flags&0x01 != 0,
RST: flags&0x04 != 0,
ACK: flags&0x10 != 0,
}, transport[dataOff : dataOff+uint8(payloadLen)], true
}
func ParseUDP(transport []byte) (UDPInfo, []byte, bool) {
if len(transport) < 8 {
return UDPInfo{}, nil, false
}
return UDPInfo{
SrcPort: uint16(transport[0])<<8 | uint16(transport[1]),
DstPort: uint16(transport[2])<<8 | uint16(transport[3]),
}, transport[8:], true
}
-262
View File
@@ -1,262 +0,0 @@
package engine
import (
"net"
"sync"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/reassembly"
)
// tcpVerdict is a subset of io.Verdict for TCP streams.
// We don't allow modifying or dropping a single packet
// for TCP streams for now, as it doesn't make much sense.
type tcpVerdict io.Verdict
const (
tcpVerdictAccept = tcpVerdict(io.VerdictAccept)
tcpVerdictAcceptStream = tcpVerdict(io.VerdictAcceptStream)
tcpVerdictDropStream = tcpVerdict(io.VerdictDropStream)
)
type tcpContext struct {
*gopacket.PacketMetadata
Verdict tcpVerdict
SrcMAC, DstMAC net.HardwareAddr
}
func (ctx *tcpContext) GetCaptureInfo() gopacket.CaptureInfo {
return ctx.CaptureInfo
}
type tcpStreamFactory struct {
WorkerID int
Logger Logger
Node *snowflake.Node
RulesetMutex sync.RWMutex
Ruleset ruleset.Ruleset
RulesetVersion uint64
}
func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream {
id := f.Node.Generate()
ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw())
ctx := ac.(*tcpContext)
info := ruleset.StreamInfo{
ID: id.Int64(),
Protocol: ruleset.ProtocolTCP,
SrcMAC: append(net.HardwareAddr(nil), ctx.SrcMAC...),
DstMAC: append(net.HardwareAddr(nil), ctx.DstMAC...),
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: uint16(tcp.SrcPort),
DstPort: uint16(tcp.DstPort),
Props: make(analyzer.CombinedPropMap),
}
f.Logger.TCPStreamNew(f.WorkerID, info)
rs, version := f.currentRuleset()
var ans []analyzer.TCPAnalyzer
if rs != nil {
ans = analyzersToTCPAnalyzers(rs.Analyzers(info))
}
// Create entries for each analyzer
entries := make([]*tcpStreamEntry, 0, len(ans))
for _, a := range ans {
entries = append(entries, &tcpStreamEntry{
Name: a.Name(),
Stream: a.NewTCP(analyzer.TCPInfo{
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: uint16(tcp.SrcPort),
DstPort: uint16(tcp.DstPort),
}, &analyzerLogger{
StreamID: id.Int64(),
Name: a.Name(),
Logger: f.Logger,
}),
HasLimit: a.Limit() > 0,
Quota: a.Limit(),
})
}
return &tcpStream{
info: info,
virgin: true,
logger: f.Logger,
rulesetVersion: version,
rulesetSource: f.currentRuleset,
activeEntries: entries,
}
}
func (f *tcpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error {
f.RulesetMutex.Lock()
defer f.RulesetMutex.Unlock()
f.Ruleset = r
f.RulesetVersion++
return nil
}
func (f *tcpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) {
f.RulesetMutex.RLock()
defer f.RulesetMutex.RUnlock()
return f.Ruleset, f.RulesetVersion
}
type tcpStream struct {
info ruleset.StreamInfo
virgin bool // true if no packets have been processed
logger Logger
rulesetVersion uint64
rulesetSource func() (ruleset.Ruleset, uint64)
activeEntries []*tcpStreamEntry
doneEntries []*tcpStreamEntry
lastVerdict tcpVerdict
}
type tcpStreamEntry struct {
Name string
Stream analyzer.TCPStream
HasLimit bool
Quota int
}
func (s *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool {
if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() {
// Make sure every stream matches against the ruleset at least once,
// even if there are no activeEntries, as the ruleset may have built-in
// properties that need to be matched.
return true
} else {
ctx := ac.(*tcpContext)
ctx.Verdict = s.lastVerdict
return false
}
}
func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) {
dir, start, end, skip := sg.Info()
rev := dir == reassembly.TCPDirServerToClient
avail, _ := sg.Lengths()
data := sg.Fetch(avail)
updated := false
for i := len(s.activeEntries) - 1; i >= 0; i-- {
// Important: reverse order so we can remove entries
entry := s.activeEntries[i]
update, closeUpdate, done := s.feedEntry(entry, rev, start, end, skip, data)
up1 := processPropUpdate(s.info.Props, entry.Name, update)
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
updated = updated || up1 || up2
if done {
s.activeEntries = append(s.activeEntries[:i], s.activeEntries[i+1:]...)
s.doneEntries = append(s.doneEntries, entry)
}
}
ctx := ac.(*tcpContext)
rs, version := s.currentRuleset()
rulesetChanged := version != s.rulesetVersion
s.rulesetVersion = version
if updated || s.virgin || rulesetChanged {
s.virgin = false
s.logger.TCPStreamPropUpdate(s.info, false)
// Match properties against ruleset
result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
if rs != nil {
result = rs.Match(s.info)
}
action := result.Action
if action != ruleset.ActionMaybe && action != ruleset.ActionModify {
verdict := actionToTCPVerdict(action)
s.lastVerdict = verdict
ctx.Verdict = verdict
s.logger.TCPStreamAction(s.info, action, false)
// Verdict issued, no need to process any more packets
s.closeActiveEntries()
}
}
if len(s.activeEntries) == 0 && ctx.Verdict == tcpVerdictAccept {
// All entries are done but no verdict issued, accept stream
s.lastVerdict = tcpVerdictAcceptStream
ctx.Verdict = tcpVerdictAcceptStream
s.logger.TCPStreamAction(s.info, ruleset.ActionAllow, true)
}
}
func (s *tcpStream) currentRuleset() (ruleset.Ruleset, uint64) {
if s.rulesetSource == nil {
return nil, s.rulesetVersion
}
return s.rulesetSource()
}
func (s *tcpStream) rulesetChanged() bool {
_, version := s.currentRuleset()
return version != s.rulesetVersion
}
func (s *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool {
s.closeActiveEntries()
return true
}
func (s *tcpStream) closeActiveEntries() {
// Signal close to all active entries & move them to doneEntries
updated := false
for _, entry := range s.activeEntries {
update := entry.Stream.Close(false)
up := processPropUpdate(s.info.Props, entry.Name, update)
updated = updated || up
}
if updated {
s.logger.TCPStreamPropUpdate(s.info, true)
}
s.doneEntries = append(s.doneEntries, s.activeEntries...)
s.activeEntries = nil
}
func (s *tcpStream) feedEntry(entry *tcpStreamEntry, rev, start, end bool, skip int, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
if !entry.HasLimit {
update, done = entry.Stream.Feed(rev, start, end, skip, data)
} else {
qData := data
if len(qData) > entry.Quota {
qData = qData[:entry.Quota]
}
update, done = entry.Stream.Feed(rev, start, end, skip, qData)
entry.Quota -= len(qData)
if entry.Quota <= 0 {
// Quota exhausted, signal close & move to doneEntries
closeUpdate = entry.Stream.Close(true)
done = true
}
}
return
}
func analyzersToTCPAnalyzers(ans []analyzer.Analyzer) []analyzer.TCPAnalyzer {
tcpAns := make([]analyzer.TCPAnalyzer, 0, len(ans))
for _, a := range ans {
if tcpM, ok := a.(analyzer.TCPAnalyzer); ok {
tcpAns = append(tcpAns, tcpM)
}
}
return tcpAns
}
func actionToTCPVerdict(a ruleset.Action) tcpVerdict {
switch a {
case ruleset.ActionMaybe, ruleset.ActionAllow, ruleset.ActionModify:
return tcpVerdictAcceptStream
case ruleset.ActionBlock, ruleset.ActionDrop:
return tcpVerdictDropStream
default:
// Should never happen
return tcpVerdictAcceptStream
}
}
+302
View File
@@ -0,0 +1,302 @@
package engine
import (
"net"
"sync"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
)
const tcpFlowMaxBuffer = 16384
type tcpFlowDirection uint8
const (
tcpDirC2S tcpFlowDirection = iota
tcpDirS2C
)
type tcpFlow struct {
streamID uint32
srcIP [4]byte
dstIP [4]byte
srcPort uint16
dstPort uint16
dirSeq [2]uint32
dirBuf [2][]byte
info ruleset.StreamInfo
virgin bool
logger Logger
rulesetVersion uint64
rulesetSource func() (ruleset.Ruleset, uint64)
activeEntries []*tcpFlowEntry
doneEntries []*tcpFlowEntry
lastVerdict io.Verdict
feedCalled [2]bool
}
type tcpFlowEntry struct {
Name string
Stream analyzer.TCPStream
HasLimit bool
Quota int
}
func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
if f.rulesetChanged() || f.virgin {
f.virgin = false
return io.VerdictAccept
}
if len(f.activeEntries) == 0 {
return f.lastVerdict
}
dir, rev := f.resolveDirection(tcp)
if tcp.RST || tcp.FIN {
f.closeActiveEntries()
f.maybeFinalizeVerdict()
return f.lastVerdict
}
if len(payload) == 0 {
return io.VerdictAccept
}
expected := f.dirSeq[dir]
if f.feedCalled[dir] && expected != 0 && tcp.Seq != expected {
return io.VerdictAccept
}
f.feedCalled[dir] = true
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
if len(f.dirBuf[dir]) > tcpFlowMaxBuffer {
return io.VerdictAccept
}
updated := false
for i := len(f.activeEntries) - 1; i >= 0; i-- {
entry := f.activeEntries[i]
update, closeUpdate, done := feedFlowEntry(entry, rev, f.dirBuf[dir])
u1 := processPropUpdate(f.info.Props, entry.Name, update)
u2 := processPropUpdate(f.info.Props, entry.Name, closeUpdate)
updated = updated || u1 || u2
if done {
f.activeEntries = append(f.activeEntries[:i], f.activeEntries[i+1:]...)
f.doneEntries = append(f.doneEntries, entry)
}
}
if updated {
f.logger.TCPStreamPropUpdate(f.info, false)
rs, version := f.currentRuleset()
f.rulesetVersion = version
result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
if rs != nil {
result = rs.Match(f.info)
}
action := result.Action
if action != ruleset.ActionMaybe && action != ruleset.ActionModify {
verdict := actionToTCPVerdict(action)
f.lastVerdict = verdict
f.closeActiveEntries()
f.logger.TCPStreamAction(f.info, action, false)
return verdict
}
}
f.maybeFinalizeVerdict()
return f.lastVerdict
}
func (f *tcpFlow) maybeFinalizeVerdict() {
if len(f.activeEntries) == 0 && f.lastVerdict == io.VerdictAccept {
f.lastVerdict = io.VerdictAcceptStream
f.logger.TCPStreamAction(f.info, ruleset.ActionAllow, true)
}
}
func (f *tcpFlow) resolveDirection(tcp TCPInfo) (dir uint8, rev bool) {
if tcp.SrcPort == f.srcPort {
return uint8(tcpDirC2S), false
}
return uint8(tcpDirS2C), true
}
func (f *tcpFlow) currentRuleset() (ruleset.Ruleset, uint64) {
if f.rulesetSource == nil {
return nil, f.rulesetVersion
}
return f.rulesetSource()
}
func (f *tcpFlow) rulesetChanged() bool {
_, version := f.currentRuleset()
return version != f.rulesetVersion
}
func (f *tcpFlow) closeActiveEntries() {
updated := false
for _, entry := range f.activeEntries {
update := entry.Stream.Close(false)
updated = updated || processPropUpdate(f.info.Props, entry.Name, update)
}
if updated {
f.logger.TCPStreamPropUpdate(f.info, true)
}
f.doneEntries = append(f.doneEntries, f.activeEntries...)
f.activeEntries = nil
}
type tcpFlowManager struct {
mu sync.Mutex
flows map[uint32]*tcpFlow
sfNode *snowflake.Node
logger Logger
rulesetSource func() (ruleset.Ruleset, uint64)
workerID int
macResolver *sourceMACResolver
}
func newTCPFlowManager(workerID int, logger Logger, macResolver *sourceMACResolver, node *snowflake.Node) *tcpFlowManager {
return &tcpFlowManager{
flows: make(map[uint32]*tcpFlow),
sfNode: node,
logger: logger,
workerID: workerID,
macResolver: macResolver,
}
}
func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) io.Verdict {
m.mu.Lock()
flow, ok := m.flows[streamID]
if !ok {
flow = m.createFlow(streamID, l3, tcp, srcMAC, dstMAC)
m.flows[streamID] = flow
}
m.mu.Unlock()
verdict := flow.feed(l3, tcp, payload)
if verdict == io.VerdictAcceptStream || verdict == io.VerdictDropStream || tcp.RST || tcp.FIN {
m.mu.Lock()
delete(m.flows, streamID)
m.mu.Unlock()
}
return verdict
}
func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, srcMAC, dstMAC net.HardwareAddr) *tcpFlow {
id := m.sfNode.Generate()
ipSrc := net.IP(l3.SrcIP[:])
ipDst := net.IP(l3.DstIP[:])
if len(srcMAC) == 0 && m.macResolver != nil {
srcMAC = m.macResolver.Resolve(ipSrc)
}
info := ruleset.StreamInfo{
ID: id.Int64(),
Protocol: ruleset.ProtocolTCP,
SrcMAC: append(net.HardwareAddr(nil), srcMAC...),
DstMAC: append(net.HardwareAddr(nil), dstMAC...),
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: tcp.SrcPort,
DstPort: tcp.DstPort,
Props: make(analyzer.CombinedPropMap),
}
m.logger.TCPStreamNew(m.workerID, info)
rs, version := m.rulesetSource()
var ans []analyzer.TCPAnalyzer
if rs != nil {
ans = analyzersToTCPAnalyzers(rs.Analyzers(info))
}
entries := make([]*tcpFlowEntry, 0, len(ans))
for _, a := range ans {
entries = append(entries, &tcpFlowEntry{
Name: a.Name(),
Stream: a.NewTCP(analyzer.TCPInfo{
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: tcp.SrcPort,
DstPort: tcp.DstPort,
}, &analyzerLogger{
StreamID: id.Int64(),
Name: a.Name(),
Logger: m.logger,
}),
HasLimit: a.Limit() > 0,
Quota: a.Limit(),
})
}
flow := &tcpFlow{
streamID: streamID,
srcIP: l3.SrcIP,
dstIP: l3.DstIP,
srcPort: tcp.SrcPort,
dstPort: tcp.DstPort,
info: info,
virgin: true,
logger: m.logger,
rulesetSource: m.rulesetSource,
rulesetVersion: version,
activeEntries: entries,
lastVerdict: io.VerdictAccept,
}
flow.dirSeq[tcpDirC2S] = tcp.Seq + 1
return flow
}
func (m *tcpFlowManager) updateRuleset(r ruleset.Ruleset, version uint64) {
m.rulesetSource = func() (ruleset.Ruleset, uint64) {
return r, version
}
}
func feedFlowEntry(entry *tcpFlowEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
if !entry.HasLimit {
update, done = entry.Stream.Feed(rev, true, false, 0, data)
} else {
qData := data
if len(qData) > entry.Quota {
qData = qData[:entry.Quota]
}
update, done = entry.Stream.Feed(rev, true, false, 0, qData)
entry.Quota -= len(qData)
if entry.Quota <= 0 {
closeUpdate = entry.Stream.Close(true)
done = true
}
}
return
}
func analyzersToTCPAnalyzers(ans []analyzer.Analyzer) []analyzer.TCPAnalyzer {
tcpAns := make([]analyzer.TCPAnalyzer, 0, len(ans))
for _, a := range ans {
if ta, ok := a.(analyzer.TCPAnalyzer); ok {
tcpAns = append(tcpAns, ta)
}
}
return tcpAns
}
func actionToTCPVerdict(a ruleset.Action) io.Verdict {
switch a {
case ruleset.ActionMaybe, ruleset.ActionAllow, ruleset.ActionModify:
return io.VerdictAcceptStream
case ruleset.ActionBlock, ruleset.ActionDrop:
return io.VerdictDropStream
default:
return io.VerdictAcceptStream
}
}
+106 -90
View File
@@ -10,20 +10,13 @@ import (
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/reassembly"
)
const (
defaultChanSize = 64
defaultTCPMaxBufferedPagesTotal = 4096
defaultTCPMaxBufferedPagesPerConnection = 64
defaultUDPMaxStreams = 4096
)
var _ Engine = (*engine)(nil)
type workerPacket struct {
StreamID uint32
Data []byte
LayerType gopacket.LayerType
SrcMAC net.HardwareAddr
DstMAC net.HardwareAddr
SetVerdict func(io.Verdict, []byte) error
@@ -35,12 +28,8 @@ type worker struct {
logger Logger
macResolver *sourceMACResolver
tcpStreamFactory *tcpStreamFactory
tcpStreamPool *reassembly.StreamPool
tcpAssembler *reassembly.Assembler
udpStreamFactory *udpStreamFactory
udpStreamManager *udpStreamManager
tcpFlowMgr *tcpFlowManager
udpSM *udpStreamManager
modSerializeBuffer gopacket.SerializeBuffer
}
@@ -51,23 +40,17 @@ type workerConfig struct {
Logger Logger
Ruleset ruleset.Ruleset
MACResolver *sourceMACResolver
TCPMaxBufferedPagesTotal int
TCPMaxBufferedPagesPerConn int
TCPMaxBufferedPagesTotal int // unused, kept for config compat
TCPMaxBufferedPagesPerConn int // unused, kept for config compat
UDPMaxStreams int
}
func (c *workerConfig) fillDefaults() {
if c.ChanSize <= 0 {
c.ChanSize = defaultChanSize
}
if c.TCPMaxBufferedPagesTotal <= 0 {
c.TCPMaxBufferedPagesTotal = defaultTCPMaxBufferedPagesTotal
}
if c.TCPMaxBufferedPagesPerConn <= 0 {
c.TCPMaxBufferedPagesPerConn = defaultTCPMaxBufferedPagesPerConnection
c.ChanSize = 64
}
if c.UDPMaxStreams <= 0 {
c.UDPMaxStreams = defaultUDPMaxStreams
c.UDPMaxStreams = 4096
}
}
@@ -77,16 +60,12 @@ func newWorker(config workerConfig) (*worker, error) {
if err != nil {
return nil, err
}
tcpSF := &tcpStreamFactory{
WorkerID: config.ID,
Logger: config.Logger,
Node: sfNode,
Ruleset: config.Ruleset,
tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode)
if config.Ruleset != nil {
tcpMgr.updateRuleset(config.Ruleset, 0)
}
tcpStreamPool := reassembly.NewStreamPool(tcpSF)
tcpAssembler := reassembly.NewAssembler(tcpStreamPool)
tcpAssembler.MaxBufferedPagesTotal = config.TCPMaxBufferedPagesTotal
tcpAssembler.MaxBufferedPagesPerConnection = config.TCPMaxBufferedPagesPerConn
udpSF := &udpStreamFactory{
WorkerID: config.ID,
Logger: config.Logger,
@@ -97,25 +76,24 @@ func newWorker(config workerConfig) (*worker, error) {
if err != nil {
return nil, err
}
return &worker{
id: config.ID,
packetChan: make(chan *workerPacket, config.ChanSize),
logger: config.Logger,
macResolver: config.MACResolver,
tcpStreamFactory: tcpSF,
tcpStreamPool: tcpStreamPool,
tcpAssembler: tcpAssembler,
udpStreamFactory: udpSF,
udpStreamManager: udpSM,
tcpFlowMgr: tcpMgr,
udpSM: udpSM,
modSerializeBuffer: gopacket.NewSerializeBuffer(),
}, nil
}
func (w *worker) Feed(p *workerPacket) {
func (w *worker) Feed(p *workerPacket) bool {
select {
case w.packetChan <- p:
return true
default:
_ = p.SetVerdict(io.VerdictAccept, nil)
return false
}
}
@@ -126,78 +104,116 @@ func (w *worker) Run(ctx context.Context) {
select {
case <-ctx.Done():
return
case wPkt := <-w.packetChan:
if wPkt == nil {
case wp := <-w.packetChan:
if wp == nil {
return
}
pkt := gopacket.NewPacket(wPkt.Data, wPkt.LayerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
v, b := w.handle(wPkt.StreamID, pkt, wPkt.SrcMAC, wPkt.DstMAC)
_ = wPkt.SetVerdict(v, b)
v, b := w.handle(wp)
_ = wp.SetVerdict(v, b)
}
}
}
func (w *worker) UpdateRuleset(r ruleset.Ruleset) error {
if err := w.tcpStreamFactory.UpdateRuleset(r); err != nil {
return err
}
return w.udpStreamFactory.UpdateRuleset(r)
w.tcpFlowMgr.updateRuleset(r, 0)
return w.udpSM.factory.UpdateRuleset(r)
}
func (w *worker) handle(streamID uint32, p gopacket.Packet, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) {
netLayer, trLayer := p.NetworkLayer(), p.TransportLayer()
if netLayer == nil || trLayer == nil {
// Invalid packet
func (w *worker) handle(wp *workerPacket) (io.Verdict, []byte) {
data := wp.Data
if len(data) == 0 {
return io.VerdictAccept, nil
}
ipFlow := netLayer.NetworkFlow()
if len(srcMAC) == 0 && w.macResolver != nil {
srcMAC = w.macResolver.Resolve(net.IP(ipFlow.Src().Raw()))
}
switch tr := trLayer.(type) {
case *layers.TCP:
return w.handleTCP(ipFlow, srcMAC, dstMAC, p.Metadata(), tr), nil
case *layers.UDP:
v, modPayload := w.handleUDP(streamID, ipFlow, srcMAC, dstMAC, tr)
if v == io.VerdictAcceptModify && modPayload != nil {
tr.Payload = modPayload
_ = tr.SetNetworkLayerForChecksum(netLayer)
_ = w.modSerializeBuffer.Clear()
err := gopacket.SerializePacket(w.modSerializeBuffer,
gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}, p)
if err != nil {
// Just accept without modification for now
ipVersion := data[0] >> 4
if ipVersion == 4 {
l3, transport, ok := ParseL3(data)
if !ok {
return io.VerdictAccept, nil
}
switch l3.Protocol {
case 6: // TCP
tcp, payload, ok := ParseTCP(transport)
if !ok {
return io.VerdictAccept, nil
}
return v, w.modSerializeBuffer.Bytes()
verdict := w.tcpFlowMgr.handle(
wp.StreamID, l3, tcp, payload,
wp.SrcMAC, wp.DstMAC,
)
return verdict, nil
case 17: // UDP
udp, payload, ok := ParseUDP(transport)
if !ok {
return io.VerdictAccept, nil
}
v, modPayload := w.handleUDP(
wp.StreamID, l3, udp, payload,
wp.SrcMAC, wp.DstMAC,
)
if v == io.VerdictAcceptModify && modPayload != nil {
return w.serializeModifiedUDP(data, l3, udp, transport, modPayload)
}
return v, nil
default:
return io.VerdictAccept, nil
}
return v, nil
default:
// Unsupported protocol
}
// Ethernet frame path (for custom PacketIO)
if ipVersion == 6 {
// TODO: IPv6 support with raw parsing
return io.VerdictAccept, nil
}
return io.VerdictAccept, nil
}
func (w *worker) handleTCP(ipFlow gopacket.Flow, srcMAC, dstMAC net.HardwareAddr, pMeta *gopacket.PacketMetadata, tcp *layers.TCP) io.Verdict {
ctx := &tcpContext{
PacketMetadata: pMeta,
Verdict: tcpVerdictAccept,
SrcMAC: srcMAC,
DstMAC: dstMAC,
func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) {
ipSrc := net.IP(l3.SrcIP[:])
ipDst := net.IP(l3.DstIP[:])
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, ipSrc.To4(), ipDst.To4())
udpFlow := gopacket.NewFlow(layers.EndpointUDPPort, []byte{byte(udp.SrcPort >> 8), byte(udp.SrcPort)}, []byte{byte(udp.DstPort >> 8), byte(udp.DstPort)})
if len(srcMAC) == 0 && w.macResolver != nil {
srcMAC = w.macResolver.Resolve(ipSrc)
}
w.tcpAssembler.AssembleWithContext(ipFlow, tcp, ctx)
return io.Verdict(ctx.Verdict)
}
func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, srcMAC, dstMAC net.HardwareAddr, udp *layers.UDP) (io.Verdict, []byte) {
ctx := &udpContext{
uc := &udpContext{
Verdict: udpVerdictAccept,
SrcMAC: srcMAC,
DstMAC: dstMAC,
}
w.udpStreamManager.MatchWithContext(streamID, ipFlow, udp, ctx)
return io.Verdict(ctx.Verdict), ctx.Packet
// Temporarily set payload on a UDP layer so existing UDP handling works
// We pass the payload through the context
w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{
BaseLayer: layers.BaseLayer{Payload: payload},
SrcPort: layers.UDPPort(udp.SrcPort),
DstPort: layers.UDPPort(udp.DstPort),
}, uc)
return io.Verdict(uc.Verdict), uc.Packet
}
func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, udp UDPInfo, transport []byte, modPayload []byte) (io.Verdict, []byte) {
ipPkt := gopacket.NewPacket(fullData, layers.LayerTypeIPv4, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
netLayer := ipPkt.NetworkLayer()
trLayer := ipPkt.TransportLayer()
if netLayer == nil || trLayer == nil {
return io.VerdictAccept, nil
}
udpLayer, ok := trLayer.(*layers.UDP)
if !ok {
return io.VerdictAccept, nil
}
udpLayer.Payload = modPayload
_ = udpLayer.SetNetworkLayerForChecksum(netLayer)
_ = w.modSerializeBuffer.Clear()
err := gopacket.SerializePacket(w.modSerializeBuffer,
gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}, ipPkt)
if err != nil {
return io.VerdictAccept, nil
}
return io.VerdictAcceptModify, w.modSerializeBuffer.Bytes()
}
+157 -127
View File
@@ -18,9 +18,9 @@ import (
)
const (
nfqueueNum = 100
nfqueueMaxPacketLen = 0xFFFF
nfqueueNumStart = 100
nfqueueDefaultQueueSize = 128
nfqueueDefaultMaxLen = 0xFFFF
nfqueueConnMarkAccept = 1001
nfqueueConnMarkDrop = 1002
@@ -29,17 +29,25 @@ const (
nftTable = "mellaris"
)
func generateNftRules(local, rst bool) (*nftTableSpec, error) {
func generateNftRules(local, rst bool, numQueues int) (*nftTableSpec, error) {
if local && rst {
return nil, errors.New("tcp rst is not supported in local mode")
}
if numQueues < 1 {
numQueues = 1
}
table := &nftTableSpec{
Family: nftFamily,
Table: nftTable,
}
table.Defines = append(table.Defines, fmt.Sprintf("define ACCEPT_CTMARK=%d", nfqueueConnMarkAccept))
table.Defines = append(table.Defines, fmt.Sprintf("define DROP_CTMARK=%d", nfqueueConnMarkDrop))
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", nfqueueNum))
queueEnd := nfqueueNumStart + numQueues - 1
if numQueues == 1 {
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", nfqueueNumStart))
} else {
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d-%d", nfqueueNumStart, queueEnd))
}
if local {
table.Chains = []nftChainSpec{
{Chain: "INPUT", Header: "type filter hook input priority filter; policy accept;"},
@@ -52,7 +60,7 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) {
}
for i := range table.Chains {
c := &table.Chains[i]
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") // Bypass protected connections
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK")
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
if rst {
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
@@ -63,10 +71,13 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) {
return table, nil
}
func generateIptRules(local, rst bool) ([]iptRule, error) {
func generateIptRules(local, rst bool, numQueues int) ([]iptRule, error) {
if local && rst {
return nil, errors.New("tcp rst is not supported in local mode")
}
if numQueues < 1 {
numQueues = 1
}
var chains []string
if local {
chains = []string{"INPUT", "OUTPUT"}
@@ -75,16 +86,19 @@ func generateIptRules(local, rst bool) ([]iptRule, error) {
}
rules := make([]iptRule, 0, 4*len(chains))
for _, chain := range chains {
// Bypass protected connections
rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}})
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
if rst {
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
}
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}})
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}})
if numQueues == 1 {
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNumStart), "--queue-bypass"}})
} else {
queueSpec := fmt.Sprintf("%d:%d", nfqueueNumStart, nfqueueNumStart+numQueues-1)
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-balance", queueSpec, "--queue-bypass"}})
}
}
return rules, nil
}
@@ -93,12 +107,12 @@ var _ PacketIO = (*nfqueuePacketIO)(nil)
var errNotNFQueuePacket = errors.New("not an NFQueue packet")
type nfqueuePacketIO struct {
n *nfqueue.Nfqueue
local bool
rst bool
rSet bool // whether the nftables/iptables rules have been set
nqs []*nfqueue.Nfqueue
numQueues int
local bool
rst bool
rSet bool
// iptables not nil = use iptables instead of nftables
ipt4 *iptables.IPTables
ipt6 *iptables.IPTables
@@ -106,21 +120,28 @@ type nfqueuePacketIO struct {
}
type NFQueuePacketIOConfig struct {
QueueSize uint32
ReadBuffer int
WriteBuffer int
Local bool
RST bool
QueueSize uint32
ReadBuffer int
WriteBuffer int
Local bool
RST bool
NumQueues int
MaxPacketLen uint32
}
func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
if config.QueueSize == 0 {
config.QueueSize = nfqueueDefaultQueueSize
}
if config.NumQueues <= 0 {
config.NumQueues = 1
}
if config.MaxPacketLen == 0 {
config.MaxPacketLen = nfqueueDefaultMaxLen
}
var ipt4, ipt6 *iptables.IPTables
var err error
if nftCheck() != nil {
// We prefer nftables, but if it's not available, fall back to iptables
ipt4, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, err
@@ -130,36 +151,50 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
return nil, err
}
}
n, err := nfqueue.Open(&nfqueue.Config{
NfQueue: nfqueueNum,
MaxPacketLen: nfqueueMaxPacketLen,
MaxQueueLen: config.QueueSize,
Copymode: nfqueue.NfQnlCopyPacket,
Flags: nfqueue.NfQaCfgFlagConntrack,
})
if err != nil {
return nil, err
}
if config.ReadBuffer > 0 {
err = n.Con.SetReadBuffer(config.ReadBuffer)
nqs := make([]*nfqueue.Nfqueue, config.NumQueues)
for i := range nqs {
n, err := nfqueue.Open(&nfqueue.Config{
NfQueue: uint16(nfqueueNumStart + i),
MaxPacketLen: config.MaxPacketLen,
MaxQueueLen: config.QueueSize,
Copymode: nfqueue.NfQnlCopyPacket,
Flags: nfqueue.NfQaCfgFlagConntrack,
})
if err != nil {
_ = n.Close()
for j := 0; j < i; j++ {
nqs[j].Close()
}
return nil, err
}
}
if config.WriteBuffer > 0 {
err = n.Con.SetWriteBuffer(config.WriteBuffer)
if err != nil {
_ = n.Close()
return nil, err
if config.ReadBuffer > 0 {
err = n.Con.SetReadBuffer(config.ReadBuffer)
if err != nil {
for j := 0; j <= i; j++ {
nqs[j].Close()
}
return nil, err
}
}
if config.WriteBuffer > 0 {
err = n.Con.SetWriteBuffer(config.WriteBuffer)
if err != nil {
for j := 0; j <= i; j++ {
nqs[j].Close()
}
return nil, err
}
}
nqs[i] = n
}
return &nfqueuePacketIO{
n: n,
local: config.Local,
rst: config.RST,
ipt4: ipt4,
ipt6: ipt6,
nqs: nqs,
numQueues: config.NumQueues,
local: config.Local,
rst: config.RST,
ipt4: ipt4,
ipt6: ipt6,
protectedDialer: &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
var err error
@@ -175,60 +210,63 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
}, nil
}
func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
err := n.n.RegisterWithErrorFunc(ctx,
func(a nfqueue.Attribute) int {
if ok, verdict := n.packetAttributeSanityCheck(a); !ok {
if a.PacketID != nil {
_ = n.n.SetVerdict(*a.PacketID, verdict)
}
return 0
}
p := &nfqueuePacket{
id: *a.PacketID,
streamID: ctIDFromCtBytes(*a.Ct),
data: *a.Payload,
}
return okBoolToInt(cb(p, nil))
},
func(e error) int {
if opErr := (*netlink.OpError)(nil); errors.As(e, &opErr) {
if errors.Is(opErr.Err, unix.ENOBUFS) {
// Kernel buffer temporarily full, ignore
func (nio *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
for i, nq := range nio.nqs {
nq := nq
err := nq.RegisterWithErrorFunc(ctx,
func(a nfqueue.Attribute) int {
if ok, verdict := nio.packetAttributeSanityCheck(a); !ok {
if a.PacketID != nil {
_ = nq.SetVerdict(*a.PacketID, verdict)
}
return 0
}
}
return okBoolToInt(cb(nil, e))
})
if err != nil {
return err
}
if !n.rSet {
if n.ipt4 != nil {
err = n.setupIpt(n.local, n.rst, false)
} else {
err = n.setupNft(n.local, n.rst, false)
}
p := &nfqueuePacket{
id: *a.PacketID,
streamID: ctIDFromCtBytes(*a.Ct),
data: *a.Payload,
nq: nq,
}
return okBoolToInt(cb(p, nil))
},
func(e error) int {
if opErr := (*netlink.OpError)(nil); errors.As(e, &opErr) {
if errors.Is(opErr.Err, unix.ENOBUFS) {
return 0
}
}
return okBoolToInt(cb(nil, e))
})
if err != nil {
return err
}
n.rSet = true
}
if !nio.rSet {
if nio.ipt4 != nil {
err := nio.setupIpt(nio.local, nio.rst, false)
if err != nil {
return err
}
} else {
err := nio.setupNft(nio.local, nio.rst, false)
if err != nil {
return err
}
}
nio.rSet = true
}
return nil
}
func (n *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bool, verdict int) {
func (nio *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bool, verdict int) {
if a.PacketID == nil {
// Re-inject to NFQUEUE is actually not possible in this condition
return false, -1
}
if a.Payload == nil || len(*a.Payload) < 20 {
// 20 is the minimum possible size of an IP packet
return false, nfqueue.NfDrop
}
if a.Ct == nil {
// Multicast packets may not have a conntrack, but only appear in local mode
if n.local {
if nio.local {
return false, nfqueue.NfAccept
}
return false, nfqueue.NfDrop
@@ -236,46 +274,54 @@ func (n *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bo
return true, -1
}
func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error {
func (nio *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error {
nP, ok := p.(*nfqueuePacket)
if !ok {
return &ErrInvalidPacket{Err: errNotNFQueuePacket}
}
switch v {
case VerdictAccept:
return n.n.SetVerdict(nP.id, nfqueue.NfAccept)
return nP.nq.SetVerdict(nP.id, nfqueue.NfAccept)
case VerdictAcceptModify:
return n.n.SetVerdictModPacket(nP.id, nfqueue.NfAccept, newPacket)
return nP.nq.SetVerdictModPacket(nP.id, nfqueue.NfAccept, newPacket)
case VerdictAcceptStream:
return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfAccept, nfqueueConnMarkAccept)
return nP.nq.SetVerdictWithConnMark(nP.id, nfqueue.NfAccept, nfqueueConnMarkAccept)
case VerdictDrop:
return n.n.SetVerdict(nP.id, nfqueue.NfDrop)
return nP.nq.SetVerdict(nP.id, nfqueue.NfDrop)
case VerdictDropStream:
return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfDrop, nfqueueConnMarkDrop)
return nP.nq.SetVerdictWithConnMark(nP.id, nfqueue.NfDrop, nfqueueConnMarkDrop)
default:
// Invalid verdict, ignore for now
return nil
}
}
func (n *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
return n.protectedDialer.DialContext(ctx, network, address)
func (nio *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
return nio.protectedDialer.DialContext(ctx, network, address)
}
func (n *nfqueuePacketIO) Close() error {
if n.rSet {
if n.ipt4 != nil {
_ = n.setupIpt(n.local, n.rst, true)
func (nio *nfqueuePacketIO) Close() error {
if nio.rSet {
if nio.ipt4 != nil {
_ = nio.setupIpt(nio.local, nio.rst, true)
} else {
_ = n.setupNft(n.local, n.rst, true)
_ = nio.setupNft(nio.local, nio.rst, true)
}
n.rSet = false
nio.rSet = false
}
return n.n.Close()
var errs []error
for _, nq := range nio.nqs {
if err := nq.Close(); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errs[0]
}
return nil
}
func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error {
rules, err := generateNftRules(local, rst)
func (nio *nfqueuePacketIO) setupNft(local, rst, remove bool) error {
rules, err := generateNftRules(local, rst, nio.numQueues)
if err != nil {
return err
}
@@ -283,30 +329,23 @@ func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error {
if remove {
err = nftDelete(nftFamily, nftTable)
} else {
// Delete first to make sure no leftover rules
_ = nftDelete(nftFamily, nftTable)
err = nftAdd(rulesText)
}
if err != nil {
return err
}
return nil
return err
}
func (n *nfqueuePacketIO) setupIpt(local, rst, remove bool) error {
rules, err := generateIptRules(local, rst)
func (nio *nfqueuePacketIO) setupIpt(local, rst, remove bool) error {
rules, err := generateIptRules(local, rst, nio.numQueues)
if err != nil {
return err
}
if remove {
err = iptsBatchDeleteIfExists([]*iptables.IPTables{n.ipt4, n.ipt6}, rules)
err = iptsBatchDeleteIfExists([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules)
} else {
err = iptsBatchAppendUnique([]*iptables.IPTables{n.ipt4, n.ipt6}, rules)
err = iptsBatchAppendUnique([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules)
}
if err != nil {
return err
}
return nil
return err
}
var _ Packet = (*nfqueuePacket)(nil)
@@ -315,30 +354,22 @@ type nfqueuePacket struct {
id uint32
streamID uint32
data []byte
nq *nfqueue.Nfqueue
}
func (p *nfqueuePacket) StreamID() uint32 {
return p.streamID
}
func (p *nfqueuePacket) Data() []byte {
return p.data
}
func (p *nfqueuePacket) StreamID() uint32 { return p.streamID }
func (p *nfqueuePacket) Data() []byte { return p.data }
func okBoolToInt(ok bool) int {
if ok {
return 0
} else {
return 1
}
return 1
}
func nftCheck() error {
_, err := exec.LookPath("nft")
if err != nil {
return err
}
return nil
return err
}
func nftAdd(input string) error {
@@ -363,7 +394,6 @@ func (t *nftTableSpec) String() string {
for _, c := range t.Chains {
chains = append(chains, c.String())
}
return fmt.Sprintf(`
%s
+33 -20
View File
@@ -7,6 +7,7 @@ import (
"os"
"reflect"
"strings"
"sync"
"time"
"github.com/expr-lang/expr/builtin"
@@ -67,6 +68,19 @@ type compiledExprRule struct {
var _ Ruleset = (*exprRuleset)(nil)
var (
envPool = sync.Pool{
New: func() any {
return make(map[string]any, 16)
},
}
subMapPool = sync.Pool{
New: func() any {
return make(map[string]any, 8)
},
}
)
type exprRuleset struct {
Rules []compiledExprRule
Ans []analyzer.Analyzer
@@ -79,7 +93,9 @@ func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
}
func (r *exprRuleset) Match(info StreamInfo) MatchResult {
env := streamInfoToExprEnv(info)
env := envPool.Get().(map[string]any)
clear(env)
populateExprEnv(env, info)
now := time.Now()
for _, rule := range r.Rules {
if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) {
@@ -99,6 +115,7 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
r.Logger.Log(logInfo, rule.Name)
}
if rule.Action != nil {
envPool.Put(env)
return MatchResult{
Action: *rule.Action,
ModInstance: rule.ModInstance,
@@ -106,7 +123,7 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
}
}
}
// No match
envPool.Put(env)
return MatchResult{
Action: ActionMaybe,
}
@@ -228,30 +245,26 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
}, nil
}
func streamInfoToExprEnv(info StreamInfo) map[string]interface{} {
m := map[string]interface{}{
"id": info.ID,
"proto": info.Protocol.String(),
"mac": map[string]string{
"src": info.SrcMAC.String(),
"dst": info.DstMAC.String(),
},
"ip": map[string]string{
"src": info.SrcIP.String(),
"dst": info.DstIP.String(),
},
"port": map[string]uint16{
"src": info.SrcPort,
"dst": info.DstPort,
},
func populateExprEnv(m map[string]any, info StreamInfo) {
m["id"] = info.ID
m["proto"] = info.Protocol.String()
m["mac"] = map[string]string{
"src": info.SrcMAC.String(),
"dst": info.DstMAC.String(),
}
m["ip"] = map[string]string{
"src": info.SrcIP.String(),
"dst": info.DstIP.String(),
}
m["port"] = map[string]uint16{
"src": info.SrcPort,
"dst": info.DstPort,
}
for anName, anProps := range info.Props {
if len(anProps) != 0 {
// Ignore analyzers with empty properties
m[anName] = anProps
}
}
return m
}
func isBuiltInAnalyzer(name string) bool {