Files
Mellaris/engine/tcp_flow.go
T
hayzam 7a3f6e945d Improves flow handling and adds runtime stats APIs
Refactors TCP and UDP flow managers to enhance analyzer selection and flow binding accuracy, including O(1) UDP stream rebinding by 5-tuple.
Introduces runtime stats tracking for engine and ruleset operations, exposing new APIs for granular performance and error metrics.
Optimizes GeoMatcher with result caching and supports efficient geosite set matching, reducing redundant computation in ruleset expressions.
2026-05-13 06:10:38 +05:30

307 lines
7.7 KiB
Go

package engine
import (
"net"
"sync"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
)
const tcpFlowMaxBuffer = 16384
type tcpFlowDirection uint8
const (
tcpDirC2S tcpFlowDirection = iota
tcpDirS2C
)
type tcpFlow struct {
streamID uint32
srcIP [4]byte
dstIP [4]byte
srcPort uint16
dstPort uint16
dirSeq [2]uint32
dirBuf [2][]byte
info ruleset.StreamInfo
virgin bool
logger Logger
rulesetVersion uint64
rulesetSource func() (ruleset.Ruleset, uint64)
activeEntries []*tcpFlowEntry
doneEntries []*tcpFlowEntry
lastVerdict io.Verdict
feedCalled [2]bool
}
type tcpFlowEntry struct {
Name string
Stream analyzer.TCPStream
HasLimit bool
Quota int
}
func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
rs, version := f.currentRuleset()
rulesetChanged := version != f.rulesetVersion
if !f.virgin && !rulesetChanged && len(f.activeEntries) == 0 {
return f.lastVerdict
}
if tcp.RST || tcp.FIN {
f.closeActiveEntries()
f.runMatch(rs, version, rulesetChanged)
f.maybeFinalizeVerdict()
return f.lastVerdict
}
if len(payload) > 0 {
dir, rev := f.resolveDirection(tcp)
expected := f.dirSeq[dir]
if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected {
f.feedCalled[dir] = true
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
if len(f.dirBuf[dir]) <= tcpFlowMaxBuffer {
f.feedAnalyzers(rev)
}
}
}
f.runMatch(rs, version, rulesetChanged)
f.maybeFinalizeVerdict()
return f.lastVerdict
}
func (f *tcpFlow) feedAnalyzers(rev bool) {
buf := f.dirBuf[uint8(tcpDirC2S)]
if rev {
buf = f.dirBuf[uint8(tcpDirS2C)]
}
for i := len(f.activeEntries) - 1; i >= 0; i-- {
entry := f.activeEntries[i]
update, closeUpdate, done := feedFlowEntry(entry, rev, buf)
u1 := processPropUpdate(f.info.Props, entry.Name, update)
u2 := processPropUpdate(f.info.Props, entry.Name, closeUpdate)
if u1 || u2 {
f.logger.TCPStreamPropUpdate(f.info, false)
}
if done {
f.activeEntries = append(f.activeEntries[:i], f.activeEntries[i+1:]...)
f.doneEntries = append(f.doneEntries, entry)
}
}
}
func (f *tcpFlow) runMatch(rs ruleset.Ruleset, version uint64, rulesetChanged bool) {
if !f.virgin && !rulesetChanged {
return
}
f.virgin = false
f.rulesetVersion = version
result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
if rs != nil {
result = rs.Match(f.info)
}
action := result.Action
if action != ruleset.ActionMaybe && action != ruleset.ActionModify {
verdict := actionToTCPVerdict(action)
f.lastVerdict = verdict
f.closeActiveEntries()
f.logger.TCPStreamAction(f.info, action, false)
}
}
func (f *tcpFlow) maybeFinalizeVerdict() {
if len(f.activeEntries) == 0 && f.lastVerdict == io.VerdictAccept {
f.lastVerdict = io.VerdictAcceptStream
f.logger.TCPStreamAction(f.info, ruleset.ActionAllow, true)
}
}
func (f *tcpFlow) resolveDirection(tcp TCPInfo) (dir uint8, rev bool) {
if tcp.SrcPort == f.srcPort {
return uint8(tcpDirC2S), false
}
return uint8(tcpDirS2C), true
}
func (f *tcpFlow) currentRuleset() (ruleset.Ruleset, uint64) {
if f.rulesetSource == nil {
return nil, f.rulesetVersion
}
return f.rulesetSource()
}
func (f *tcpFlow) closeActiveEntries() {
updated := false
for _, entry := range f.activeEntries {
update := entry.Stream.Close(false)
updated = updated || processPropUpdate(f.info.Props, entry.Name, update)
}
if updated {
f.logger.TCPStreamPropUpdate(f.info, true)
}
f.doneEntries = append(f.doneEntries, f.activeEntries...)
f.activeEntries = nil
}
type tcpFlowManager struct {
mu sync.Mutex
flows map[uint32]*tcpFlow
sfNode *snowflake.Node
logger Logger
rulesetSource func() (ruleset.Ruleset, uint64)
workerID int
macResolver *sourceMACResolver
selector *analyzerSelector
}
func newTCPFlowManager(workerID int, logger Logger, macResolver *sourceMACResolver, node *snowflake.Node, selector *analyzerSelector) *tcpFlowManager {
return &tcpFlowManager{
flows: make(map[uint32]*tcpFlow),
sfNode: node,
logger: logger,
workerID: workerID,
macResolver: macResolver,
selector: selector,
}
}
func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) io.Verdict {
m.mu.Lock()
flow, ok := m.flows[streamID]
if !ok {
flow = m.createFlow(streamID, l3, tcp, payload, srcMAC, dstMAC)
m.flows[streamID] = flow
}
m.mu.Unlock()
verdict := flow.feed(l3, tcp, payload)
if verdict == io.VerdictAcceptStream || verdict == io.VerdictDropStream || tcp.RST || tcp.FIN {
m.mu.Lock()
delete(m.flows, streamID)
m.mu.Unlock()
}
return verdict
}
func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) *tcpFlow {
id := m.sfNode.Generate()
ipSrc := net.IP(l3.SrcIP[:])
ipDst := net.IP(l3.DstIP[:])
if len(srcMAC) == 0 && m.macResolver != nil {
srcMAC = m.macResolver.Resolve(ipSrc)
}
info := ruleset.StreamInfo{
ID: id.Int64(),
Protocol: ruleset.ProtocolTCP,
SrcMAC: append(net.HardwareAddr(nil), srcMAC...),
DstMAC: append(net.HardwareAddr(nil), dstMAC...),
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: tcp.SrcPort,
DstPort: tcp.DstPort,
Props: make(analyzer.CombinedPropMap),
}
m.logger.TCPStreamNew(m.workerID, info)
rs, version := m.rulesetSource()
var ans []analyzer.TCPAnalyzer
if rs != nil {
baseAns := rs.Analyzers(info)
baseAns = m.selector.SelectTCP(baseAns, payload)
ans = analyzersToTCPAnalyzers(baseAns)
}
entries := make([]*tcpFlowEntry, 0, len(ans))
for _, a := range ans {
entries = append(entries, &tcpFlowEntry{
Name: a.Name(),
Stream: a.NewTCP(analyzer.TCPInfo{
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: tcp.SrcPort,
DstPort: tcp.DstPort,
}, &analyzerLogger{
StreamID: id.Int64(),
Name: a.Name(),
Logger: m.logger,
}),
HasLimit: a.Limit() > 0,
Quota: a.Limit(),
})
}
flow := &tcpFlow{
streamID: streamID,
srcIP: l3.SrcIP,
dstIP: l3.DstIP,
srcPort: tcp.SrcPort,
dstPort: tcp.DstPort,
info: info,
virgin: true,
logger: m.logger,
rulesetSource: m.rulesetSource,
rulesetVersion: version,
activeEntries: entries,
lastVerdict: io.VerdictAccept,
}
flow.dirSeq[tcpDirC2S] = tcp.Seq + 1
return flow
}
func (m *tcpFlowManager) updateRuleset(r ruleset.Ruleset, version uint64) {
m.rulesetSource = func() (ruleset.Ruleset, uint64) {
return r, version
}
}
func feedFlowEntry(entry *tcpFlowEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
if !entry.HasLimit {
update, done = entry.Stream.Feed(rev, true, false, 0, data)
} else {
qData := data
if len(qData) > entry.Quota {
qData = qData[:entry.Quota]
}
update, done = entry.Stream.Feed(rev, true, false, 0, qData)
entry.Quota -= len(qData)
if entry.Quota <= 0 {
closeUpdate = entry.Stream.Close(true)
done = true
}
}
return
}
func analyzersToTCPAnalyzers(ans []analyzer.Analyzer) []analyzer.TCPAnalyzer {
tcpAns := make([]analyzer.TCPAnalyzer, 0, len(ans))
for _, a := range ans {
if ta, ok := a.(analyzer.TCPAnalyzer); ok {
tcpAns = append(tcpAns, ta)
}
}
return tcpAns
}
func actionToTCPVerdict(a ruleset.Action) io.Verdict {
switch a {
case ruleset.ActionMaybe, ruleset.ActionAllow, ruleset.ActionModify:
return io.VerdictAcceptStream
case ruleset.ActionBlock, ruleset.ActionDrop:
return io.VerdictDropStream
default:
return io.VerdictAcceptStream
}
}