Files
Mellaris/engine/udp.go
T
hayzam 7a3f6e945d Improves flow handling and adds runtime stats APIs
Refactors TCP and UDP flow managers to enhance analyzer selection and flow binding accuracy, including O(1) UDP stream rebinding by 5-tuple.
Introduces runtime stats tracking for engine and ruleset operations, exposing new APIs for granular performance and error metrics.
Optimizes GeoMatcher with result caching and supports efficient geosite set matching, reducing redundant computation in ruleset expressions.
2026-05-13 06:10:38 +05:30

445 lines
12 KiB
Go

package engine
import (
"bytes"
"errors"
"net"
"sync"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/modifier"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
lru "github.com/hashicorp/golang-lru/v2"
)
// udpVerdict is a subset of io.Verdict for UDP streams.
// For UDP, we support all verdicts.
type udpVerdict io.Verdict
const (
udpVerdictAccept = udpVerdict(io.VerdictAccept)
udpVerdictAcceptModify = udpVerdict(io.VerdictAcceptModify)
udpVerdictAcceptStream = udpVerdict(io.VerdictAcceptStream)
udpVerdictDrop = udpVerdict(io.VerdictDrop)
udpVerdictDropStream = udpVerdict(io.VerdictDropStream)
)
var errInvalidModifier = errors.New("invalid modifier")
type udpContext struct {
Verdict udpVerdict
Packet []byte
SrcMAC, DstMAC net.HardwareAddr
}
type udpStreamFactory struct {
WorkerID int
Logger Logger
Node *snowflake.Node
Selector *analyzerSelector
Stats *statsCounters
RulesetMutex sync.RWMutex
Ruleset ruleset.Ruleset
RulesetVersion uint64
}
func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) *udpStream {
id := f.Node.Generate()
ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw())
info := ruleset.StreamInfo{
ID: id.Int64(),
Protocol: ruleset.ProtocolUDP,
SrcMAC: append(net.HardwareAddr(nil), uc.SrcMAC...),
DstMAC: append(net.HardwareAddr(nil), uc.DstMAC...),
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: uint16(udp.SrcPort),
DstPort: uint16(udp.DstPort),
Props: make(analyzer.CombinedPropMap),
}
f.Logger.UDPStreamNew(f.WorkerID, info)
rs, version := f.currentRuleset()
var ans []analyzer.UDPAnalyzer
if rs != nil {
baseAns := rs.Analyzers(info)
if f.Selector != nil {
baseAns = f.Selector.SelectUDP(baseAns, udp.Payload)
}
ans = analyzersToUDPAnalyzers(baseAns)
}
// Create entries for each analyzer
entries := make([]*udpStreamEntry, 0, len(ans))
for _, a := range ans {
entries = append(entries, &udpStreamEntry{
Name: a.Name(),
Stream: a.NewUDP(analyzer.UDPInfo{
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: uint16(udp.SrcPort),
DstPort: uint16(udp.DstPort),
}, &analyzerLogger{
StreamID: id.Int64(),
Name: a.Name(),
Logger: f.Logger,
}),
HasLimit: a.Limit() > 0,
Quota: a.Limit(),
})
}
return &udpStream{
info: info,
virgin: true,
logger: f.Logger,
rulesetVersion: version,
rulesetSource: f.currentRuleset,
activeEntries: entries,
}
}
func (f *udpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error {
f.RulesetMutex.Lock()
defer f.RulesetMutex.Unlock()
f.Ruleset = r
f.RulesetVersion++
return nil
}
func (f *udpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) {
f.RulesetMutex.RLock()
defer f.RulesetMutex.RUnlock()
return f.Ruleset, f.RulesetVersion
}
type udpStreamManager struct {
factory *udpStreamFactory
streams *lru.Cache[uint32, *udpStreamValue]
tupleIndex map[udpTupleKey]uint32
streamTuples map[uint32]udpTupleKey
stats *statsCounters
}
type udpStreamValue struct {
Stream *udpStream
IPFlow gopacket.Flow
UDPFlow gopacket.Flow
}
type udpTupleKey struct {
AIP [16]byte
BIP [16]byte
ALen uint8
BLen uint8
APort uint16
BPort uint16
}
func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) {
fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow
rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()
return fwd || rev, rev
}
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) {
m := &udpStreamManager{
factory: factory,
tupleIndex: make(map[udpTupleKey]uint32, maxStreams),
streamTuples: make(map[uint32]udpTupleKey, maxStreams),
stats: stats,
}
ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) {
m.removeTupleMappingLocked(k)
})
if err != nil {
return nil, err
}
m.streams = ss
return m, nil
}
func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) {
rev := false
value, ok := m.streams.Get(streamID)
tuple := canonicalUDPTupleKey(ipFlow, udp)
if !ok {
if m.stats != nil {
m.stats.UDPTupleLookups.Add(1)
}
// Conntrack IDs can change during early flow lifetime on some systems.
// Rebind by canonical 5-tuple in O(1).
matchedKey, found := m.tupleIndex[tuple]
var matchedValue *udpStreamValue
var matchedRev bool
if found {
if m.stats != nil {
m.stats.UDPTupleHits.Add(1)
}
var hasValue bool
matchedValue, hasValue = m.streams.Get(matchedKey)
if !hasValue || matchedValue == nil {
delete(m.tupleIndex, tuple)
delete(m.streamTuples, matchedKey)
found = false
}
}
if found {
_, matchedRev = matchedValue.Match(ipFlow, udp.TransportFlow())
value = matchedValue
rev = matchedRev
if matchedKey != streamID {
m.streams.Remove(matchedKey)
m.streams.Add(streamID, matchedValue)
m.bindTupleLocked(streamID, tuple)
}
} else {
// New stream
value = &udpStreamValue{
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc),
IPFlow: ipFlow,
UDPFlow: udp.TransportFlow(),
}
m.streams.Add(streamID, value)
m.bindTupleLocked(streamID, tuple)
}
} else {
// Stream ID exists, but is it really the same stream?
ok, rev = value.Match(ipFlow, udp.TransportFlow())
if !ok {
// It's not - close the old stream & replace it with a new one
value.Stream.Close()
value = &udpStreamValue{
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc),
IPFlow: ipFlow,
UDPFlow: udp.TransportFlow(),
}
m.streams.Add(streamID, value)
m.bindTupleLocked(streamID, tuple)
}
}
if value.Stream.Accept(udp, rev, uc) {
value.Stream.Feed(udp, rev, uc)
}
}
func (m *udpStreamManager) bindTupleLocked(streamID uint32, key udpTupleKey) {
m.removeTupleMappingLocked(streamID)
m.tupleIndex[key] = streamID
m.streamTuples[streamID] = key
}
func (m *udpStreamManager) removeTupleMappingLocked(streamID uint32) {
if key, ok := m.streamTuples[streamID]; ok {
delete(m.streamTuples, streamID)
current, exists := m.tupleIndex[key]
if exists && current == streamID {
delete(m.tupleIndex, key)
}
}
}
func canonicalUDPTupleKey(ipFlow gopacket.Flow, udp *layers.UDP) udpTupleKey {
srcIP := ipFlow.Src().Raw()
dstIP := ipFlow.Dst().Raw()
srcPort := uint16(udp.SrcPort)
dstPort := uint16(udp.DstPort)
if compareIPEndpoint(srcIP, srcPort, dstIP, dstPort) > 0 {
srcIP, dstIP = dstIP, srcIP
srcPort, dstPort = dstPort, srcPort
}
var key udpTupleKey
key.ALen = uint8(copy(key.AIP[:], srcIP))
key.BLen = uint8(copy(key.BIP[:], dstIP))
key.APort = srcPort
key.BPort = dstPort
return key
}
func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int {
if len(aIP) != len(bIP) {
if len(aIP) < len(bIP) {
return -1
}
return 1
}
if c := bytes.Compare(aIP, bIP); c != 0 {
return c
}
if aPort < bPort {
return -1
}
if aPort > bPort {
return 1
}
return 0
}
type udpStream struct {
info ruleset.StreamInfo
virgin bool // true if no packets have been processed
logger Logger
rulesetVersion uint64
rulesetSource func() (ruleset.Ruleset, uint64)
activeEntries []*udpStreamEntry
doneEntries []*udpStreamEntry
lastVerdict udpVerdict
}
type udpStreamEntry struct {
Name string
Stream analyzer.UDPStream
HasLimit bool
Quota int
}
func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool {
if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() {
// Make sure every stream matches against the ruleset at least once,
// even if there are no activeEntries, as the ruleset may have built-in
// properties that need to be matched.
return true
} else {
uc.Verdict = s.lastVerdict
return false
}
}
func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) {
updated := false
for i := len(s.activeEntries) - 1; i >= 0; i-- {
// Important: reverse order so we can remove entries
entry := s.activeEntries[i]
update, closeUpdate, done := s.feedEntry(entry, rev, udp.Payload)
up1 := processPropUpdate(s.info.Props, entry.Name, update)
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
updated = updated || up1 || up2
if done {
s.activeEntries = append(s.activeEntries[:i], s.activeEntries[i+1:]...)
s.doneEntries = append(s.doneEntries, entry)
}
}
rs, version := s.currentRuleset()
rulesetChanged := version != s.rulesetVersion
s.rulesetVersion = version
if updated || s.virgin || rulesetChanged {
s.virgin = false
s.logger.UDPStreamPropUpdate(s.info, false)
// Match properties against ruleset
result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
if rs != nil {
result = rs.Match(s.info)
}
action := result.Action
if action == ruleset.ActionModify {
// Call the modifier instance
udpMI, ok := result.ModInstance.(modifier.UDPModifierInstance)
if !ok {
// Not for UDP, fallback to maybe
s.logger.ModifyError(s.info, errInvalidModifier)
action = ruleset.ActionMaybe
} else {
var err error
uc.Packet, err = udpMI.Process(udp.Payload)
if err != nil {
// Modifier error, fallback to maybe
s.logger.ModifyError(s.info, err)
action = ruleset.ActionMaybe
}
}
}
if action != ruleset.ActionMaybe {
verdict, final := actionToUDPVerdict(action)
s.lastVerdict = verdict
uc.Verdict = verdict
s.logger.UDPStreamAction(s.info, action, false)
if final {
s.closeActiveEntries()
}
}
}
if len(s.activeEntries) == 0 && uc.Verdict == udpVerdictAccept {
// All entries are done but no verdict issued, accept stream
s.lastVerdict = udpVerdictAcceptStream
uc.Verdict = udpVerdictAcceptStream
s.logger.UDPStreamAction(s.info, ruleset.ActionAllow, true)
}
}
func (s *udpStream) currentRuleset() (ruleset.Ruleset, uint64) {
if s.rulesetSource == nil {
return nil, s.rulesetVersion
}
return s.rulesetSource()
}
func (s *udpStream) rulesetChanged() bool {
_, version := s.currentRuleset()
return version != s.rulesetVersion
}
func (s *udpStream) Close() {
s.closeActiveEntries()
}
func (s *udpStream) closeActiveEntries() {
// Signal close to all active entries & move them to doneEntries
updated := false
for _, entry := range s.activeEntries {
update := entry.Stream.Close(false)
up := processPropUpdate(s.info.Props, entry.Name, update)
updated = updated || up
}
if updated {
s.logger.UDPStreamPropUpdate(s.info, true)
}
s.doneEntries = append(s.doneEntries, s.activeEntries...)
s.activeEntries = nil
}
func (s *udpStream) feedEntry(entry *udpStreamEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
update, done = entry.Stream.Feed(rev, data)
if entry.HasLimit {
entry.Quota -= len(data)
if entry.Quota <= 0 {
// Quota exhausted, signal close & move to doneEntries
closeUpdate = entry.Stream.Close(true)
done = true
}
}
return
}
func analyzersToUDPAnalyzers(ans []analyzer.Analyzer) []analyzer.UDPAnalyzer {
udpAns := make([]analyzer.UDPAnalyzer, 0, len(ans))
for _, a := range ans {
if udpM, ok := a.(analyzer.UDPAnalyzer); ok {
udpAns = append(udpAns, udpM)
}
}
return udpAns
}
func actionToUDPVerdict(a ruleset.Action) (v udpVerdict, final bool) {
switch a {
case ruleset.ActionMaybe:
return udpVerdictAccept, false
case ruleset.ActionAllow:
return udpVerdictAcceptStream, true
case ruleset.ActionBlock:
return udpVerdictDropStream, true
case ruleset.ActionDrop:
return udpVerdictDrop, false
case ruleset.ActionModify:
return udpVerdictAcceptModify, false
default:
// Should never happen
return udpVerdictAccept, false
}
}