refactor: engine/tcp/worker perf improvements
This commit is contained in:
@@ -49,6 +49,8 @@ func New(cfg Config, opts Options) (*App, error) {
|
|||||||
WriteBuffer: cfg.IO.WriteBuffer,
|
WriteBuffer: cfg.IO.WriteBuffer,
|
||||||
Local: cfg.IO.Local,
|
Local: cfg.IO.Local,
|
||||||
RST: cfg.IO.RST,
|
RST: cfg.IO.RST,
|
||||||
|
NumQueues: cfg.IO.NumQueues,
|
||||||
|
MaxPacketLen: cfg.IO.MaxPacketLen,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ConfigError{Field: "io", Err: err}
|
return nil, ConfigError{Field: "io", Err: err}
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ type IOConfig struct {
|
|||||||
WriteBuffer int `mapstructure:"sndBuf" yaml:"sndBuf"`
|
WriteBuffer int `mapstructure:"sndBuf" yaml:"sndBuf"`
|
||||||
Local bool `mapstructure:"local" yaml:"local"`
|
Local bool `mapstructure:"local" yaml:"local"`
|
||||||
RST bool `mapstructure:"rst" yaml:"rst"`
|
RST bool `mapstructure:"rst" yaml:"rst"`
|
||||||
|
NumQueues int `mapstructure:"numQueues" yaml:"numQueues"`
|
||||||
|
MaxPacketLen uint32 `mapstructure:"maxPacketLen" yaml:"maxPacketLen"`
|
||||||
|
|
||||||
// PacketIO overrides NFQueue creation when set.
|
// PacketIO overrides NFQueue creation when set.
|
||||||
// When provided, App.Close will call PacketIO.Close.
|
// When provided, App.Close will call PacketIO.Close.
|
||||||
|
|||||||
+37
-34
@@ -2,16 +2,11 @@ package engine
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"git.difuse.io/Difuse/Mellaris/io"
|
"git.difuse.io/Difuse/Mellaris/io"
|
||||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ Engine = (*engine)(nil)
|
var _ Engine = (*engine)(nil)
|
||||||
@@ -27,12 +22,15 @@ type engine struct {
|
|||||||
workers []*worker
|
workers []*worker
|
||||||
verdicts sync.Map // streamID(uint32) → verdictEntry
|
verdicts sync.Map // streamID(uint32) → verdictEntry
|
||||||
verdictsGen atomic.Int64 // incremented on ruleset update
|
verdictsGen atomic.Int64 // incremented on ruleset update
|
||||||
|
|
||||||
|
overflowCh chan *workerPacket
|
||||||
|
overflowOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewEngine(config Config) (Engine, error) {
|
func NewEngine(config Config) (Engine, error) {
|
||||||
workerCount := config.Workers
|
workerCount := config.Workers
|
||||||
if workerCount <= 0 {
|
if workerCount <= 0 {
|
||||||
workerCount = runtime.NumCPU()
|
workerCount = 1
|
||||||
}
|
}
|
||||||
macResolver := newSourceMACResolver()
|
macResolver := newSourceMACResolver()
|
||||||
var err error
|
var err error
|
||||||
@@ -56,6 +54,7 @@ func NewEngine(config Config) (Engine, error) {
|
|||||||
logger: config.Logger,
|
logger: config.Logger,
|
||||||
io: config.IO,
|
io: config.IO,
|
||||||
workers: workers,
|
workers: workers,
|
||||||
|
overflowCh: make(chan *workerPacket, 1024),
|
||||||
}
|
}
|
||||||
return e, nil
|
return e, nil
|
||||||
}
|
}
|
||||||
@@ -75,6 +74,10 @@ func (e *engine) Run(ctx context.Context) error {
|
|||||||
ioCtx, ioCancel := context.WithCancel(ctx)
|
ioCtx, ioCancel := context.WithCancel(ctx)
|
||||||
defer ioCancel()
|
defer ioCancel()
|
||||||
|
|
||||||
|
e.overflowOnce.Do(func() {
|
||||||
|
go e.drainOverflow(ioCtx)
|
||||||
|
})
|
||||||
|
|
||||||
for _, w := range e.workers {
|
for _, w := range e.workers {
|
||||||
go w.Run(ioCtx)
|
go w.Run(ioCtx)
|
||||||
}
|
}
|
||||||
@@ -111,55 +114,55 @@ func (e *engine) dispatch(p io.Packet) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
data := p.Data()
|
data := p.Data()
|
||||||
layerType, srcMAC, dstMAC, ok := classifyPacket(data)
|
if !validPacket(data) {
|
||||||
if !ok {
|
|
||||||
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
|
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
gen := e.verdictsGen.Load()
|
gen := e.verdictsGen.Load()
|
||||||
index := streamID % uint32(len(e.workers))
|
index := streamID % uint32(len(e.workers))
|
||||||
e.workers[index].Feed(&workerPacket{
|
wp := &workerPacket{
|
||||||
StreamID: streamID,
|
StreamID: streamID,
|
||||||
Data: data,
|
Data: data,
|
||||||
LayerType: layerType,
|
|
||||||
SrcMAC: srcMAC,
|
|
||||||
DstMAC: dstMAC,
|
|
||||||
SetVerdict: func(v io.Verdict, b []byte) error {
|
SetVerdict: func(v io.Verdict, b []byte) error {
|
||||||
if v == io.VerdictAcceptStream || v == io.VerdictDropStream {
|
if v == io.VerdictAcceptStream || v == io.VerdictDropStream {
|
||||||
e.verdicts.Store(streamID, verdictEntry{Verdict: v, Gen: gen})
|
e.verdicts.Store(streamID, verdictEntry{Verdict: v, Gen: gen})
|
||||||
}
|
}
|
||||||
return e.io.SetVerdict(p, v, b)
|
return e.io.SetVerdict(p, v, b)
|
||||||
},
|
},
|
||||||
})
|
}
|
||||||
|
if !e.workers[index].Feed(wp) {
|
||||||
|
select {
|
||||||
|
case e.overflowCh <- wp:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// classifyPacket detects packet framing and returns a gopacket decode layer
|
func validPacket(data []byte) bool {
|
||||||
// plus best-effort source/destination MAC addresses when available.
|
|
||||||
func classifyPacket(data []byte) (gopacket.LayerType, []byte, []byte, bool) {
|
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
return 0, nil, nil, false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fast path for IP packets (NFQUEUE payloads are typically IP-only).
|
|
||||||
ipVersion := data[0] >> 4
|
ipVersion := data[0] >> 4
|
||||||
if ipVersion == 4 {
|
if ipVersion == 4 || ipVersion == 6 {
|
||||||
return layers.LayerTypeIPv4, nil, nil, true
|
return true
|
||||||
}
|
}
|
||||||
if ipVersion == 6 {
|
|
||||||
return layers.LayerTypeIPv6, nil, nil, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ethernet frame path (for custom PacketIO implementations).
|
|
||||||
if len(data) >= 14 {
|
if len(data) >= 14 {
|
||||||
etherType := binary.BigEndian.Uint16(data[12:14])
|
etherType := uint16(data[12])<<8 | uint16(data[13])
|
||||||
if etherType == uint16(layers.EthernetTypeIPv4) || etherType == uint16(layers.EthernetTypeIPv6) {
|
if etherType == 0x0800 || etherType == 0x86DD {
|
||||||
return layers.LayerTypeEthernet,
|
return true
|
||||||
append([]byte(nil), data[6:12]...),
|
}
|
||||||
append([]byte(nil), data[:6]...),
|
}
|
||||||
true
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *engine) drainOverflow(ctx context.Context) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case wp := <-e.overflowCh:
|
||||||
|
_ = wp.SetVerdict(io.VerdictAccept, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, nil, nil, false
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,91 @@
|
|||||||
|
package engine
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
type L3Info struct {
|
||||||
|
Version uint8
|
||||||
|
Protocol uint8
|
||||||
|
IHL uint8
|
||||||
|
SrcIP [4]byte
|
||||||
|
DstIP [4]byte
|
||||||
|
Length uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i L3Info) SrcIPAddr() net.IP { return net.IP(i.SrcIP[:]) }
|
||||||
|
func (i L3Info) DstIPAddr() net.IP { return net.IP(i.DstIP[:]) }
|
||||||
|
|
||||||
|
type TCPInfo struct {
|
||||||
|
SrcPort uint16
|
||||||
|
DstPort uint16
|
||||||
|
Seq uint32
|
||||||
|
Ack uint32
|
||||||
|
HdrLen uint8
|
||||||
|
SYN bool
|
||||||
|
FIN bool
|
||||||
|
RST bool
|
||||||
|
ACK bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type UDPInfo struct {
|
||||||
|
SrcPort uint16
|
||||||
|
DstPort uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseL3(data []byte) (l3 L3Info, transport []byte, ok bool) {
|
||||||
|
if len(data) < 20 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
version := data[0] >> 4
|
||||||
|
if version != 4 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ihl := data[0] & 0x0F
|
||||||
|
if ihl < 5 || len(data) < int(ihl)*4 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
totalLen := int(uint16(data[2])<<8 | uint16(data[3]))
|
||||||
|
if totalLen < int(ihl)*4 || totalLen > len(data) {
|
||||||
|
totalLen = len(data)
|
||||||
|
}
|
||||||
|
return L3Info{
|
||||||
|
Version: 4,
|
||||||
|
Protocol: data[9],
|
||||||
|
IHL: ihl,
|
||||||
|
Length: uint16(totalLen),
|
||||||
|
SrcIP: [4]byte{data[12], data[13], data[14], data[15]},
|
||||||
|
DstIP: [4]byte{data[16], data[17], data[18], data[19]},
|
||||||
|
}, data[ihl*4:totalLen], true
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseTCP(transport []byte) (TCPInfo, []byte, bool) {
|
||||||
|
if len(transport) < 20 {
|
||||||
|
return TCPInfo{}, nil, false
|
||||||
|
}
|
||||||
|
dataOff := uint8(transport[12]>>4) * 4
|
||||||
|
if dataOff < 20 || len(transport) < int(dataOff) {
|
||||||
|
return TCPInfo{}, nil, false
|
||||||
|
}
|
||||||
|
flags := transport[13]
|
||||||
|
payloadLen := len(transport) - int(dataOff)
|
||||||
|
return TCPInfo{
|
||||||
|
SrcPort: uint16(transport[0])<<8 | uint16(transport[1]),
|
||||||
|
DstPort: uint16(transport[2])<<8 | uint16(transport[3]),
|
||||||
|
Seq: uint32(transport[4])<<24 | uint32(transport[5])<<16 | uint32(transport[6])<<8 | uint32(transport[7]),
|
||||||
|
Ack: uint32(transport[8])<<24 | uint32(transport[9])<<16 | uint32(transport[10])<<8 | uint32(transport[11]),
|
||||||
|
HdrLen: dataOff,
|
||||||
|
SYN: flags&0x02 != 0,
|
||||||
|
FIN: flags&0x01 != 0,
|
||||||
|
RST: flags&0x04 != 0,
|
||||||
|
ACK: flags&0x10 != 0,
|
||||||
|
}, transport[dataOff : dataOff+uint8(payloadLen)], true
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseUDP(transport []byte) (UDPInfo, []byte, bool) {
|
||||||
|
if len(transport) < 8 {
|
||||||
|
return UDPInfo{}, nil, false
|
||||||
|
}
|
||||||
|
return UDPInfo{
|
||||||
|
SrcPort: uint16(transport[0])<<8 | uint16(transport[1]),
|
||||||
|
DstPort: uint16(transport[2])<<8 | uint16(transport[3]),
|
||||||
|
}, transport[8:], true
|
||||||
|
}
|
||||||
-262
@@ -1,262 +0,0 @@
|
|||||||
package engine
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
|
||||||
"git.difuse.io/Difuse/Mellaris/io"
|
|
||||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
|
||||||
|
|
||||||
"github.com/bwmarrin/snowflake"
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/google/gopacket/reassembly"
|
|
||||||
)
|
|
||||||
|
|
||||||
// tcpVerdict is a subset of io.Verdict for TCP streams.
|
|
||||||
// We don't allow modifying or dropping a single packet
|
|
||||||
// for TCP streams for now, as it doesn't make much sense.
|
|
||||||
type tcpVerdict io.Verdict
|
|
||||||
|
|
||||||
const (
|
|
||||||
tcpVerdictAccept = tcpVerdict(io.VerdictAccept)
|
|
||||||
tcpVerdictAcceptStream = tcpVerdict(io.VerdictAcceptStream)
|
|
||||||
tcpVerdictDropStream = tcpVerdict(io.VerdictDropStream)
|
|
||||||
)
|
|
||||||
|
|
||||||
type tcpContext struct {
|
|
||||||
*gopacket.PacketMetadata
|
|
||||||
Verdict tcpVerdict
|
|
||||||
SrcMAC, DstMAC net.HardwareAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ctx *tcpContext) GetCaptureInfo() gopacket.CaptureInfo {
|
|
||||||
return ctx.CaptureInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
type tcpStreamFactory struct {
|
|
||||||
WorkerID int
|
|
||||||
Logger Logger
|
|
||||||
Node *snowflake.Node
|
|
||||||
|
|
||||||
RulesetMutex sync.RWMutex
|
|
||||||
Ruleset ruleset.Ruleset
|
|
||||||
RulesetVersion uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream {
|
|
||||||
id := f.Node.Generate()
|
|
||||||
ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw())
|
|
||||||
ctx := ac.(*tcpContext)
|
|
||||||
info := ruleset.StreamInfo{
|
|
||||||
ID: id.Int64(),
|
|
||||||
Protocol: ruleset.ProtocolTCP,
|
|
||||||
SrcMAC: append(net.HardwareAddr(nil), ctx.SrcMAC...),
|
|
||||||
DstMAC: append(net.HardwareAddr(nil), ctx.DstMAC...),
|
|
||||||
SrcIP: ipSrc,
|
|
||||||
DstIP: ipDst,
|
|
||||||
SrcPort: uint16(tcp.SrcPort),
|
|
||||||
DstPort: uint16(tcp.DstPort),
|
|
||||||
Props: make(analyzer.CombinedPropMap),
|
|
||||||
}
|
|
||||||
f.Logger.TCPStreamNew(f.WorkerID, info)
|
|
||||||
rs, version := f.currentRuleset()
|
|
||||||
var ans []analyzer.TCPAnalyzer
|
|
||||||
if rs != nil {
|
|
||||||
ans = analyzersToTCPAnalyzers(rs.Analyzers(info))
|
|
||||||
}
|
|
||||||
// Create entries for each analyzer
|
|
||||||
entries := make([]*tcpStreamEntry, 0, len(ans))
|
|
||||||
for _, a := range ans {
|
|
||||||
entries = append(entries, &tcpStreamEntry{
|
|
||||||
Name: a.Name(),
|
|
||||||
Stream: a.NewTCP(analyzer.TCPInfo{
|
|
||||||
SrcIP: ipSrc,
|
|
||||||
DstIP: ipDst,
|
|
||||||
SrcPort: uint16(tcp.SrcPort),
|
|
||||||
DstPort: uint16(tcp.DstPort),
|
|
||||||
}, &analyzerLogger{
|
|
||||||
StreamID: id.Int64(),
|
|
||||||
Name: a.Name(),
|
|
||||||
Logger: f.Logger,
|
|
||||||
}),
|
|
||||||
HasLimit: a.Limit() > 0,
|
|
||||||
Quota: a.Limit(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return &tcpStream{
|
|
||||||
info: info,
|
|
||||||
virgin: true,
|
|
||||||
logger: f.Logger,
|
|
||||||
rulesetVersion: version,
|
|
||||||
rulesetSource: f.currentRuleset,
|
|
||||||
activeEntries: entries,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *tcpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error {
|
|
||||||
f.RulesetMutex.Lock()
|
|
||||||
defer f.RulesetMutex.Unlock()
|
|
||||||
f.Ruleset = r
|
|
||||||
f.RulesetVersion++
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *tcpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) {
|
|
||||||
f.RulesetMutex.RLock()
|
|
||||||
defer f.RulesetMutex.RUnlock()
|
|
||||||
return f.Ruleset, f.RulesetVersion
|
|
||||||
}
|
|
||||||
|
|
||||||
type tcpStream struct {
|
|
||||||
info ruleset.StreamInfo
|
|
||||||
virgin bool // true if no packets have been processed
|
|
||||||
logger Logger
|
|
||||||
rulesetVersion uint64
|
|
||||||
rulesetSource func() (ruleset.Ruleset, uint64)
|
|
||||||
activeEntries []*tcpStreamEntry
|
|
||||||
doneEntries []*tcpStreamEntry
|
|
||||||
lastVerdict tcpVerdict
|
|
||||||
}
|
|
||||||
|
|
||||||
type tcpStreamEntry struct {
|
|
||||||
Name string
|
|
||||||
Stream analyzer.TCPStream
|
|
||||||
HasLimit bool
|
|
||||||
Quota int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool {
|
|
||||||
if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() {
|
|
||||||
// Make sure every stream matches against the ruleset at least once,
|
|
||||||
// even if there are no activeEntries, as the ruleset may have built-in
|
|
||||||
// properties that need to be matched.
|
|
||||||
return true
|
|
||||||
} else {
|
|
||||||
ctx := ac.(*tcpContext)
|
|
||||||
ctx.Verdict = s.lastVerdict
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) {
|
|
||||||
dir, start, end, skip := sg.Info()
|
|
||||||
rev := dir == reassembly.TCPDirServerToClient
|
|
||||||
avail, _ := sg.Lengths()
|
|
||||||
data := sg.Fetch(avail)
|
|
||||||
updated := false
|
|
||||||
for i := len(s.activeEntries) - 1; i >= 0; i-- {
|
|
||||||
// Important: reverse order so we can remove entries
|
|
||||||
entry := s.activeEntries[i]
|
|
||||||
update, closeUpdate, done := s.feedEntry(entry, rev, start, end, skip, data)
|
|
||||||
up1 := processPropUpdate(s.info.Props, entry.Name, update)
|
|
||||||
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
|
|
||||||
updated = updated || up1 || up2
|
|
||||||
if done {
|
|
||||||
s.activeEntries = append(s.activeEntries[:i], s.activeEntries[i+1:]...)
|
|
||||||
s.doneEntries = append(s.doneEntries, entry)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ctx := ac.(*tcpContext)
|
|
||||||
rs, version := s.currentRuleset()
|
|
||||||
rulesetChanged := version != s.rulesetVersion
|
|
||||||
s.rulesetVersion = version
|
|
||||||
if updated || s.virgin || rulesetChanged {
|
|
||||||
s.virgin = false
|
|
||||||
s.logger.TCPStreamPropUpdate(s.info, false)
|
|
||||||
// Match properties against ruleset
|
|
||||||
result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
|
|
||||||
if rs != nil {
|
|
||||||
result = rs.Match(s.info)
|
|
||||||
}
|
|
||||||
action := result.Action
|
|
||||||
if action != ruleset.ActionMaybe && action != ruleset.ActionModify {
|
|
||||||
verdict := actionToTCPVerdict(action)
|
|
||||||
s.lastVerdict = verdict
|
|
||||||
ctx.Verdict = verdict
|
|
||||||
s.logger.TCPStreamAction(s.info, action, false)
|
|
||||||
// Verdict issued, no need to process any more packets
|
|
||||||
s.closeActiveEntries()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(s.activeEntries) == 0 && ctx.Verdict == tcpVerdictAccept {
|
|
||||||
// All entries are done but no verdict issued, accept stream
|
|
||||||
s.lastVerdict = tcpVerdictAcceptStream
|
|
||||||
ctx.Verdict = tcpVerdictAcceptStream
|
|
||||||
s.logger.TCPStreamAction(s.info, ruleset.ActionAllow, true)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *tcpStream) currentRuleset() (ruleset.Ruleset, uint64) {
|
|
||||||
if s.rulesetSource == nil {
|
|
||||||
return nil, s.rulesetVersion
|
|
||||||
}
|
|
||||||
return s.rulesetSource()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *tcpStream) rulesetChanged() bool {
|
|
||||||
_, version := s.currentRuleset()
|
|
||||||
return version != s.rulesetVersion
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool {
|
|
||||||
s.closeActiveEntries()
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *tcpStream) closeActiveEntries() {
|
|
||||||
// Signal close to all active entries & move them to doneEntries
|
|
||||||
updated := false
|
|
||||||
for _, entry := range s.activeEntries {
|
|
||||||
update := entry.Stream.Close(false)
|
|
||||||
up := processPropUpdate(s.info.Props, entry.Name, update)
|
|
||||||
updated = updated || up
|
|
||||||
}
|
|
||||||
if updated {
|
|
||||||
s.logger.TCPStreamPropUpdate(s.info, true)
|
|
||||||
}
|
|
||||||
s.doneEntries = append(s.doneEntries, s.activeEntries...)
|
|
||||||
s.activeEntries = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *tcpStream) feedEntry(entry *tcpStreamEntry, rev, start, end bool, skip int, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
|
|
||||||
if !entry.HasLimit {
|
|
||||||
update, done = entry.Stream.Feed(rev, start, end, skip, data)
|
|
||||||
} else {
|
|
||||||
qData := data
|
|
||||||
if len(qData) > entry.Quota {
|
|
||||||
qData = qData[:entry.Quota]
|
|
||||||
}
|
|
||||||
update, done = entry.Stream.Feed(rev, start, end, skip, qData)
|
|
||||||
entry.Quota -= len(qData)
|
|
||||||
if entry.Quota <= 0 {
|
|
||||||
// Quota exhausted, signal close & move to doneEntries
|
|
||||||
closeUpdate = entry.Stream.Close(true)
|
|
||||||
done = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func analyzersToTCPAnalyzers(ans []analyzer.Analyzer) []analyzer.TCPAnalyzer {
|
|
||||||
tcpAns := make([]analyzer.TCPAnalyzer, 0, len(ans))
|
|
||||||
for _, a := range ans {
|
|
||||||
if tcpM, ok := a.(analyzer.TCPAnalyzer); ok {
|
|
||||||
tcpAns = append(tcpAns, tcpM)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return tcpAns
|
|
||||||
}
|
|
||||||
|
|
||||||
func actionToTCPVerdict(a ruleset.Action) tcpVerdict {
|
|
||||||
switch a {
|
|
||||||
case ruleset.ActionMaybe, ruleset.ActionAllow, ruleset.ActionModify:
|
|
||||||
return tcpVerdictAcceptStream
|
|
||||||
case ruleset.ActionBlock, ruleset.ActionDrop:
|
|
||||||
return tcpVerdictDropStream
|
|
||||||
default:
|
|
||||||
// Should never happen
|
|
||||||
return tcpVerdictAcceptStream
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,302 @@
|
|||||||
|
package engine
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||||
|
"git.difuse.io/Difuse/Mellaris/io"
|
||||||
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||||
|
|
||||||
|
"github.com/bwmarrin/snowflake"
|
||||||
|
)
|
||||||
|
|
||||||
|
const tcpFlowMaxBuffer = 16384
|
||||||
|
|
||||||
|
type tcpFlowDirection uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
tcpDirC2S tcpFlowDirection = iota
|
||||||
|
tcpDirS2C
|
||||||
|
)
|
||||||
|
|
||||||
|
type tcpFlow struct {
|
||||||
|
streamID uint32
|
||||||
|
srcIP [4]byte
|
||||||
|
dstIP [4]byte
|
||||||
|
srcPort uint16
|
||||||
|
dstPort uint16
|
||||||
|
|
||||||
|
dirSeq [2]uint32
|
||||||
|
dirBuf [2][]byte
|
||||||
|
|
||||||
|
info ruleset.StreamInfo
|
||||||
|
virgin bool
|
||||||
|
logger Logger
|
||||||
|
rulesetVersion uint64
|
||||||
|
rulesetSource func() (ruleset.Ruleset, uint64)
|
||||||
|
activeEntries []*tcpFlowEntry
|
||||||
|
doneEntries []*tcpFlowEntry
|
||||||
|
lastVerdict io.Verdict
|
||||||
|
feedCalled [2]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type tcpFlowEntry struct {
|
||||||
|
Name string
|
||||||
|
Stream analyzer.TCPStream
|
||||||
|
HasLimit bool
|
||||||
|
Quota int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
|
||||||
|
if f.rulesetChanged() || f.virgin {
|
||||||
|
f.virgin = false
|
||||||
|
return io.VerdictAccept
|
||||||
|
}
|
||||||
|
if len(f.activeEntries) == 0 {
|
||||||
|
return f.lastVerdict
|
||||||
|
}
|
||||||
|
|
||||||
|
dir, rev := f.resolveDirection(tcp)
|
||||||
|
|
||||||
|
if tcp.RST || tcp.FIN {
|
||||||
|
f.closeActiveEntries()
|
||||||
|
f.maybeFinalizeVerdict()
|
||||||
|
return f.lastVerdict
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(payload) == 0 {
|
||||||
|
return io.VerdictAccept
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := f.dirSeq[dir]
|
||||||
|
if f.feedCalled[dir] && expected != 0 && tcp.Seq != expected {
|
||||||
|
return io.VerdictAccept
|
||||||
|
}
|
||||||
|
f.feedCalled[dir] = true
|
||||||
|
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
|
||||||
|
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
|
||||||
|
|
||||||
|
if len(f.dirBuf[dir]) > tcpFlowMaxBuffer {
|
||||||
|
return io.VerdictAccept
|
||||||
|
}
|
||||||
|
|
||||||
|
updated := false
|
||||||
|
for i := len(f.activeEntries) - 1; i >= 0; i-- {
|
||||||
|
entry := f.activeEntries[i]
|
||||||
|
update, closeUpdate, done := feedFlowEntry(entry, rev, f.dirBuf[dir])
|
||||||
|
u1 := processPropUpdate(f.info.Props, entry.Name, update)
|
||||||
|
u2 := processPropUpdate(f.info.Props, entry.Name, closeUpdate)
|
||||||
|
updated = updated || u1 || u2
|
||||||
|
if done {
|
||||||
|
f.activeEntries = append(f.activeEntries[:i], f.activeEntries[i+1:]...)
|
||||||
|
f.doneEntries = append(f.doneEntries, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated {
|
||||||
|
f.logger.TCPStreamPropUpdate(f.info, false)
|
||||||
|
rs, version := f.currentRuleset()
|
||||||
|
f.rulesetVersion = version
|
||||||
|
result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
|
||||||
|
if rs != nil {
|
||||||
|
result = rs.Match(f.info)
|
||||||
|
}
|
||||||
|
action := result.Action
|
||||||
|
if action != ruleset.ActionMaybe && action != ruleset.ActionModify {
|
||||||
|
verdict := actionToTCPVerdict(action)
|
||||||
|
f.lastVerdict = verdict
|
||||||
|
f.closeActiveEntries()
|
||||||
|
f.logger.TCPStreamAction(f.info, action, false)
|
||||||
|
return verdict
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.maybeFinalizeVerdict()
|
||||||
|
return f.lastVerdict
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *tcpFlow) maybeFinalizeVerdict() {
|
||||||
|
if len(f.activeEntries) == 0 && f.lastVerdict == io.VerdictAccept {
|
||||||
|
f.lastVerdict = io.VerdictAcceptStream
|
||||||
|
f.logger.TCPStreamAction(f.info, ruleset.ActionAllow, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *tcpFlow) resolveDirection(tcp TCPInfo) (dir uint8, rev bool) {
|
||||||
|
if tcp.SrcPort == f.srcPort {
|
||||||
|
return uint8(tcpDirC2S), false
|
||||||
|
}
|
||||||
|
return uint8(tcpDirS2C), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *tcpFlow) currentRuleset() (ruleset.Ruleset, uint64) {
|
||||||
|
if f.rulesetSource == nil {
|
||||||
|
return nil, f.rulesetVersion
|
||||||
|
}
|
||||||
|
return f.rulesetSource()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *tcpFlow) rulesetChanged() bool {
|
||||||
|
_, version := f.currentRuleset()
|
||||||
|
return version != f.rulesetVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *tcpFlow) closeActiveEntries() {
|
||||||
|
updated := false
|
||||||
|
for _, entry := range f.activeEntries {
|
||||||
|
update := entry.Stream.Close(false)
|
||||||
|
updated = updated || processPropUpdate(f.info.Props, entry.Name, update)
|
||||||
|
}
|
||||||
|
if updated {
|
||||||
|
f.logger.TCPStreamPropUpdate(f.info, true)
|
||||||
|
}
|
||||||
|
f.doneEntries = append(f.doneEntries, f.activeEntries...)
|
||||||
|
f.activeEntries = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type tcpFlowManager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
flows map[uint32]*tcpFlow
|
||||||
|
sfNode *snowflake.Node
|
||||||
|
logger Logger
|
||||||
|
rulesetSource func() (ruleset.Ruleset, uint64)
|
||||||
|
workerID int
|
||||||
|
macResolver *sourceMACResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTCPFlowManager(workerID int, logger Logger, macResolver *sourceMACResolver, node *snowflake.Node) *tcpFlowManager {
|
||||||
|
return &tcpFlowManager{
|
||||||
|
flows: make(map[uint32]*tcpFlow),
|
||||||
|
sfNode: node,
|
||||||
|
logger: logger,
|
||||||
|
workerID: workerID,
|
||||||
|
macResolver: macResolver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) io.Verdict {
|
||||||
|
m.mu.Lock()
|
||||||
|
flow, ok := m.flows[streamID]
|
||||||
|
if !ok {
|
||||||
|
flow = m.createFlow(streamID, l3, tcp, srcMAC, dstMAC)
|
||||||
|
m.flows[streamID] = flow
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
verdict := flow.feed(l3, tcp, payload)
|
||||||
|
|
||||||
|
if verdict == io.VerdictAcceptStream || verdict == io.VerdictDropStream || tcp.RST || tcp.FIN {
|
||||||
|
m.mu.Lock()
|
||||||
|
delete(m.flows, streamID)
|
||||||
|
m.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
return verdict
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, srcMAC, dstMAC net.HardwareAddr) *tcpFlow {
|
||||||
|
id := m.sfNode.Generate()
|
||||||
|
ipSrc := net.IP(l3.SrcIP[:])
|
||||||
|
ipDst := net.IP(l3.DstIP[:])
|
||||||
|
if len(srcMAC) == 0 && m.macResolver != nil {
|
||||||
|
srcMAC = m.macResolver.Resolve(ipSrc)
|
||||||
|
}
|
||||||
|
info := ruleset.StreamInfo{
|
||||||
|
ID: id.Int64(),
|
||||||
|
Protocol: ruleset.ProtocolTCP,
|
||||||
|
SrcMAC: append(net.HardwareAddr(nil), srcMAC...),
|
||||||
|
DstMAC: append(net.HardwareAddr(nil), dstMAC...),
|
||||||
|
SrcIP: ipSrc,
|
||||||
|
DstIP: ipDst,
|
||||||
|
SrcPort: tcp.SrcPort,
|
||||||
|
DstPort: tcp.DstPort,
|
||||||
|
Props: make(analyzer.CombinedPropMap),
|
||||||
|
}
|
||||||
|
m.logger.TCPStreamNew(m.workerID, info)
|
||||||
|
rs, version := m.rulesetSource()
|
||||||
|
var ans []analyzer.TCPAnalyzer
|
||||||
|
if rs != nil {
|
||||||
|
ans = analyzersToTCPAnalyzers(rs.Analyzers(info))
|
||||||
|
}
|
||||||
|
entries := make([]*tcpFlowEntry, 0, len(ans))
|
||||||
|
for _, a := range ans {
|
||||||
|
entries = append(entries, &tcpFlowEntry{
|
||||||
|
Name: a.Name(),
|
||||||
|
Stream: a.NewTCP(analyzer.TCPInfo{
|
||||||
|
SrcIP: ipSrc,
|
||||||
|
DstIP: ipDst,
|
||||||
|
SrcPort: tcp.SrcPort,
|
||||||
|
DstPort: tcp.DstPort,
|
||||||
|
}, &analyzerLogger{
|
||||||
|
StreamID: id.Int64(),
|
||||||
|
Name: a.Name(),
|
||||||
|
Logger: m.logger,
|
||||||
|
}),
|
||||||
|
HasLimit: a.Limit() > 0,
|
||||||
|
Quota: a.Limit(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
flow := &tcpFlow{
|
||||||
|
streamID: streamID,
|
||||||
|
srcIP: l3.SrcIP,
|
||||||
|
dstIP: l3.DstIP,
|
||||||
|
srcPort: tcp.SrcPort,
|
||||||
|
dstPort: tcp.DstPort,
|
||||||
|
info: info,
|
||||||
|
virgin: true,
|
||||||
|
logger: m.logger,
|
||||||
|
rulesetSource: m.rulesetSource,
|
||||||
|
rulesetVersion: version,
|
||||||
|
activeEntries: entries,
|
||||||
|
lastVerdict: io.VerdictAccept,
|
||||||
|
}
|
||||||
|
flow.dirSeq[tcpDirC2S] = tcp.Seq + 1
|
||||||
|
return flow
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *tcpFlowManager) updateRuleset(r ruleset.Ruleset, version uint64) {
|
||||||
|
m.rulesetSource = func() (ruleset.Ruleset, uint64) {
|
||||||
|
return r, version
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func feedFlowEntry(entry *tcpFlowEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
|
||||||
|
if !entry.HasLimit {
|
||||||
|
update, done = entry.Stream.Feed(rev, true, false, 0, data)
|
||||||
|
} else {
|
||||||
|
qData := data
|
||||||
|
if len(qData) > entry.Quota {
|
||||||
|
qData = qData[:entry.Quota]
|
||||||
|
}
|
||||||
|
update, done = entry.Stream.Feed(rev, true, false, 0, qData)
|
||||||
|
entry.Quota -= len(qData)
|
||||||
|
if entry.Quota <= 0 {
|
||||||
|
closeUpdate = entry.Stream.Close(true)
|
||||||
|
done = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func analyzersToTCPAnalyzers(ans []analyzer.Analyzer) []analyzer.TCPAnalyzer {
|
||||||
|
tcpAns := make([]analyzer.TCPAnalyzer, 0, len(ans))
|
||||||
|
for _, a := range ans {
|
||||||
|
if ta, ok := a.(analyzer.TCPAnalyzer); ok {
|
||||||
|
tcpAns = append(tcpAns, ta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tcpAns
|
||||||
|
}
|
||||||
|
|
||||||
|
func actionToTCPVerdict(a ruleset.Action) io.Verdict {
|
||||||
|
switch a {
|
||||||
|
case ruleset.ActionMaybe, ruleset.ActionAllow, ruleset.ActionModify:
|
||||||
|
return io.VerdictAcceptStream
|
||||||
|
case ruleset.ActionBlock, ruleset.ActionDrop:
|
||||||
|
return io.VerdictDropStream
|
||||||
|
default:
|
||||||
|
return io.VerdictAcceptStream
|
||||||
|
}
|
||||||
|
}
|
||||||
+105
-89
@@ -10,20 +10,13 @@ import (
|
|||||||
"github.com/bwmarrin/snowflake"
|
"github.com/bwmarrin/snowflake"
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/google/gopacket/reassembly"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var _ Engine = (*engine)(nil)
|
||||||
defaultChanSize = 64
|
|
||||||
defaultTCPMaxBufferedPagesTotal = 4096
|
|
||||||
defaultTCPMaxBufferedPagesPerConnection = 64
|
|
||||||
defaultUDPMaxStreams = 4096
|
|
||||||
)
|
|
||||||
|
|
||||||
type workerPacket struct {
|
type workerPacket struct {
|
||||||
StreamID uint32
|
StreamID uint32
|
||||||
Data []byte
|
Data []byte
|
||||||
LayerType gopacket.LayerType
|
|
||||||
SrcMAC net.HardwareAddr
|
SrcMAC net.HardwareAddr
|
||||||
DstMAC net.HardwareAddr
|
DstMAC net.HardwareAddr
|
||||||
SetVerdict func(io.Verdict, []byte) error
|
SetVerdict func(io.Verdict, []byte) error
|
||||||
@@ -35,12 +28,8 @@ type worker struct {
|
|||||||
logger Logger
|
logger Logger
|
||||||
macResolver *sourceMACResolver
|
macResolver *sourceMACResolver
|
||||||
|
|
||||||
tcpStreamFactory *tcpStreamFactory
|
tcpFlowMgr *tcpFlowManager
|
||||||
tcpStreamPool *reassembly.StreamPool
|
udpSM *udpStreamManager
|
||||||
tcpAssembler *reassembly.Assembler
|
|
||||||
|
|
||||||
udpStreamFactory *udpStreamFactory
|
|
||||||
udpStreamManager *udpStreamManager
|
|
||||||
|
|
||||||
modSerializeBuffer gopacket.SerializeBuffer
|
modSerializeBuffer gopacket.SerializeBuffer
|
||||||
}
|
}
|
||||||
@@ -51,23 +40,17 @@ type workerConfig struct {
|
|||||||
Logger Logger
|
Logger Logger
|
||||||
Ruleset ruleset.Ruleset
|
Ruleset ruleset.Ruleset
|
||||||
MACResolver *sourceMACResolver
|
MACResolver *sourceMACResolver
|
||||||
TCPMaxBufferedPagesTotal int
|
TCPMaxBufferedPagesTotal int // unused, kept for config compat
|
||||||
TCPMaxBufferedPagesPerConn int
|
TCPMaxBufferedPagesPerConn int // unused, kept for config compat
|
||||||
UDPMaxStreams int
|
UDPMaxStreams int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *workerConfig) fillDefaults() {
|
func (c *workerConfig) fillDefaults() {
|
||||||
if c.ChanSize <= 0 {
|
if c.ChanSize <= 0 {
|
||||||
c.ChanSize = defaultChanSize
|
c.ChanSize = 64
|
||||||
}
|
|
||||||
if c.TCPMaxBufferedPagesTotal <= 0 {
|
|
||||||
c.TCPMaxBufferedPagesTotal = defaultTCPMaxBufferedPagesTotal
|
|
||||||
}
|
|
||||||
if c.TCPMaxBufferedPagesPerConn <= 0 {
|
|
||||||
c.TCPMaxBufferedPagesPerConn = defaultTCPMaxBufferedPagesPerConnection
|
|
||||||
}
|
}
|
||||||
if c.UDPMaxStreams <= 0 {
|
if c.UDPMaxStreams <= 0 {
|
||||||
c.UDPMaxStreams = defaultUDPMaxStreams
|
c.UDPMaxStreams = 4096
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,16 +60,12 @@ func newWorker(config workerConfig) (*worker, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tcpSF := &tcpStreamFactory{
|
|
||||||
WorkerID: config.ID,
|
tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode)
|
||||||
Logger: config.Logger,
|
if config.Ruleset != nil {
|
||||||
Node: sfNode,
|
tcpMgr.updateRuleset(config.Ruleset, 0)
|
||||||
Ruleset: config.Ruleset,
|
|
||||||
}
|
}
|
||||||
tcpStreamPool := reassembly.NewStreamPool(tcpSF)
|
|
||||||
tcpAssembler := reassembly.NewAssembler(tcpStreamPool)
|
|
||||||
tcpAssembler.MaxBufferedPagesTotal = config.TCPMaxBufferedPagesTotal
|
|
||||||
tcpAssembler.MaxBufferedPagesPerConnection = config.TCPMaxBufferedPagesPerConn
|
|
||||||
udpSF := &udpStreamFactory{
|
udpSF := &udpStreamFactory{
|
||||||
WorkerID: config.ID,
|
WorkerID: config.ID,
|
||||||
Logger: config.Logger,
|
Logger: config.Logger,
|
||||||
@@ -97,25 +76,24 @@ func newWorker(config workerConfig) (*worker, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &worker{
|
return &worker{
|
||||||
id: config.ID,
|
id: config.ID,
|
||||||
packetChan: make(chan *workerPacket, config.ChanSize),
|
packetChan: make(chan *workerPacket, config.ChanSize),
|
||||||
logger: config.Logger,
|
logger: config.Logger,
|
||||||
macResolver: config.MACResolver,
|
macResolver: config.MACResolver,
|
||||||
tcpStreamFactory: tcpSF,
|
tcpFlowMgr: tcpMgr,
|
||||||
tcpStreamPool: tcpStreamPool,
|
udpSM: udpSM,
|
||||||
tcpAssembler: tcpAssembler,
|
|
||||||
udpStreamFactory: udpSF,
|
|
||||||
udpStreamManager: udpSM,
|
|
||||||
modSerializeBuffer: gopacket.NewSerializeBuffer(),
|
modSerializeBuffer: gopacket.NewSerializeBuffer(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *worker) Feed(p *workerPacket) {
|
func (w *worker) Feed(p *workerPacket) bool {
|
||||||
select {
|
select {
|
||||||
case w.packetChan <- p:
|
case w.packetChan <- p:
|
||||||
|
return true
|
||||||
default:
|
default:
|
||||||
_ = p.SetVerdict(io.VerdictAccept, nil)
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,78 +104,116 @@ func (w *worker) Run(ctx context.Context) {
|
|||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case wPkt := <-w.packetChan:
|
case wp := <-w.packetChan:
|
||||||
if wPkt == nil {
|
if wp == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
pkt := gopacket.NewPacket(wPkt.Data, wPkt.LayerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
|
v, b := w.handle(wp)
|
||||||
v, b := w.handle(wPkt.StreamID, pkt, wPkt.SrcMAC, wPkt.DstMAC)
|
_ = wp.SetVerdict(v, b)
|
||||||
_ = wPkt.SetVerdict(v, b)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *worker) UpdateRuleset(r ruleset.Ruleset) error {
|
func (w *worker) UpdateRuleset(r ruleset.Ruleset) error {
|
||||||
if err := w.tcpStreamFactory.UpdateRuleset(r); err != nil {
|
w.tcpFlowMgr.updateRuleset(r, 0)
|
||||||
return err
|
return w.udpSM.factory.UpdateRuleset(r)
|
||||||
}
|
|
||||||
return w.udpStreamFactory.UpdateRuleset(r)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *worker) handle(streamID uint32, p gopacket.Packet, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) {
|
func (w *worker) handle(wp *workerPacket) (io.Verdict, []byte) {
|
||||||
netLayer, trLayer := p.NetworkLayer(), p.TransportLayer()
|
data := wp.Data
|
||||||
if netLayer == nil || trLayer == nil {
|
if len(data) == 0 {
|
||||||
// Invalid packet
|
|
||||||
return io.VerdictAccept, nil
|
return io.VerdictAccept, nil
|
||||||
}
|
}
|
||||||
ipFlow := netLayer.NetworkFlow()
|
|
||||||
if len(srcMAC) == 0 && w.macResolver != nil {
|
ipVersion := data[0] >> 4
|
||||||
srcMAC = w.macResolver.Resolve(net.IP(ipFlow.Src().Raw()))
|
if ipVersion == 4 {
|
||||||
|
l3, transport, ok := ParseL3(data)
|
||||||
|
if !ok {
|
||||||
|
return io.VerdictAccept, nil
|
||||||
}
|
}
|
||||||
switch tr := trLayer.(type) {
|
switch l3.Protocol {
|
||||||
case *layers.TCP:
|
case 6: // TCP
|
||||||
return w.handleTCP(ipFlow, srcMAC, dstMAC, p.Metadata(), tr), nil
|
tcp, payload, ok := ParseTCP(transport)
|
||||||
case *layers.UDP:
|
if !ok {
|
||||||
v, modPayload := w.handleUDP(streamID, ipFlow, srcMAC, dstMAC, tr)
|
return io.VerdictAccept, nil
|
||||||
|
}
|
||||||
|
verdict := w.tcpFlowMgr.handle(
|
||||||
|
wp.StreamID, l3, tcp, payload,
|
||||||
|
wp.SrcMAC, wp.DstMAC,
|
||||||
|
)
|
||||||
|
return verdict, nil
|
||||||
|
|
||||||
|
case 17: // UDP
|
||||||
|
udp, payload, ok := ParseUDP(transport)
|
||||||
|
if !ok {
|
||||||
|
return io.VerdictAccept, nil
|
||||||
|
}
|
||||||
|
v, modPayload := w.handleUDP(
|
||||||
|
wp.StreamID, l3, udp, payload,
|
||||||
|
wp.SrcMAC, wp.DstMAC,
|
||||||
|
)
|
||||||
if v == io.VerdictAcceptModify && modPayload != nil {
|
if v == io.VerdictAcceptModify && modPayload != nil {
|
||||||
tr.Payload = modPayload
|
return w.serializeModifiedUDP(data, l3, udp, transport, modPayload)
|
||||||
_ = tr.SetNetworkLayerForChecksum(netLayer)
|
|
||||||
_ = w.modSerializeBuffer.Clear()
|
|
||||||
err := gopacket.SerializePacket(w.modSerializeBuffer,
|
|
||||||
gopacket.SerializeOptions{
|
|
||||||
FixLengths: true,
|
|
||||||
ComputeChecksums: true,
|
|
||||||
}, p)
|
|
||||||
if err != nil {
|
|
||||||
// Just accept without modification for now
|
|
||||||
return io.VerdictAccept, nil
|
|
||||||
}
|
|
||||||
return v, w.modSerializeBuffer.Bytes()
|
|
||||||
}
|
}
|
||||||
return v, nil
|
return v, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// Unsupported protocol
|
|
||||||
return io.VerdictAccept, nil
|
return io.VerdictAccept, nil
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (w *worker) handleTCP(ipFlow gopacket.Flow, srcMAC, dstMAC net.HardwareAddr, pMeta *gopacket.PacketMetadata, tcp *layers.TCP) io.Verdict {
|
|
||||||
ctx := &tcpContext{
|
|
||||||
PacketMetadata: pMeta,
|
|
||||||
Verdict: tcpVerdictAccept,
|
|
||||||
SrcMAC: srcMAC,
|
|
||||||
DstMAC: dstMAC,
|
|
||||||
}
|
}
|
||||||
w.tcpAssembler.AssembleWithContext(ipFlow, tcp, ctx)
|
|
||||||
return io.Verdict(ctx.Verdict)
|
// Ethernet frame path (for custom PacketIO)
|
||||||
|
if ipVersion == 6 {
|
||||||
|
// TODO: IPv6 support with raw parsing
|
||||||
|
return io.VerdictAccept, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return io.VerdictAccept, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, srcMAC, dstMAC net.HardwareAddr, udp *layers.UDP) (io.Verdict, []byte) {
|
func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) {
|
||||||
ctx := &udpContext{
|
ipSrc := net.IP(l3.SrcIP[:])
|
||||||
|
ipDst := net.IP(l3.DstIP[:])
|
||||||
|
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, ipSrc.To4(), ipDst.To4())
|
||||||
|
udpFlow := gopacket.NewFlow(layers.EndpointUDPPort, []byte{byte(udp.SrcPort >> 8), byte(udp.SrcPort)}, []byte{byte(udp.DstPort >> 8), byte(udp.DstPort)})
|
||||||
|
|
||||||
|
if len(srcMAC) == 0 && w.macResolver != nil {
|
||||||
|
srcMAC = w.macResolver.Resolve(ipSrc)
|
||||||
|
}
|
||||||
|
|
||||||
|
uc := &udpContext{
|
||||||
Verdict: udpVerdictAccept,
|
Verdict: udpVerdictAccept,
|
||||||
SrcMAC: srcMAC,
|
SrcMAC: srcMAC,
|
||||||
DstMAC: dstMAC,
|
DstMAC: dstMAC,
|
||||||
}
|
}
|
||||||
w.udpStreamManager.MatchWithContext(streamID, ipFlow, udp, ctx)
|
// Temporarily set payload on a UDP layer so existing UDP handling works
|
||||||
return io.Verdict(ctx.Verdict), ctx.Packet
|
// We pass the payload through the context
|
||||||
|
w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{
|
||||||
|
BaseLayer: layers.BaseLayer{Payload: payload},
|
||||||
|
SrcPort: layers.UDPPort(udp.SrcPort),
|
||||||
|
DstPort: layers.UDPPort(udp.DstPort),
|
||||||
|
}, uc)
|
||||||
|
return io.Verdict(uc.Verdict), uc.Packet
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, udp UDPInfo, transport []byte, modPayload []byte) (io.Verdict, []byte) {
|
||||||
|
ipPkt := gopacket.NewPacket(fullData, layers.LayerTypeIPv4, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
|
||||||
|
netLayer := ipPkt.NetworkLayer()
|
||||||
|
trLayer := ipPkt.TransportLayer()
|
||||||
|
if netLayer == nil || trLayer == nil {
|
||||||
|
return io.VerdictAccept, nil
|
||||||
|
}
|
||||||
|
udpLayer, ok := trLayer.(*layers.UDP)
|
||||||
|
if !ok {
|
||||||
|
return io.VerdictAccept, nil
|
||||||
|
}
|
||||||
|
udpLayer.Payload = modPayload
|
||||||
|
_ = udpLayer.SetNetworkLayerForChecksum(netLayer)
|
||||||
|
_ = w.modSerializeBuffer.Clear()
|
||||||
|
err := gopacket.SerializePacket(w.modSerializeBuffer,
|
||||||
|
gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}, ipPkt)
|
||||||
|
if err != nil {
|
||||||
|
return io.VerdictAccept, nil
|
||||||
|
}
|
||||||
|
return io.VerdictAcceptModify, w.modSerializeBuffer.Bytes()
|
||||||
}
|
}
|
||||||
|
|||||||
+106
-76
@@ -18,9 +18,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
nfqueueNum = 100
|
nfqueueNumStart = 100
|
||||||
nfqueueMaxPacketLen = 0xFFFF
|
|
||||||
nfqueueDefaultQueueSize = 128
|
nfqueueDefaultQueueSize = 128
|
||||||
|
nfqueueDefaultMaxLen = 0xFFFF
|
||||||
|
|
||||||
nfqueueConnMarkAccept = 1001
|
nfqueueConnMarkAccept = 1001
|
||||||
nfqueueConnMarkDrop = 1002
|
nfqueueConnMarkDrop = 1002
|
||||||
@@ -29,17 +29,25 @@ const (
|
|||||||
nftTable = "mellaris"
|
nftTable = "mellaris"
|
||||||
)
|
)
|
||||||
|
|
||||||
func generateNftRules(local, rst bool) (*nftTableSpec, error) {
|
func generateNftRules(local, rst bool, numQueues int) (*nftTableSpec, error) {
|
||||||
if local && rst {
|
if local && rst {
|
||||||
return nil, errors.New("tcp rst is not supported in local mode")
|
return nil, errors.New("tcp rst is not supported in local mode")
|
||||||
}
|
}
|
||||||
|
if numQueues < 1 {
|
||||||
|
numQueues = 1
|
||||||
|
}
|
||||||
table := &nftTableSpec{
|
table := &nftTableSpec{
|
||||||
Family: nftFamily,
|
Family: nftFamily,
|
||||||
Table: nftTable,
|
Table: nftTable,
|
||||||
}
|
}
|
||||||
table.Defines = append(table.Defines, fmt.Sprintf("define ACCEPT_CTMARK=%d", nfqueueConnMarkAccept))
|
table.Defines = append(table.Defines, fmt.Sprintf("define ACCEPT_CTMARK=%d", nfqueueConnMarkAccept))
|
||||||
table.Defines = append(table.Defines, fmt.Sprintf("define DROP_CTMARK=%d", nfqueueConnMarkDrop))
|
table.Defines = append(table.Defines, fmt.Sprintf("define DROP_CTMARK=%d", nfqueueConnMarkDrop))
|
||||||
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", nfqueueNum))
|
queueEnd := nfqueueNumStart + numQueues - 1
|
||||||
|
if numQueues == 1 {
|
||||||
|
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", nfqueueNumStart))
|
||||||
|
} else {
|
||||||
|
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d-%d", nfqueueNumStart, queueEnd))
|
||||||
|
}
|
||||||
if local {
|
if local {
|
||||||
table.Chains = []nftChainSpec{
|
table.Chains = []nftChainSpec{
|
||||||
{Chain: "INPUT", Header: "type filter hook input priority filter; policy accept;"},
|
{Chain: "INPUT", Header: "type filter hook input priority filter; policy accept;"},
|
||||||
@@ -52,7 +60,7 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) {
|
|||||||
}
|
}
|
||||||
for i := range table.Chains {
|
for i := range table.Chains {
|
||||||
c := &table.Chains[i]
|
c := &table.Chains[i]
|
||||||
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") // Bypass protected connections
|
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK")
|
||||||
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
|
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
|
||||||
if rst {
|
if rst {
|
||||||
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
|
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
|
||||||
@@ -63,10 +71,13 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) {
|
|||||||
return table, nil
|
return table, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateIptRules(local, rst bool) ([]iptRule, error) {
|
func generateIptRules(local, rst bool, numQueues int) ([]iptRule, error) {
|
||||||
if local && rst {
|
if local && rst {
|
||||||
return nil, errors.New("tcp rst is not supported in local mode")
|
return nil, errors.New("tcp rst is not supported in local mode")
|
||||||
}
|
}
|
||||||
|
if numQueues < 1 {
|
||||||
|
numQueues = 1
|
||||||
|
}
|
||||||
var chains []string
|
var chains []string
|
||||||
if local {
|
if local {
|
||||||
chains = []string{"INPUT", "OUTPUT"}
|
chains = []string{"INPUT", "OUTPUT"}
|
||||||
@@ -75,16 +86,19 @@ func generateIptRules(local, rst bool) ([]iptRule, error) {
|
|||||||
}
|
}
|
||||||
rules := make([]iptRule, 0, 4*len(chains))
|
rules := make([]iptRule, 0, 4*len(chains))
|
||||||
for _, chain := range chains {
|
for _, chain := range chains {
|
||||||
// Bypass protected connections
|
|
||||||
rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}})
|
rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}})
|
||||||
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
|
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
|
||||||
if rst {
|
if rst {
|
||||||
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
|
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
|
||||||
}
|
}
|
||||||
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}})
|
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}})
|
||||||
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}})
|
if numQueues == 1 {
|
||||||
|
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNumStart), "--queue-bypass"}})
|
||||||
|
} else {
|
||||||
|
queueSpec := fmt.Sprintf("%d:%d", nfqueueNumStart, nfqueueNumStart+numQueues-1)
|
||||||
|
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-balance", queueSpec, "--queue-bypass"}})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,12 +107,12 @@ var _ PacketIO = (*nfqueuePacketIO)(nil)
|
|||||||
var errNotNFQueuePacket = errors.New("not an NFQueue packet")
|
var errNotNFQueuePacket = errors.New("not an NFQueue packet")
|
||||||
|
|
||||||
type nfqueuePacketIO struct {
|
type nfqueuePacketIO struct {
|
||||||
n *nfqueue.Nfqueue
|
nqs []*nfqueue.Nfqueue
|
||||||
|
numQueues int
|
||||||
local bool
|
local bool
|
||||||
rst bool
|
rst bool
|
||||||
rSet bool // whether the nftables/iptables rules have been set
|
rSet bool
|
||||||
|
|
||||||
// iptables not nil = use iptables instead of nftables
|
|
||||||
ipt4 *iptables.IPTables
|
ipt4 *iptables.IPTables
|
||||||
ipt6 *iptables.IPTables
|
ipt6 *iptables.IPTables
|
||||||
|
|
||||||
@@ -111,16 +125,23 @@ type NFQueuePacketIOConfig struct {
|
|||||||
WriteBuffer int
|
WriteBuffer int
|
||||||
Local bool
|
Local bool
|
||||||
RST bool
|
RST bool
|
||||||
|
NumQueues int
|
||||||
|
MaxPacketLen uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
|
func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
|
||||||
if config.QueueSize == 0 {
|
if config.QueueSize == 0 {
|
||||||
config.QueueSize = nfqueueDefaultQueueSize
|
config.QueueSize = nfqueueDefaultQueueSize
|
||||||
}
|
}
|
||||||
|
if config.NumQueues <= 0 {
|
||||||
|
config.NumQueues = 1
|
||||||
|
}
|
||||||
|
if config.MaxPacketLen == 0 {
|
||||||
|
config.MaxPacketLen = nfqueueDefaultMaxLen
|
||||||
|
}
|
||||||
var ipt4, ipt6 *iptables.IPTables
|
var ipt4, ipt6 *iptables.IPTables
|
||||||
var err error
|
var err error
|
||||||
if nftCheck() != nil {
|
if nftCheck() != nil {
|
||||||
// We prefer nftables, but if it's not available, fall back to iptables
|
|
||||||
ipt4, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
ipt4, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -130,32 +151,46 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nqs := make([]*nfqueue.Nfqueue, config.NumQueues)
|
||||||
|
for i := range nqs {
|
||||||
n, err := nfqueue.Open(&nfqueue.Config{
|
n, err := nfqueue.Open(&nfqueue.Config{
|
||||||
NfQueue: nfqueueNum,
|
NfQueue: uint16(nfqueueNumStart + i),
|
||||||
MaxPacketLen: nfqueueMaxPacketLen,
|
MaxPacketLen: config.MaxPacketLen,
|
||||||
MaxQueueLen: config.QueueSize,
|
MaxQueueLen: config.QueueSize,
|
||||||
Copymode: nfqueue.NfQnlCopyPacket,
|
Copymode: nfqueue.NfQnlCopyPacket,
|
||||||
Flags: nfqueue.NfQaCfgFlagConntrack,
|
Flags: nfqueue.NfQaCfgFlagConntrack,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
for j := 0; j < i; j++ {
|
||||||
|
nqs[j].Close()
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if config.ReadBuffer > 0 {
|
if config.ReadBuffer > 0 {
|
||||||
err = n.Con.SetReadBuffer(config.ReadBuffer)
|
err = n.Con.SetReadBuffer(config.ReadBuffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = n.Close()
|
for j := 0; j <= i; j++ {
|
||||||
|
nqs[j].Close()
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if config.WriteBuffer > 0 {
|
if config.WriteBuffer > 0 {
|
||||||
err = n.Con.SetWriteBuffer(config.WriteBuffer)
|
err = n.Con.SetWriteBuffer(config.WriteBuffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = n.Close()
|
for j := 0; j <= i; j++ {
|
||||||
|
nqs[j].Close()
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
nqs[i] = n
|
||||||
|
}
|
||||||
|
|
||||||
return &nfqueuePacketIO{
|
return &nfqueuePacketIO{
|
||||||
n: n,
|
nqs: nqs,
|
||||||
|
numQueues: config.NumQueues,
|
||||||
local: config.Local,
|
local: config.Local,
|
||||||
rst: config.RST,
|
rst: config.RST,
|
||||||
ipt4: ipt4,
|
ipt4: ipt4,
|
||||||
@@ -175,12 +210,14 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
|
func (nio *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
|
||||||
err := n.n.RegisterWithErrorFunc(ctx,
|
for i, nq := range nio.nqs {
|
||||||
|
nq := nq
|
||||||
|
err := nq.RegisterWithErrorFunc(ctx,
|
||||||
func(a nfqueue.Attribute) int {
|
func(a nfqueue.Attribute) int {
|
||||||
if ok, verdict := n.packetAttributeSanityCheck(a); !ok {
|
if ok, verdict := nio.packetAttributeSanityCheck(a); !ok {
|
||||||
if a.PacketID != nil {
|
if a.PacketID != nil {
|
||||||
_ = n.n.SetVerdict(*a.PacketID, verdict)
|
_ = nq.SetVerdict(*a.PacketID, verdict)
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@@ -188,13 +225,13 @@ func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error
|
|||||||
id: *a.PacketID,
|
id: *a.PacketID,
|
||||||
streamID: ctIDFromCtBytes(*a.Ct),
|
streamID: ctIDFromCtBytes(*a.Ct),
|
||||||
data: *a.Payload,
|
data: *a.Payload,
|
||||||
|
nq: nq,
|
||||||
}
|
}
|
||||||
return okBoolToInt(cb(p, nil))
|
return okBoolToInt(cb(p, nil))
|
||||||
},
|
},
|
||||||
func(e error) int {
|
func(e error) int {
|
||||||
if opErr := (*netlink.OpError)(nil); errors.As(e, &opErr) {
|
if opErr := (*netlink.OpError)(nil); errors.As(e, &opErr) {
|
||||||
if errors.Is(opErr.Err, unix.ENOBUFS) {
|
if errors.Is(opErr.Err, unix.ENOBUFS) {
|
||||||
// Kernel buffer temporarily full, ignore
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -203,32 +240,33 @@ func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !n.rSet {
|
|
||||||
if n.ipt4 != nil {
|
|
||||||
err = n.setupIpt(n.local, n.rst, false)
|
|
||||||
} else {
|
|
||||||
err = n.setupNft(n.local, n.rst, false)
|
|
||||||
}
|
}
|
||||||
|
if !nio.rSet {
|
||||||
|
if nio.ipt4 != nil {
|
||||||
|
err := nio.setupIpt(nio.local, nio.rst, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
n.rSet = true
|
} else {
|
||||||
|
err := nio.setupNft(nio.local, nio.rst, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nio.rSet = true
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bool, verdict int) {
|
func (nio *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bool, verdict int) {
|
||||||
if a.PacketID == nil {
|
if a.PacketID == nil {
|
||||||
// Re-inject to NFQUEUE is actually not possible in this condition
|
|
||||||
return false, -1
|
return false, -1
|
||||||
}
|
}
|
||||||
if a.Payload == nil || len(*a.Payload) < 20 {
|
if a.Payload == nil || len(*a.Payload) < 20 {
|
||||||
// 20 is the minimum possible size of an IP packet
|
|
||||||
return false, nfqueue.NfDrop
|
return false, nfqueue.NfDrop
|
||||||
}
|
}
|
||||||
if a.Ct == nil {
|
if a.Ct == nil {
|
||||||
// Multicast packets may not have a conntrack, but only appear in local mode
|
if nio.local {
|
||||||
if n.local {
|
|
||||||
return false, nfqueue.NfAccept
|
return false, nfqueue.NfAccept
|
||||||
}
|
}
|
||||||
return false, nfqueue.NfDrop
|
return false, nfqueue.NfDrop
|
||||||
@@ -236,46 +274,54 @@ func (n *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bo
|
|||||||
return true, -1
|
return true, -1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error {
|
func (nio *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error {
|
||||||
nP, ok := p.(*nfqueuePacket)
|
nP, ok := p.(*nfqueuePacket)
|
||||||
if !ok {
|
if !ok {
|
||||||
return &ErrInvalidPacket{Err: errNotNFQueuePacket}
|
return &ErrInvalidPacket{Err: errNotNFQueuePacket}
|
||||||
}
|
}
|
||||||
switch v {
|
switch v {
|
||||||
case VerdictAccept:
|
case VerdictAccept:
|
||||||
return n.n.SetVerdict(nP.id, nfqueue.NfAccept)
|
return nP.nq.SetVerdict(nP.id, nfqueue.NfAccept)
|
||||||
case VerdictAcceptModify:
|
case VerdictAcceptModify:
|
||||||
return n.n.SetVerdictModPacket(nP.id, nfqueue.NfAccept, newPacket)
|
return nP.nq.SetVerdictModPacket(nP.id, nfqueue.NfAccept, newPacket)
|
||||||
case VerdictAcceptStream:
|
case VerdictAcceptStream:
|
||||||
return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfAccept, nfqueueConnMarkAccept)
|
return nP.nq.SetVerdictWithConnMark(nP.id, nfqueue.NfAccept, nfqueueConnMarkAccept)
|
||||||
case VerdictDrop:
|
case VerdictDrop:
|
||||||
return n.n.SetVerdict(nP.id, nfqueue.NfDrop)
|
return nP.nq.SetVerdict(nP.id, nfqueue.NfDrop)
|
||||||
case VerdictDropStream:
|
case VerdictDropStream:
|
||||||
return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfDrop, nfqueueConnMarkDrop)
|
return nP.nq.SetVerdictWithConnMark(nP.id, nfqueue.NfDrop, nfqueueConnMarkDrop)
|
||||||
default:
|
default:
|
||||||
// Invalid verdict, ignore for now
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
func (nio *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return n.protectedDialer.DialContext(ctx, network, address)
|
return nio.protectedDialer.DialContext(ctx, network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nfqueuePacketIO) Close() error {
|
func (nio *nfqueuePacketIO) Close() error {
|
||||||
if n.rSet {
|
if nio.rSet {
|
||||||
if n.ipt4 != nil {
|
if nio.ipt4 != nil {
|
||||||
_ = n.setupIpt(n.local, n.rst, true)
|
_ = nio.setupIpt(nio.local, nio.rst, true)
|
||||||
} else {
|
} else {
|
||||||
_ = n.setupNft(n.local, n.rst, true)
|
_ = nio.setupNft(nio.local, nio.rst, true)
|
||||||
}
|
}
|
||||||
n.rSet = false
|
nio.rSet = false
|
||||||
}
|
}
|
||||||
return n.n.Close()
|
var errs []error
|
||||||
|
for _, nq := range nio.nqs {
|
||||||
|
if err := nq.Close(); err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return errs[0]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error {
|
func (nio *nfqueuePacketIO) setupNft(local, rst, remove bool) error {
|
||||||
rules, err := generateNftRules(local, rst)
|
rules, err := generateNftRules(local, rst, nio.numQueues)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -283,30 +329,23 @@ func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error {
|
|||||||
if remove {
|
if remove {
|
||||||
err = nftDelete(nftFamily, nftTable)
|
err = nftDelete(nftFamily, nftTable)
|
||||||
} else {
|
} else {
|
||||||
// Delete first to make sure no leftover rules
|
|
||||||
_ = nftDelete(nftFamily, nftTable)
|
_ = nftDelete(nftFamily, nftTable)
|
||||||
err = nftAdd(rulesText)
|
err = nftAdd(rulesText)
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nfqueuePacketIO) setupIpt(local, rst, remove bool) error {
|
func (nio *nfqueuePacketIO) setupIpt(local, rst, remove bool) error {
|
||||||
rules, err := generateIptRules(local, rst)
|
rules, err := generateIptRules(local, rst, nio.numQueues)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if remove {
|
if remove {
|
||||||
err = iptsBatchDeleteIfExists([]*iptables.IPTables{n.ipt4, n.ipt6}, rules)
|
err = iptsBatchDeleteIfExists([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules)
|
||||||
} else {
|
} else {
|
||||||
err = iptsBatchAppendUnique([]*iptables.IPTables{n.ipt4, n.ipt6}, rules)
|
err = iptsBatchAppendUnique([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules)
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Packet = (*nfqueuePacket)(nil)
|
var _ Packet = (*nfqueuePacket)(nil)
|
||||||
@@ -315,30 +354,22 @@ type nfqueuePacket struct {
|
|||||||
id uint32
|
id uint32
|
||||||
streamID uint32
|
streamID uint32
|
||||||
data []byte
|
data []byte
|
||||||
|
nq *nfqueue.Nfqueue
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *nfqueuePacket) StreamID() uint32 {
|
func (p *nfqueuePacket) StreamID() uint32 { return p.streamID }
|
||||||
return p.streamID
|
func (p *nfqueuePacket) Data() []byte { return p.data }
|
||||||
}
|
|
||||||
|
|
||||||
func (p *nfqueuePacket) Data() []byte {
|
|
||||||
return p.data
|
|
||||||
}
|
|
||||||
|
|
||||||
func okBoolToInt(ok bool) int {
|
func okBoolToInt(ok bool) int {
|
||||||
if ok {
|
if ok {
|
||||||
return 0
|
return 0
|
||||||
} else {
|
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func nftCheck() error {
|
func nftCheck() error {
|
||||||
_, err := exec.LookPath("nft")
|
_, err := exec.LookPath("nft")
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func nftAdd(input string) error {
|
func nftAdd(input string) error {
|
||||||
@@ -363,7 +394,6 @@ func (t *nftTableSpec) String() string {
|
|||||||
for _, c := range t.Chains {
|
for _, c := range t.Chains {
|
||||||
chains = append(chains, c.String())
|
chains = append(chains, c.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf(`
|
return fmt.Sprintf(`
|
||||||
%s
|
%s
|
||||||
|
|
||||||
|
|||||||
+27
-14
@@ -7,6 +7,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/expr-lang/expr/builtin"
|
"github.com/expr-lang/expr/builtin"
|
||||||
@@ -67,6 +68,19 @@ type compiledExprRule struct {
|
|||||||
|
|
||||||
var _ Ruleset = (*exprRuleset)(nil)
|
var _ Ruleset = (*exprRuleset)(nil)
|
||||||
|
|
||||||
|
var (
|
||||||
|
envPool = sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
return make(map[string]any, 16)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
subMapPool = sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
return make(map[string]any, 8)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
type exprRuleset struct {
|
type exprRuleset struct {
|
||||||
Rules []compiledExprRule
|
Rules []compiledExprRule
|
||||||
Ans []analyzer.Analyzer
|
Ans []analyzer.Analyzer
|
||||||
@@ -79,7 +93,9 @@ func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
||||||
env := streamInfoToExprEnv(info)
|
env := envPool.Get().(map[string]any)
|
||||||
|
clear(env)
|
||||||
|
populateExprEnv(env, info)
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
for _, rule := range r.Rules {
|
for _, rule := range r.Rules {
|
||||||
if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) {
|
if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) {
|
||||||
@@ -99,6 +115,7 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
|||||||
r.Logger.Log(logInfo, rule.Name)
|
r.Logger.Log(logInfo, rule.Name)
|
||||||
}
|
}
|
||||||
if rule.Action != nil {
|
if rule.Action != nil {
|
||||||
|
envPool.Put(env)
|
||||||
return MatchResult{
|
return MatchResult{
|
||||||
Action: *rule.Action,
|
Action: *rule.Action,
|
||||||
ModInstance: rule.ModInstance,
|
ModInstance: rule.ModInstance,
|
||||||
@@ -106,7 +123,7 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// No match
|
envPool.Put(env)
|
||||||
return MatchResult{
|
return MatchResult{
|
||||||
Action: ActionMaybe,
|
Action: ActionMaybe,
|
||||||
}
|
}
|
||||||
@@ -228,30 +245,26 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamInfoToExprEnv(info StreamInfo) map[string]interface{} {
|
func populateExprEnv(m map[string]any, info StreamInfo) {
|
||||||
m := map[string]interface{}{
|
m["id"] = info.ID
|
||||||
"id": info.ID,
|
m["proto"] = info.Protocol.String()
|
||||||
"proto": info.Protocol.String(),
|
m["mac"] = map[string]string{
|
||||||
"mac": map[string]string{
|
|
||||||
"src": info.SrcMAC.String(),
|
"src": info.SrcMAC.String(),
|
||||||
"dst": info.DstMAC.String(),
|
"dst": info.DstMAC.String(),
|
||||||
},
|
}
|
||||||
"ip": map[string]string{
|
m["ip"] = map[string]string{
|
||||||
"src": info.SrcIP.String(),
|
"src": info.SrcIP.String(),
|
||||||
"dst": info.DstIP.String(),
|
"dst": info.DstIP.String(),
|
||||||
},
|
}
|
||||||
"port": map[string]uint16{
|
m["port"] = map[string]uint16{
|
||||||
"src": info.SrcPort,
|
"src": info.SrcPort,
|
||||||
"dst": info.DstPort,
|
"dst": info.DstPort,
|
||||||
},
|
|
||||||
}
|
}
|
||||||
for anName, anProps := range info.Props {
|
for anName, anProps := range info.Props {
|
||||||
if len(anProps) != 0 {
|
if len(anProps) != 0 {
|
||||||
// Ignore analyzers with empty properties
|
|
||||||
m[anName] = anProps
|
m[anName] = anProps
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return m
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isBuiltInAnalyzer(name string) bool {
|
func isBuiltInAnalyzer(name string) bool {
|
||||||
|
|||||||
Reference in New Issue
Block a user