diff --git a/app.go b/app.go index ea4e552..b1cff75 100644 --- a/app.go +++ b/app.go @@ -44,11 +44,13 @@ func New(cfg Config, opts Options) (*App, error) { ownsIO := false if packetIO == nil { packetIO, err = gfwio.NewNFQueuePacketIO(gfwio.NFQueuePacketIOConfig{ - QueueSize: cfg.IO.QueueSize, - ReadBuffer: cfg.IO.ReadBuffer, - WriteBuffer: cfg.IO.WriteBuffer, - Local: cfg.IO.Local, - RST: cfg.IO.RST, + QueueSize: cfg.IO.QueueSize, + ReadBuffer: cfg.IO.ReadBuffer, + WriteBuffer: cfg.IO.WriteBuffer, + Local: cfg.IO.Local, + RST: cfg.IO.RST, + NumQueues: cfg.IO.NumQueues, + MaxPacketLen: cfg.IO.MaxPacketLen, }) if err != nil { return nil, ConfigError{Field: "io", Err: err} diff --git a/config.go b/config.go index d2d5afe..eaab254 100644 --- a/config.go +++ b/config.go @@ -17,11 +17,13 @@ type Config struct { // IOConfig configures packet IO. type IOConfig struct { - QueueSize uint32 `mapstructure:"queueSize" yaml:"queueSize"` - ReadBuffer int `mapstructure:"rcvBuf" yaml:"rcvBuf"` - WriteBuffer int `mapstructure:"sndBuf" yaml:"sndBuf"` - Local bool `mapstructure:"local" yaml:"local"` - RST bool `mapstructure:"rst" yaml:"rst"` + QueueSize uint32 `mapstructure:"queueSize" yaml:"queueSize"` + ReadBuffer int `mapstructure:"rcvBuf" yaml:"rcvBuf"` + WriteBuffer int `mapstructure:"sndBuf" yaml:"sndBuf"` + Local bool `mapstructure:"local" yaml:"local"` + RST bool `mapstructure:"rst" yaml:"rst"` + NumQueues int `mapstructure:"numQueues" yaml:"numQueues"` + MaxPacketLen uint32 `mapstructure:"maxPacketLen" yaml:"maxPacketLen"` // PacketIO overrides NFQueue creation when set. // When provided, App.Close will call PacketIO.Close. diff --git a/engine/engine.go b/engine/engine.go index ccd8dfe..269a511 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -2,16 +2,11 @@ package engine import ( "context" - "encoding/binary" - "runtime" "sync" "sync/atomic" "git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/ruleset" - - "github.com/google/gopacket" - "github.com/google/gopacket/layers" ) var _ Engine = (*engine)(nil) @@ -27,12 +22,15 @@ type engine struct { workers []*worker verdicts sync.Map // streamID(uint32) → verdictEntry verdictsGen atomic.Int64 // incremented on ruleset update + + overflowCh chan *workerPacket + overflowOnce sync.Once } func NewEngine(config Config) (Engine, error) { workerCount := config.Workers if workerCount <= 0 { - workerCount = runtime.NumCPU() + workerCount = 1 } macResolver := newSourceMACResolver() var err error @@ -53,9 +51,10 @@ func NewEngine(config Config) (Engine, error) { } } e := &engine{ - logger: config.Logger, - io: config.IO, - workers: workers, + logger: config.Logger, + io: config.IO, + workers: workers, + overflowCh: make(chan *workerPacket, 1024), } return e, nil } @@ -75,6 +74,10 @@ func (e *engine) Run(ctx context.Context) error { ioCtx, ioCancel := context.WithCancel(ctx) defer ioCancel() + e.overflowOnce.Do(func() { + go e.drainOverflow(ioCtx) + }) + for _, w := range e.workers { go w.Run(ioCtx) } @@ -111,55 +114,55 @@ func (e *engine) dispatch(p io.Packet) bool { } data := p.Data() - layerType, srcMAC, dstMAC, ok := classifyPacket(data) - if !ok { + if !validPacket(data) { _ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil) return true } gen := e.verdictsGen.Load() index := streamID % uint32(len(e.workers)) - e.workers[index].Feed(&workerPacket{ + wp := &workerPacket{ StreamID: streamID, Data: data, - LayerType: layerType, - SrcMAC: srcMAC, - DstMAC: dstMAC, SetVerdict: func(v io.Verdict, b []byte) error { if v == io.VerdictAcceptStream || v == io.VerdictDropStream { e.verdicts.Store(streamID, verdictEntry{Verdict: v, Gen: gen}) } return e.io.SetVerdict(p, v, b) }, - }) + } + if !e.workers[index].Feed(wp) { + select { + case e.overflowCh <- wp: + default: + } + } return true } -// classifyPacket detects packet framing and returns a gopacket decode layer -// plus best-effort source/destination MAC addresses when available. -func classifyPacket(data []byte) (gopacket.LayerType, []byte, []byte, bool) { +func validPacket(data []byte) bool { if len(data) == 0 { - return 0, nil, nil, false + return false } - - // Fast path for IP packets (NFQUEUE payloads are typically IP-only). ipVersion := data[0] >> 4 - if ipVersion == 4 { - return layers.LayerTypeIPv4, nil, nil, true + if ipVersion == 4 || ipVersion == 6 { + return true } - if ipVersion == 6 { - return layers.LayerTypeIPv6, nil, nil, true - } - - // Ethernet frame path (for custom PacketIO implementations). if len(data) >= 14 { - etherType := binary.BigEndian.Uint16(data[12:14]) - if etherType == uint16(layers.EthernetTypeIPv4) || etherType == uint16(layers.EthernetTypeIPv6) { - return layers.LayerTypeEthernet, - append([]byte(nil), data[6:12]...), - append([]byte(nil), data[:6]...), - true + etherType := uint16(data[12])<<8 | uint16(data[13]) + if etherType == 0x0800 || etherType == 0x86DD { + return true + } + } + return false +} + +func (e *engine) drainOverflow(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case wp := <-e.overflowCh: + _ = wp.SetVerdict(io.VerdictAccept, nil) } } - - return 0, nil, nil, false } diff --git a/engine/packet.go b/engine/packet.go new file mode 100644 index 0000000..c784572 --- /dev/null +++ b/engine/packet.go @@ -0,0 +1,91 @@ +package engine + +import "net" + +type L3Info struct { + Version uint8 + Protocol uint8 + IHL uint8 + SrcIP [4]byte + DstIP [4]byte + Length uint16 +} + +func (i L3Info) SrcIPAddr() net.IP { return net.IP(i.SrcIP[:]) } +func (i L3Info) DstIPAddr() net.IP { return net.IP(i.DstIP[:]) } + +type TCPInfo struct { + SrcPort uint16 + DstPort uint16 + Seq uint32 + Ack uint32 + HdrLen uint8 + SYN bool + FIN bool + RST bool + ACK bool +} + +type UDPInfo struct { + SrcPort uint16 + DstPort uint16 +} + +func ParseL3(data []byte) (l3 L3Info, transport []byte, ok bool) { + if len(data) < 20 { + return + } + version := data[0] >> 4 + if version != 4 { + return + } + ihl := data[0] & 0x0F + if ihl < 5 || len(data) < int(ihl)*4 { + return + } + totalLen := int(uint16(data[2])<<8 | uint16(data[3])) + if totalLen < int(ihl)*4 || totalLen > len(data) { + totalLen = len(data) + } + return L3Info{ + Version: 4, + Protocol: data[9], + IHL: ihl, + Length: uint16(totalLen), + SrcIP: [4]byte{data[12], data[13], data[14], data[15]}, + DstIP: [4]byte{data[16], data[17], data[18], data[19]}, + }, data[ihl*4:totalLen], true +} + +func ParseTCP(transport []byte) (TCPInfo, []byte, bool) { + if len(transport) < 20 { + return TCPInfo{}, nil, false + } + dataOff := uint8(transport[12]>>4) * 4 + if dataOff < 20 || len(transport) < int(dataOff) { + return TCPInfo{}, nil, false + } + flags := transport[13] + payloadLen := len(transport) - int(dataOff) + return TCPInfo{ + SrcPort: uint16(transport[0])<<8 | uint16(transport[1]), + DstPort: uint16(transport[2])<<8 | uint16(transport[3]), + Seq: uint32(transport[4])<<24 | uint32(transport[5])<<16 | uint32(transport[6])<<8 | uint32(transport[7]), + Ack: uint32(transport[8])<<24 | uint32(transport[9])<<16 | uint32(transport[10])<<8 | uint32(transport[11]), + HdrLen: dataOff, + SYN: flags&0x02 != 0, + FIN: flags&0x01 != 0, + RST: flags&0x04 != 0, + ACK: flags&0x10 != 0, + }, transport[dataOff : dataOff+uint8(payloadLen)], true +} + +func ParseUDP(transport []byte) (UDPInfo, []byte, bool) { + if len(transport) < 8 { + return UDPInfo{}, nil, false + } + return UDPInfo{ + SrcPort: uint16(transport[0])<<8 | uint16(transport[1]), + DstPort: uint16(transport[2])<<8 | uint16(transport[3]), + }, transport[8:], true +} diff --git a/engine/tcp.go b/engine/tcp.go deleted file mode 100644 index 4bb8c8c..0000000 --- a/engine/tcp.go +++ /dev/null @@ -1,262 +0,0 @@ -package engine - -import ( - "net" - "sync" - - "git.difuse.io/Difuse/Mellaris/analyzer" - "git.difuse.io/Difuse/Mellaris/io" - "git.difuse.io/Difuse/Mellaris/ruleset" - - "github.com/bwmarrin/snowflake" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" - "github.com/google/gopacket/reassembly" -) - -// tcpVerdict is a subset of io.Verdict for TCP streams. -// We don't allow modifying or dropping a single packet -// for TCP streams for now, as it doesn't make much sense. -type tcpVerdict io.Verdict - -const ( - tcpVerdictAccept = tcpVerdict(io.VerdictAccept) - tcpVerdictAcceptStream = tcpVerdict(io.VerdictAcceptStream) - tcpVerdictDropStream = tcpVerdict(io.VerdictDropStream) -) - -type tcpContext struct { - *gopacket.PacketMetadata - Verdict tcpVerdict - SrcMAC, DstMAC net.HardwareAddr -} - -func (ctx *tcpContext) GetCaptureInfo() gopacket.CaptureInfo { - return ctx.CaptureInfo -} - -type tcpStreamFactory struct { - WorkerID int - Logger Logger - Node *snowflake.Node - - RulesetMutex sync.RWMutex - Ruleset ruleset.Ruleset - RulesetVersion uint64 -} - -func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream { - id := f.Node.Generate() - ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw()) - ctx := ac.(*tcpContext) - info := ruleset.StreamInfo{ - ID: id.Int64(), - Protocol: ruleset.ProtocolTCP, - SrcMAC: append(net.HardwareAddr(nil), ctx.SrcMAC...), - DstMAC: append(net.HardwareAddr(nil), ctx.DstMAC...), - SrcIP: ipSrc, - DstIP: ipDst, - SrcPort: uint16(tcp.SrcPort), - DstPort: uint16(tcp.DstPort), - Props: make(analyzer.CombinedPropMap), - } - f.Logger.TCPStreamNew(f.WorkerID, info) - rs, version := f.currentRuleset() - var ans []analyzer.TCPAnalyzer - if rs != nil { - ans = analyzersToTCPAnalyzers(rs.Analyzers(info)) - } - // Create entries for each analyzer - entries := make([]*tcpStreamEntry, 0, len(ans)) - for _, a := range ans { - entries = append(entries, &tcpStreamEntry{ - Name: a.Name(), - Stream: a.NewTCP(analyzer.TCPInfo{ - SrcIP: ipSrc, - DstIP: ipDst, - SrcPort: uint16(tcp.SrcPort), - DstPort: uint16(tcp.DstPort), - }, &analyzerLogger{ - StreamID: id.Int64(), - Name: a.Name(), - Logger: f.Logger, - }), - HasLimit: a.Limit() > 0, - Quota: a.Limit(), - }) - } - return &tcpStream{ - info: info, - virgin: true, - logger: f.Logger, - rulesetVersion: version, - rulesetSource: f.currentRuleset, - activeEntries: entries, - } -} - -func (f *tcpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error { - f.RulesetMutex.Lock() - defer f.RulesetMutex.Unlock() - f.Ruleset = r - f.RulesetVersion++ - return nil -} - -func (f *tcpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) { - f.RulesetMutex.RLock() - defer f.RulesetMutex.RUnlock() - return f.Ruleset, f.RulesetVersion -} - -type tcpStream struct { - info ruleset.StreamInfo - virgin bool // true if no packets have been processed - logger Logger - rulesetVersion uint64 - rulesetSource func() (ruleset.Ruleset, uint64) - activeEntries []*tcpStreamEntry - doneEntries []*tcpStreamEntry - lastVerdict tcpVerdict -} - -type tcpStreamEntry struct { - Name string - Stream analyzer.TCPStream - HasLimit bool - Quota int -} - -func (s *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool { - if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() { - // Make sure every stream matches against the ruleset at least once, - // even if there are no activeEntries, as the ruleset may have built-in - // properties that need to be matched. - return true - } else { - ctx := ac.(*tcpContext) - ctx.Verdict = s.lastVerdict - return false - } -} - -func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) { - dir, start, end, skip := sg.Info() - rev := dir == reassembly.TCPDirServerToClient - avail, _ := sg.Lengths() - data := sg.Fetch(avail) - updated := false - for i := len(s.activeEntries) - 1; i >= 0; i-- { - // Important: reverse order so we can remove entries - entry := s.activeEntries[i] - update, closeUpdate, done := s.feedEntry(entry, rev, start, end, skip, data) - up1 := processPropUpdate(s.info.Props, entry.Name, update) - up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate) - updated = updated || up1 || up2 - if done { - s.activeEntries = append(s.activeEntries[:i], s.activeEntries[i+1:]...) - s.doneEntries = append(s.doneEntries, entry) - } - } - ctx := ac.(*tcpContext) - rs, version := s.currentRuleset() - rulesetChanged := version != s.rulesetVersion - s.rulesetVersion = version - if updated || s.virgin || rulesetChanged { - s.virgin = false - s.logger.TCPStreamPropUpdate(s.info, false) - // Match properties against ruleset - result := ruleset.MatchResult{Action: ruleset.ActionMaybe} - if rs != nil { - result = rs.Match(s.info) - } - action := result.Action - if action != ruleset.ActionMaybe && action != ruleset.ActionModify { - verdict := actionToTCPVerdict(action) - s.lastVerdict = verdict - ctx.Verdict = verdict - s.logger.TCPStreamAction(s.info, action, false) - // Verdict issued, no need to process any more packets - s.closeActiveEntries() - } - } - if len(s.activeEntries) == 0 && ctx.Verdict == tcpVerdictAccept { - // All entries are done but no verdict issued, accept stream - s.lastVerdict = tcpVerdictAcceptStream - ctx.Verdict = tcpVerdictAcceptStream - s.logger.TCPStreamAction(s.info, ruleset.ActionAllow, true) - } -} - -func (s *tcpStream) currentRuleset() (ruleset.Ruleset, uint64) { - if s.rulesetSource == nil { - return nil, s.rulesetVersion - } - return s.rulesetSource() -} - -func (s *tcpStream) rulesetChanged() bool { - _, version := s.currentRuleset() - return version != s.rulesetVersion -} - -func (s *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { - s.closeActiveEntries() - return true -} - -func (s *tcpStream) closeActiveEntries() { - // Signal close to all active entries & move them to doneEntries - updated := false - for _, entry := range s.activeEntries { - update := entry.Stream.Close(false) - up := processPropUpdate(s.info.Props, entry.Name, update) - updated = updated || up - } - if updated { - s.logger.TCPStreamPropUpdate(s.info, true) - } - s.doneEntries = append(s.doneEntries, s.activeEntries...) - s.activeEntries = nil -} - -func (s *tcpStream) feedEntry(entry *tcpStreamEntry, rev, start, end bool, skip int, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) { - if !entry.HasLimit { - update, done = entry.Stream.Feed(rev, start, end, skip, data) - } else { - qData := data - if len(qData) > entry.Quota { - qData = qData[:entry.Quota] - } - update, done = entry.Stream.Feed(rev, start, end, skip, qData) - entry.Quota -= len(qData) - if entry.Quota <= 0 { - // Quota exhausted, signal close & move to doneEntries - closeUpdate = entry.Stream.Close(true) - done = true - } - } - return -} - -func analyzersToTCPAnalyzers(ans []analyzer.Analyzer) []analyzer.TCPAnalyzer { - tcpAns := make([]analyzer.TCPAnalyzer, 0, len(ans)) - for _, a := range ans { - if tcpM, ok := a.(analyzer.TCPAnalyzer); ok { - tcpAns = append(tcpAns, tcpM) - } - } - return tcpAns -} - -func actionToTCPVerdict(a ruleset.Action) tcpVerdict { - switch a { - case ruleset.ActionMaybe, ruleset.ActionAllow, ruleset.ActionModify: - return tcpVerdictAcceptStream - case ruleset.ActionBlock, ruleset.ActionDrop: - return tcpVerdictDropStream - default: - // Should never happen - return tcpVerdictAcceptStream - } -} diff --git a/engine/tcp_flow.go b/engine/tcp_flow.go new file mode 100644 index 0000000..17020c3 --- /dev/null +++ b/engine/tcp_flow.go @@ -0,0 +1,302 @@ +package engine + +import ( + "net" + "sync" + + "git.difuse.io/Difuse/Mellaris/analyzer" + "git.difuse.io/Difuse/Mellaris/io" + "git.difuse.io/Difuse/Mellaris/ruleset" + + "github.com/bwmarrin/snowflake" +) + +const tcpFlowMaxBuffer = 16384 + +type tcpFlowDirection uint8 + +const ( + tcpDirC2S tcpFlowDirection = iota + tcpDirS2C +) + +type tcpFlow struct { + streamID uint32 + srcIP [4]byte + dstIP [4]byte + srcPort uint16 + dstPort uint16 + + dirSeq [2]uint32 + dirBuf [2][]byte + + info ruleset.StreamInfo + virgin bool + logger Logger + rulesetVersion uint64 + rulesetSource func() (ruleset.Ruleset, uint64) + activeEntries []*tcpFlowEntry + doneEntries []*tcpFlowEntry + lastVerdict io.Verdict + feedCalled [2]bool +} + +type tcpFlowEntry struct { + Name string + Stream analyzer.TCPStream + HasLimit bool + Quota int +} + +func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict { + if f.rulesetChanged() || f.virgin { + f.virgin = false + return io.VerdictAccept + } + if len(f.activeEntries) == 0 { + return f.lastVerdict + } + + dir, rev := f.resolveDirection(tcp) + + if tcp.RST || tcp.FIN { + f.closeActiveEntries() + f.maybeFinalizeVerdict() + return f.lastVerdict + } + + if len(payload) == 0 { + return io.VerdictAccept + } + + expected := f.dirSeq[dir] + if f.feedCalled[dir] && expected != 0 && tcp.Seq != expected { + return io.VerdictAccept + } + f.feedCalled[dir] = true + f.dirBuf[dir] = append(f.dirBuf[dir], payload...) + f.dirSeq[dir] = tcp.Seq + uint32(len(payload)) + + if len(f.dirBuf[dir]) > tcpFlowMaxBuffer { + return io.VerdictAccept + } + + updated := false + for i := len(f.activeEntries) - 1; i >= 0; i-- { + entry := f.activeEntries[i] + update, closeUpdate, done := feedFlowEntry(entry, rev, f.dirBuf[dir]) + u1 := processPropUpdate(f.info.Props, entry.Name, update) + u2 := processPropUpdate(f.info.Props, entry.Name, closeUpdate) + updated = updated || u1 || u2 + if done { + f.activeEntries = append(f.activeEntries[:i], f.activeEntries[i+1:]...) + f.doneEntries = append(f.doneEntries, entry) + } + } + + if updated { + f.logger.TCPStreamPropUpdate(f.info, false) + rs, version := f.currentRuleset() + f.rulesetVersion = version + result := ruleset.MatchResult{Action: ruleset.ActionMaybe} + if rs != nil { + result = rs.Match(f.info) + } + action := result.Action + if action != ruleset.ActionMaybe && action != ruleset.ActionModify { + verdict := actionToTCPVerdict(action) + f.lastVerdict = verdict + f.closeActiveEntries() + f.logger.TCPStreamAction(f.info, action, false) + return verdict + } + } + + f.maybeFinalizeVerdict() + return f.lastVerdict +} + +func (f *tcpFlow) maybeFinalizeVerdict() { + if len(f.activeEntries) == 0 && f.lastVerdict == io.VerdictAccept { + f.lastVerdict = io.VerdictAcceptStream + f.logger.TCPStreamAction(f.info, ruleset.ActionAllow, true) + } +} + +func (f *tcpFlow) resolveDirection(tcp TCPInfo) (dir uint8, rev bool) { + if tcp.SrcPort == f.srcPort { + return uint8(tcpDirC2S), false + } + return uint8(tcpDirS2C), true +} + +func (f *tcpFlow) currentRuleset() (ruleset.Ruleset, uint64) { + if f.rulesetSource == nil { + return nil, f.rulesetVersion + } + return f.rulesetSource() +} + +func (f *tcpFlow) rulesetChanged() bool { + _, version := f.currentRuleset() + return version != f.rulesetVersion +} + +func (f *tcpFlow) closeActiveEntries() { + updated := false + for _, entry := range f.activeEntries { + update := entry.Stream.Close(false) + updated = updated || processPropUpdate(f.info.Props, entry.Name, update) + } + if updated { + f.logger.TCPStreamPropUpdate(f.info, true) + } + f.doneEntries = append(f.doneEntries, f.activeEntries...) + f.activeEntries = nil +} + +type tcpFlowManager struct { + mu sync.Mutex + flows map[uint32]*tcpFlow + sfNode *snowflake.Node + logger Logger + rulesetSource func() (ruleset.Ruleset, uint64) + workerID int + macResolver *sourceMACResolver +} + +func newTCPFlowManager(workerID int, logger Logger, macResolver *sourceMACResolver, node *snowflake.Node) *tcpFlowManager { + return &tcpFlowManager{ + flows: make(map[uint32]*tcpFlow), + sfNode: node, + logger: logger, + workerID: workerID, + macResolver: macResolver, + } +} + +func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) io.Verdict { + m.mu.Lock() + flow, ok := m.flows[streamID] + if !ok { + flow = m.createFlow(streamID, l3, tcp, srcMAC, dstMAC) + m.flows[streamID] = flow + } + m.mu.Unlock() + + verdict := flow.feed(l3, tcp, payload) + + if verdict == io.VerdictAcceptStream || verdict == io.VerdictDropStream || tcp.RST || tcp.FIN { + m.mu.Lock() + delete(m.flows, streamID) + m.mu.Unlock() + } + + return verdict +} + +func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, srcMAC, dstMAC net.HardwareAddr) *tcpFlow { + id := m.sfNode.Generate() + ipSrc := net.IP(l3.SrcIP[:]) + ipDst := net.IP(l3.DstIP[:]) + if len(srcMAC) == 0 && m.macResolver != nil { + srcMAC = m.macResolver.Resolve(ipSrc) + } + info := ruleset.StreamInfo{ + ID: id.Int64(), + Protocol: ruleset.ProtocolTCP, + SrcMAC: append(net.HardwareAddr(nil), srcMAC...), + DstMAC: append(net.HardwareAddr(nil), dstMAC...), + SrcIP: ipSrc, + DstIP: ipDst, + SrcPort: tcp.SrcPort, + DstPort: tcp.DstPort, + Props: make(analyzer.CombinedPropMap), + } + m.logger.TCPStreamNew(m.workerID, info) + rs, version := m.rulesetSource() + var ans []analyzer.TCPAnalyzer + if rs != nil { + ans = analyzersToTCPAnalyzers(rs.Analyzers(info)) + } + entries := make([]*tcpFlowEntry, 0, len(ans)) + for _, a := range ans { + entries = append(entries, &tcpFlowEntry{ + Name: a.Name(), + Stream: a.NewTCP(analyzer.TCPInfo{ + SrcIP: ipSrc, + DstIP: ipDst, + SrcPort: tcp.SrcPort, + DstPort: tcp.DstPort, + }, &analyzerLogger{ + StreamID: id.Int64(), + Name: a.Name(), + Logger: m.logger, + }), + HasLimit: a.Limit() > 0, + Quota: a.Limit(), + }) + } + + flow := &tcpFlow{ + streamID: streamID, + srcIP: l3.SrcIP, + dstIP: l3.DstIP, + srcPort: tcp.SrcPort, + dstPort: tcp.DstPort, + info: info, + virgin: true, + logger: m.logger, + rulesetSource: m.rulesetSource, + rulesetVersion: version, + activeEntries: entries, + lastVerdict: io.VerdictAccept, + } + flow.dirSeq[tcpDirC2S] = tcp.Seq + 1 + return flow +} + +func (m *tcpFlowManager) updateRuleset(r ruleset.Ruleset, version uint64) { + m.rulesetSource = func() (ruleset.Ruleset, uint64) { + return r, version + } +} + +func feedFlowEntry(entry *tcpFlowEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) { + if !entry.HasLimit { + update, done = entry.Stream.Feed(rev, true, false, 0, data) + } else { + qData := data + if len(qData) > entry.Quota { + qData = qData[:entry.Quota] + } + update, done = entry.Stream.Feed(rev, true, false, 0, qData) + entry.Quota -= len(qData) + if entry.Quota <= 0 { + closeUpdate = entry.Stream.Close(true) + done = true + } + } + return +} + +func analyzersToTCPAnalyzers(ans []analyzer.Analyzer) []analyzer.TCPAnalyzer { + tcpAns := make([]analyzer.TCPAnalyzer, 0, len(ans)) + for _, a := range ans { + if ta, ok := a.(analyzer.TCPAnalyzer); ok { + tcpAns = append(tcpAns, ta) + } + } + return tcpAns +} + +func actionToTCPVerdict(a ruleset.Action) io.Verdict { + switch a { + case ruleset.ActionMaybe, ruleset.ActionAllow, ruleset.ActionModify: + return io.VerdictAcceptStream + case ruleset.ActionBlock, ruleset.ActionDrop: + return io.VerdictDropStream + default: + return io.VerdictAcceptStream + } +} diff --git a/engine/worker.go b/engine/worker.go index 558ba72..ade57a7 100644 --- a/engine/worker.go +++ b/engine/worker.go @@ -10,20 +10,13 @@ import ( "github.com/bwmarrin/snowflake" "github.com/google/gopacket" "github.com/google/gopacket/layers" - "github.com/google/gopacket/reassembly" ) -const ( - defaultChanSize = 64 - defaultTCPMaxBufferedPagesTotal = 4096 - defaultTCPMaxBufferedPagesPerConnection = 64 - defaultUDPMaxStreams = 4096 -) +var _ Engine = (*engine)(nil) type workerPacket struct { StreamID uint32 Data []byte - LayerType gopacket.LayerType SrcMAC net.HardwareAddr DstMAC net.HardwareAddr SetVerdict func(io.Verdict, []byte) error @@ -35,12 +28,8 @@ type worker struct { logger Logger macResolver *sourceMACResolver - tcpStreamFactory *tcpStreamFactory - tcpStreamPool *reassembly.StreamPool - tcpAssembler *reassembly.Assembler - - udpStreamFactory *udpStreamFactory - udpStreamManager *udpStreamManager + tcpFlowMgr *tcpFlowManager + udpSM *udpStreamManager modSerializeBuffer gopacket.SerializeBuffer } @@ -51,23 +40,17 @@ type workerConfig struct { Logger Logger Ruleset ruleset.Ruleset MACResolver *sourceMACResolver - TCPMaxBufferedPagesTotal int - TCPMaxBufferedPagesPerConn int + TCPMaxBufferedPagesTotal int // unused, kept for config compat + TCPMaxBufferedPagesPerConn int // unused, kept for config compat UDPMaxStreams int } func (c *workerConfig) fillDefaults() { if c.ChanSize <= 0 { - c.ChanSize = defaultChanSize - } - if c.TCPMaxBufferedPagesTotal <= 0 { - c.TCPMaxBufferedPagesTotal = defaultTCPMaxBufferedPagesTotal - } - if c.TCPMaxBufferedPagesPerConn <= 0 { - c.TCPMaxBufferedPagesPerConn = defaultTCPMaxBufferedPagesPerConnection + c.ChanSize = 64 } if c.UDPMaxStreams <= 0 { - c.UDPMaxStreams = defaultUDPMaxStreams + c.UDPMaxStreams = 4096 } } @@ -77,16 +60,12 @@ func newWorker(config workerConfig) (*worker, error) { if err != nil { return nil, err } - tcpSF := &tcpStreamFactory{ - WorkerID: config.ID, - Logger: config.Logger, - Node: sfNode, - Ruleset: config.Ruleset, + + tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode) + if config.Ruleset != nil { + tcpMgr.updateRuleset(config.Ruleset, 0) } - tcpStreamPool := reassembly.NewStreamPool(tcpSF) - tcpAssembler := reassembly.NewAssembler(tcpStreamPool) - tcpAssembler.MaxBufferedPagesTotal = config.TCPMaxBufferedPagesTotal - tcpAssembler.MaxBufferedPagesPerConnection = config.TCPMaxBufferedPagesPerConn + udpSF := &udpStreamFactory{ WorkerID: config.ID, Logger: config.Logger, @@ -97,25 +76,24 @@ func newWorker(config workerConfig) (*worker, error) { if err != nil { return nil, err } + return &worker{ id: config.ID, packetChan: make(chan *workerPacket, config.ChanSize), logger: config.Logger, macResolver: config.MACResolver, - tcpStreamFactory: tcpSF, - tcpStreamPool: tcpStreamPool, - tcpAssembler: tcpAssembler, - udpStreamFactory: udpSF, - udpStreamManager: udpSM, + tcpFlowMgr: tcpMgr, + udpSM: udpSM, modSerializeBuffer: gopacket.NewSerializeBuffer(), }, nil } -func (w *worker) Feed(p *workerPacket) { +func (w *worker) Feed(p *workerPacket) bool { select { case w.packetChan <- p: + return true default: - _ = p.SetVerdict(io.VerdictAccept, nil) + return false } } @@ -126,78 +104,116 @@ func (w *worker) Run(ctx context.Context) { select { case <-ctx.Done(): return - case wPkt := <-w.packetChan: - if wPkt == nil { + case wp := <-w.packetChan: + if wp == nil { return } - pkt := gopacket.NewPacket(wPkt.Data, wPkt.LayerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true}) - v, b := w.handle(wPkt.StreamID, pkt, wPkt.SrcMAC, wPkt.DstMAC) - _ = wPkt.SetVerdict(v, b) + v, b := w.handle(wp) + _ = wp.SetVerdict(v, b) } } } func (w *worker) UpdateRuleset(r ruleset.Ruleset) error { - if err := w.tcpStreamFactory.UpdateRuleset(r); err != nil { - return err - } - return w.udpStreamFactory.UpdateRuleset(r) + w.tcpFlowMgr.updateRuleset(r, 0) + return w.udpSM.factory.UpdateRuleset(r) } -func (w *worker) handle(streamID uint32, p gopacket.Packet, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) { - netLayer, trLayer := p.NetworkLayer(), p.TransportLayer() - if netLayer == nil || trLayer == nil { - // Invalid packet +func (w *worker) handle(wp *workerPacket) (io.Verdict, []byte) { + data := wp.Data + if len(data) == 0 { return io.VerdictAccept, nil } - ipFlow := netLayer.NetworkFlow() - if len(srcMAC) == 0 && w.macResolver != nil { - srcMAC = w.macResolver.Resolve(net.IP(ipFlow.Src().Raw())) - } - switch tr := trLayer.(type) { - case *layers.TCP: - return w.handleTCP(ipFlow, srcMAC, dstMAC, p.Metadata(), tr), nil - case *layers.UDP: - v, modPayload := w.handleUDP(streamID, ipFlow, srcMAC, dstMAC, tr) - if v == io.VerdictAcceptModify && modPayload != nil { - tr.Payload = modPayload - _ = tr.SetNetworkLayerForChecksum(netLayer) - _ = w.modSerializeBuffer.Clear() - err := gopacket.SerializePacket(w.modSerializeBuffer, - gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - }, p) - if err != nil { - // Just accept without modification for now + + ipVersion := data[0] >> 4 + if ipVersion == 4 { + l3, transport, ok := ParseL3(data) + if !ok { + return io.VerdictAccept, nil + } + switch l3.Protocol { + case 6: // TCP + tcp, payload, ok := ParseTCP(transport) + if !ok { return io.VerdictAccept, nil } - return v, w.modSerializeBuffer.Bytes() + verdict := w.tcpFlowMgr.handle( + wp.StreamID, l3, tcp, payload, + wp.SrcMAC, wp.DstMAC, + ) + return verdict, nil + + case 17: // UDP + udp, payload, ok := ParseUDP(transport) + if !ok { + return io.VerdictAccept, nil + } + v, modPayload := w.handleUDP( + wp.StreamID, l3, udp, payload, + wp.SrcMAC, wp.DstMAC, + ) + if v == io.VerdictAcceptModify && modPayload != nil { + return w.serializeModifiedUDP(data, l3, udp, transport, modPayload) + } + return v, nil + + default: + return io.VerdictAccept, nil } - return v, nil - default: - // Unsupported protocol + } + + // Ethernet frame path (for custom PacketIO) + if ipVersion == 6 { + // TODO: IPv6 support with raw parsing return io.VerdictAccept, nil } + + return io.VerdictAccept, nil } -func (w *worker) handleTCP(ipFlow gopacket.Flow, srcMAC, dstMAC net.HardwareAddr, pMeta *gopacket.PacketMetadata, tcp *layers.TCP) io.Verdict { - ctx := &tcpContext{ - PacketMetadata: pMeta, - Verdict: tcpVerdictAccept, - SrcMAC: srcMAC, - DstMAC: dstMAC, +func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) { + ipSrc := net.IP(l3.SrcIP[:]) + ipDst := net.IP(l3.DstIP[:]) + ipFlow := gopacket.NewFlow(layers.EndpointIPv4, ipSrc.To4(), ipDst.To4()) + udpFlow := gopacket.NewFlow(layers.EndpointUDPPort, []byte{byte(udp.SrcPort >> 8), byte(udp.SrcPort)}, []byte{byte(udp.DstPort >> 8), byte(udp.DstPort)}) + + if len(srcMAC) == 0 && w.macResolver != nil { + srcMAC = w.macResolver.Resolve(ipSrc) } - w.tcpAssembler.AssembleWithContext(ipFlow, tcp, ctx) - return io.Verdict(ctx.Verdict) -} -func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, srcMAC, dstMAC net.HardwareAddr, udp *layers.UDP) (io.Verdict, []byte) { - ctx := &udpContext{ + uc := &udpContext{ Verdict: udpVerdictAccept, SrcMAC: srcMAC, DstMAC: dstMAC, } - w.udpStreamManager.MatchWithContext(streamID, ipFlow, udp, ctx) - return io.Verdict(ctx.Verdict), ctx.Packet + // Temporarily set payload on a UDP layer so existing UDP handling works + // We pass the payload through the context + w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{ + BaseLayer: layers.BaseLayer{Payload: payload}, + SrcPort: layers.UDPPort(udp.SrcPort), + DstPort: layers.UDPPort(udp.DstPort), + }, uc) + return io.Verdict(uc.Verdict), uc.Packet +} + +func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, udp UDPInfo, transport []byte, modPayload []byte) (io.Verdict, []byte) { + ipPkt := gopacket.NewPacket(fullData, layers.LayerTypeIPv4, gopacket.DecodeOptions{Lazy: true, NoCopy: true}) + netLayer := ipPkt.NetworkLayer() + trLayer := ipPkt.TransportLayer() + if netLayer == nil || trLayer == nil { + return io.VerdictAccept, nil + } + udpLayer, ok := trLayer.(*layers.UDP) + if !ok { + return io.VerdictAccept, nil + } + udpLayer.Payload = modPayload + _ = udpLayer.SetNetworkLayerForChecksum(netLayer) + _ = w.modSerializeBuffer.Clear() + err := gopacket.SerializePacket(w.modSerializeBuffer, + gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}, ipPkt) + if err != nil { + return io.VerdictAccept, nil + } + return io.VerdictAcceptModify, w.modSerializeBuffer.Bytes() } diff --git a/io/nfqueue.go b/io/nfqueue.go index b4956b4..3fb0247 100644 --- a/io/nfqueue.go +++ b/io/nfqueue.go @@ -18,9 +18,9 @@ import ( ) const ( - nfqueueNum = 100 - nfqueueMaxPacketLen = 0xFFFF + nfqueueNumStart = 100 nfqueueDefaultQueueSize = 128 + nfqueueDefaultMaxLen = 0xFFFF nfqueueConnMarkAccept = 1001 nfqueueConnMarkDrop = 1002 @@ -29,17 +29,25 @@ const ( nftTable = "mellaris" ) -func generateNftRules(local, rst bool) (*nftTableSpec, error) { +func generateNftRules(local, rst bool, numQueues int) (*nftTableSpec, error) { if local && rst { return nil, errors.New("tcp rst is not supported in local mode") } + if numQueues < 1 { + numQueues = 1 + } table := &nftTableSpec{ Family: nftFamily, Table: nftTable, } table.Defines = append(table.Defines, fmt.Sprintf("define ACCEPT_CTMARK=%d", nfqueueConnMarkAccept)) table.Defines = append(table.Defines, fmt.Sprintf("define DROP_CTMARK=%d", nfqueueConnMarkDrop)) - table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", nfqueueNum)) + queueEnd := nfqueueNumStart + numQueues - 1 + if numQueues == 1 { + table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", nfqueueNumStart)) + } else { + table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d-%d", nfqueueNumStart, queueEnd)) + } if local { table.Chains = []nftChainSpec{ {Chain: "INPUT", Header: "type filter hook input priority filter; policy accept;"}, @@ -52,7 +60,7 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) { } for i := range table.Chains { c := &table.Chains[i] - c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") // Bypass protected connections + c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept") if rst { c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset") @@ -63,10 +71,13 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) { return table, nil } -func generateIptRules(local, rst bool) ([]iptRule, error) { +func generateIptRules(local, rst bool, numQueues int) ([]iptRule, error) { if local && rst { return nil, errors.New("tcp rst is not supported in local mode") } + if numQueues < 1 { + numQueues = 1 + } var chains []string if local { chains = []string{"INPUT", "OUTPUT"} @@ -75,16 +86,19 @@ func generateIptRules(local, rst bool) ([]iptRule, error) { } rules := make([]iptRule, 0, 4*len(chains)) for _, chain := range chains { - // Bypass protected connections rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}}) rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}) if rst { rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}}) } rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}}) - rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}}) + if numQueues == 1 { + rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNumStart), "--queue-bypass"}}) + } else { + queueSpec := fmt.Sprintf("%d:%d", nfqueueNumStart, nfqueueNumStart+numQueues-1) + rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-balance", queueSpec, "--queue-bypass"}}) + } } - return rules, nil } @@ -93,12 +107,12 @@ var _ PacketIO = (*nfqueuePacketIO)(nil) var errNotNFQueuePacket = errors.New("not an NFQueue packet") type nfqueuePacketIO struct { - n *nfqueue.Nfqueue - local bool - rst bool - rSet bool // whether the nftables/iptables rules have been set + nqs []*nfqueue.Nfqueue + numQueues int + local bool + rst bool + rSet bool - // iptables not nil = use iptables instead of nftables ipt4 *iptables.IPTables ipt6 *iptables.IPTables @@ -106,21 +120,28 @@ type nfqueuePacketIO struct { } type NFQueuePacketIOConfig struct { - QueueSize uint32 - ReadBuffer int - WriteBuffer int - Local bool - RST bool + QueueSize uint32 + ReadBuffer int + WriteBuffer int + Local bool + RST bool + NumQueues int + MaxPacketLen uint32 } func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) { if config.QueueSize == 0 { config.QueueSize = nfqueueDefaultQueueSize } + if config.NumQueues <= 0 { + config.NumQueues = 1 + } + if config.MaxPacketLen == 0 { + config.MaxPacketLen = nfqueueDefaultMaxLen + } var ipt4, ipt6 *iptables.IPTables var err error if nftCheck() != nil { - // We prefer nftables, but if it's not available, fall back to iptables ipt4, err = iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { return nil, err @@ -130,36 +151,50 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) { return nil, err } } - n, err := nfqueue.Open(&nfqueue.Config{ - NfQueue: nfqueueNum, - MaxPacketLen: nfqueueMaxPacketLen, - MaxQueueLen: config.QueueSize, - Copymode: nfqueue.NfQnlCopyPacket, - Flags: nfqueue.NfQaCfgFlagConntrack, - }) - if err != nil { - return nil, err - } - if config.ReadBuffer > 0 { - err = n.Con.SetReadBuffer(config.ReadBuffer) + + nqs := make([]*nfqueue.Nfqueue, config.NumQueues) + for i := range nqs { + n, err := nfqueue.Open(&nfqueue.Config{ + NfQueue: uint16(nfqueueNumStart + i), + MaxPacketLen: config.MaxPacketLen, + MaxQueueLen: config.QueueSize, + Copymode: nfqueue.NfQnlCopyPacket, + Flags: nfqueue.NfQaCfgFlagConntrack, + }) if err != nil { - _ = n.Close() + for j := 0; j < i; j++ { + nqs[j].Close() + } return nil, err } - } - if config.WriteBuffer > 0 { - err = n.Con.SetWriteBuffer(config.WriteBuffer) - if err != nil { - _ = n.Close() - return nil, err + if config.ReadBuffer > 0 { + err = n.Con.SetReadBuffer(config.ReadBuffer) + if err != nil { + for j := 0; j <= i; j++ { + nqs[j].Close() + } + return nil, err + } } + if config.WriteBuffer > 0 { + err = n.Con.SetWriteBuffer(config.WriteBuffer) + if err != nil { + for j := 0; j <= i; j++ { + nqs[j].Close() + } + return nil, err + } + } + nqs[i] = n } + return &nfqueuePacketIO{ - n: n, - local: config.Local, - rst: config.RST, - ipt4: ipt4, - ipt6: ipt6, + nqs: nqs, + numQueues: config.NumQueues, + local: config.Local, + rst: config.RST, + ipt4: ipt4, + ipt6: ipt6, protectedDialer: &net.Dialer{ Control: func(network, address string, c syscall.RawConn) error { var err error @@ -175,60 +210,63 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) { }, nil } -func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error { - err := n.n.RegisterWithErrorFunc(ctx, - func(a nfqueue.Attribute) int { - if ok, verdict := n.packetAttributeSanityCheck(a); !ok { - if a.PacketID != nil { - _ = n.n.SetVerdict(*a.PacketID, verdict) - } - return 0 - } - p := &nfqueuePacket{ - id: *a.PacketID, - streamID: ctIDFromCtBytes(*a.Ct), - data: *a.Payload, - } - return okBoolToInt(cb(p, nil)) - }, - func(e error) int { - if opErr := (*netlink.OpError)(nil); errors.As(e, &opErr) { - if errors.Is(opErr.Err, unix.ENOBUFS) { - // Kernel buffer temporarily full, ignore +func (nio *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error { + for i, nq := range nio.nqs { + nq := nq + err := nq.RegisterWithErrorFunc(ctx, + func(a nfqueue.Attribute) int { + if ok, verdict := nio.packetAttributeSanityCheck(a); !ok { + if a.PacketID != nil { + _ = nq.SetVerdict(*a.PacketID, verdict) + } return 0 } - } - return okBoolToInt(cb(nil, e)) - }) - if err != nil { - return err - } - if !n.rSet { - if n.ipt4 != nil { - err = n.setupIpt(n.local, n.rst, false) - } else { - err = n.setupNft(n.local, n.rst, false) - } + p := &nfqueuePacket{ + id: *a.PacketID, + streamID: ctIDFromCtBytes(*a.Ct), + data: *a.Payload, + nq: nq, + } + return okBoolToInt(cb(p, nil)) + }, + func(e error) int { + if opErr := (*netlink.OpError)(nil); errors.As(e, &opErr) { + if errors.Is(opErr.Err, unix.ENOBUFS) { + return 0 + } + } + return okBoolToInt(cb(nil, e)) + }) if err != nil { return err } - n.rSet = true + } + if !nio.rSet { + if nio.ipt4 != nil { + err := nio.setupIpt(nio.local, nio.rst, false) + if err != nil { + return err + } + } else { + err := nio.setupNft(nio.local, nio.rst, false) + if err != nil { + return err + } + } + nio.rSet = true } return nil } -func (n *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bool, verdict int) { +func (nio *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bool, verdict int) { if a.PacketID == nil { - // Re-inject to NFQUEUE is actually not possible in this condition return false, -1 } if a.Payload == nil || len(*a.Payload) < 20 { - // 20 is the minimum possible size of an IP packet return false, nfqueue.NfDrop } if a.Ct == nil { - // Multicast packets may not have a conntrack, but only appear in local mode - if n.local { + if nio.local { return false, nfqueue.NfAccept } return false, nfqueue.NfDrop @@ -236,46 +274,54 @@ func (n *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bo return true, -1 } -func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error { +func (nio *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error { nP, ok := p.(*nfqueuePacket) if !ok { return &ErrInvalidPacket{Err: errNotNFQueuePacket} } switch v { case VerdictAccept: - return n.n.SetVerdict(nP.id, nfqueue.NfAccept) + return nP.nq.SetVerdict(nP.id, nfqueue.NfAccept) case VerdictAcceptModify: - return n.n.SetVerdictModPacket(nP.id, nfqueue.NfAccept, newPacket) + return nP.nq.SetVerdictModPacket(nP.id, nfqueue.NfAccept, newPacket) case VerdictAcceptStream: - return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfAccept, nfqueueConnMarkAccept) + return nP.nq.SetVerdictWithConnMark(nP.id, nfqueue.NfAccept, nfqueueConnMarkAccept) case VerdictDrop: - return n.n.SetVerdict(nP.id, nfqueue.NfDrop) + return nP.nq.SetVerdict(nP.id, nfqueue.NfDrop) case VerdictDropStream: - return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfDrop, nfqueueConnMarkDrop) + return nP.nq.SetVerdictWithConnMark(nP.id, nfqueue.NfDrop, nfqueueConnMarkDrop) default: - // Invalid verdict, ignore for now return nil } } -func (n *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) { - return n.protectedDialer.DialContext(ctx, network, address) +func (nio *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) { + return nio.protectedDialer.DialContext(ctx, network, address) } -func (n *nfqueuePacketIO) Close() error { - if n.rSet { - if n.ipt4 != nil { - _ = n.setupIpt(n.local, n.rst, true) +func (nio *nfqueuePacketIO) Close() error { + if nio.rSet { + if nio.ipt4 != nil { + _ = nio.setupIpt(nio.local, nio.rst, true) } else { - _ = n.setupNft(n.local, n.rst, true) + _ = nio.setupNft(nio.local, nio.rst, true) } - n.rSet = false + nio.rSet = false } - return n.n.Close() + var errs []error + for _, nq := range nio.nqs { + if err := nq.Close(); err != nil { + errs = append(errs, err) + } + } + if len(errs) > 0 { + return errs[0] + } + return nil } -func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error { - rules, err := generateNftRules(local, rst) +func (nio *nfqueuePacketIO) setupNft(local, rst, remove bool) error { + rules, err := generateNftRules(local, rst, nio.numQueues) if err != nil { return err } @@ -283,30 +329,23 @@ func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error { if remove { err = nftDelete(nftFamily, nftTable) } else { - // Delete first to make sure no leftover rules _ = nftDelete(nftFamily, nftTable) err = nftAdd(rulesText) } - if err != nil { - return err - } - return nil + return err } -func (n *nfqueuePacketIO) setupIpt(local, rst, remove bool) error { - rules, err := generateIptRules(local, rst) +func (nio *nfqueuePacketIO) setupIpt(local, rst, remove bool) error { + rules, err := generateIptRules(local, rst, nio.numQueues) if err != nil { return err } if remove { - err = iptsBatchDeleteIfExists([]*iptables.IPTables{n.ipt4, n.ipt6}, rules) + err = iptsBatchDeleteIfExists([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules) } else { - err = iptsBatchAppendUnique([]*iptables.IPTables{n.ipt4, n.ipt6}, rules) + err = iptsBatchAppendUnique([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules) } - if err != nil { - return err - } - return nil + return err } var _ Packet = (*nfqueuePacket)(nil) @@ -315,30 +354,22 @@ type nfqueuePacket struct { id uint32 streamID uint32 data []byte + nq *nfqueue.Nfqueue } -func (p *nfqueuePacket) StreamID() uint32 { - return p.streamID -} - -func (p *nfqueuePacket) Data() []byte { - return p.data -} +func (p *nfqueuePacket) StreamID() uint32 { return p.streamID } +func (p *nfqueuePacket) Data() []byte { return p.data } func okBoolToInt(ok bool) int { if ok { return 0 - } else { - return 1 } + return 1 } func nftCheck() error { _, err := exec.LookPath("nft") - if err != nil { - return err - } - return nil + return err } func nftAdd(input string) error { @@ -363,7 +394,6 @@ func (t *nftTableSpec) String() string { for _, c := range t.Chains { chains = append(chains, c.String()) } - return fmt.Sprintf(` %s diff --git a/ruleset/expr.go b/ruleset/expr.go index d4bceaa..5c76c3d 100644 --- a/ruleset/expr.go +++ b/ruleset/expr.go @@ -7,6 +7,7 @@ import ( "os" "reflect" "strings" + "sync" "time" "github.com/expr-lang/expr/builtin" @@ -67,6 +68,19 @@ type compiledExprRule struct { var _ Ruleset = (*exprRuleset)(nil) +var ( + envPool = sync.Pool{ + New: func() any { + return make(map[string]any, 16) + }, + } + subMapPool = sync.Pool{ + New: func() any { + return make(map[string]any, 8) + }, + } +) + type exprRuleset struct { Rules []compiledExprRule Ans []analyzer.Analyzer @@ -79,7 +93,9 @@ func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer { } func (r *exprRuleset) Match(info StreamInfo) MatchResult { - env := streamInfoToExprEnv(info) + env := envPool.Get().(map[string]any) + clear(env) + populateExprEnv(env, info) now := time.Now() for _, rule := range r.Rules { if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) { @@ -99,6 +115,7 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult { r.Logger.Log(logInfo, rule.Name) } if rule.Action != nil { + envPool.Put(env) return MatchResult{ Action: *rule.Action, ModInstance: rule.ModInstance, @@ -106,7 +123,7 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult { } } } - // No match + envPool.Put(env) return MatchResult{ Action: ActionMaybe, } @@ -228,30 +245,26 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier }, nil } -func streamInfoToExprEnv(info StreamInfo) map[string]interface{} { - m := map[string]interface{}{ - "id": info.ID, - "proto": info.Protocol.String(), - "mac": map[string]string{ - "src": info.SrcMAC.String(), - "dst": info.DstMAC.String(), - }, - "ip": map[string]string{ - "src": info.SrcIP.String(), - "dst": info.DstIP.String(), - }, - "port": map[string]uint16{ - "src": info.SrcPort, - "dst": info.DstPort, - }, +func populateExprEnv(m map[string]any, info StreamInfo) { + m["id"] = info.ID + m["proto"] = info.Protocol.String() + m["mac"] = map[string]string{ + "src": info.SrcMAC.String(), + "dst": info.DstMAC.String(), + } + m["ip"] = map[string]string{ + "src": info.SrcIP.String(), + "dst": info.DstIP.String(), + } + m["port"] = map[string]uint16{ + "src": info.SrcPort, + "dst": info.DstPort, } for anName, anProps := range info.Props { if len(anProps) != 0 { - // Ignore analyzers with empty properties m[anName] = anProps } } - return m } func isBuiltInAnalyzer(name string) bool {