Files
Mellaris/engine/udp.go
T

441 lines
11 KiB
Go

package engine
import (
"bytes"
"errors"
"net"
"sync"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/modifier"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
lru "github.com/hashicorp/golang-lru/v2"
)
// udpVerdict is a subset of io.Verdict for UDP streams.
// For UDP, we support all verdicts.
type udpVerdict io.Verdict
const (
udpVerdictAccept = udpVerdict(io.VerdictAccept)
udpVerdictAcceptModify = udpVerdict(io.VerdictAcceptModify)
udpVerdictAcceptStream = udpVerdict(io.VerdictAcceptStream)
udpVerdictDrop = udpVerdict(io.VerdictDrop)
udpVerdictDropStream = udpVerdict(io.VerdictDropStream)
)
var errInvalidModifier = errors.New("invalid modifier")
type udpContext struct {
Verdict udpVerdict
Packet []byte
SrcMAC, DstMAC net.HardwareAddr
}
type udpStreamFactory struct {
WorkerID int
Logger Logger
Node *snowflake.Node
Selector *analyzerSelector
Stats *statsCounters
RulesetMutex sync.RWMutex
Ruleset ruleset.Ruleset
RulesetVersion uint64
}
func (f *udpStreamFactory) New(k udpTupleKey, payload []byte, uc *udpContext) *udpStream {
id := f.Node.Generate()
ipSrc := net.IP(k.AIP[:k.ALen])
ipDst := net.IP(k.BIP[:k.BLen])
info := ruleset.StreamInfo{
ID: id.Int64(),
Protocol: ruleset.ProtocolUDP,
SrcMAC: append(net.HardwareAddr(nil), uc.SrcMAC...),
DstMAC: append(net.HardwareAddr(nil), uc.DstMAC...),
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: k.APort,
DstPort: k.BPort,
Props: make(analyzer.CombinedPropMap),
}
f.Logger.UDPStreamNew(f.WorkerID, info)
rs, version := f.currentRuleset()
var ans []analyzer.UDPAnalyzer
if rs != nil {
baseAns := rs.Analyzers(info)
if f.Selector != nil {
baseAns = f.Selector.SelectUDP(baseAns, payload)
}
ans = analyzersToUDPAnalyzers(baseAns)
}
entries := make([]*udpStreamEntry, 0, len(ans))
for _, a := range ans {
entries = append(entries, &udpStreamEntry{
Name: a.Name(),
Stream: a.NewUDP(analyzer.UDPInfo{
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: k.APort,
DstPort: k.BPort,
}, &analyzerLogger{
StreamID: id.Int64(),
Name: a.Name(),
Logger: f.Logger,
}),
HasLimit: a.Limit() > 0,
Quota: a.Limit(),
})
}
return &udpStream{
info: info,
virgin: true,
logger: f.Logger,
rulesetVersion: version,
rulesetSource: f.currentRuleset,
activeEntries: entries,
}
}
func (f *udpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error {
f.RulesetMutex.Lock()
defer f.RulesetMutex.Unlock()
f.Ruleset = r
f.RulesetVersion++
return nil
}
func (f *udpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) {
f.RulesetMutex.RLock()
defer f.RulesetMutex.RUnlock()
return f.Ruleset, f.RulesetVersion
}
type udpStreamManager struct {
factory *udpStreamFactory
streams *lru.Cache[uint32, *udpStreamValue]
tupleIndex map[udpTupleKey]uint32
streamTuples map[uint32]udpTupleKey
stats *statsCounters
}
type udpStreamValue struct {
Stream *udpStream
Tuple udpTupleKey
}
func (v *udpStreamValue) Match(k udpTupleKey) (ok, rev bool) {
fwd := v.Tuple == k
rev = v.Tuple == reverseTuple(k)
return fwd || rev, rev
}
type udpTupleKey struct {
AIP [16]byte
BIP [16]byte
ALen uint8
BLen uint8
APort uint16
BPort uint16
}
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) {
m := &udpStreamManager{
factory: factory,
tupleIndex: make(map[udpTupleKey]uint32, maxStreams),
streamTuples: make(map[uint32]udpTupleKey, maxStreams),
stats: stats,
}
ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) {
if v != nil && v.Stream != nil {
v.Stream.Close()
}
m.removeTupleMappingLocked(k)
})
if err != nil {
return nil, err
}
m.streams = ss
return m, nil
}
func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, rev bool, payload []byte, uc *udpContext) {
value, ok := m.streams.Get(streamID)
if !ok {
if m.stats != nil {
m.stats.UDPTupleLookups.Add(1)
}
matchedKey, found := m.tupleIndex[tuple]
var matchedValue *udpStreamValue
var matchedRev bool
if found {
if m.stats != nil {
m.stats.UDPTupleHits.Add(1)
}
var hasValue bool
matchedValue, hasValue = m.streams.Get(matchedKey)
if !hasValue || matchedValue == nil {
delete(m.tupleIndex, tuple)
delete(m.streamTuples, matchedKey)
found = false
}
}
if found {
_, matchedRev = matchedValue.Match(tuple)
value = matchedValue
rev = matchedRev
if matchedKey != streamID {
m.streams.Remove(matchedKey)
m.streams.Add(streamID, matchedValue)
m.bindTupleLocked(streamID, tuple)
}
} else {
value = &udpStreamValue{
Stream: m.factory.New(tuple, payload, uc),
Tuple: tuple,
}
m.streams.Add(streamID, value)
m.bindTupleLocked(streamID, tuple)
}
} else {
ok, rev = value.Match(tuple)
if !ok {
value.Stream.Close()
value = &udpStreamValue{
Stream: m.factory.New(tuple, payload, uc),
Tuple: tuple,
}
m.streams.Add(streamID, value)
m.bindTupleLocked(streamID, tuple)
}
}
if value.Stream.Accept(rev, uc) {
value.Stream.Feed(rev, payload, uc)
}
}
func (m *udpStreamManager) bindTupleLocked(streamID uint32, key udpTupleKey) {
m.removeTupleMappingLocked(streamID)
m.tupleIndex[key] = streamID
m.streamTuples[streamID] = key
}
func (m *udpStreamManager) removeTupleMappingLocked(streamID uint32) {
if key, ok := m.streamTuples[streamID]; ok {
delete(m.streamTuples, streamID)
current, exists := m.tupleIndex[key]
if exists && current == streamID {
delete(m.tupleIndex, key)
}
}
}
func canonicalUDPTupleKey(srcIP, dstIP net.IP, srcPort, dstPort uint16) udpTupleKey {
srcRaw := []byte(srcIP)
dstRaw := []byte(dstIP)
if compareIPEndpoint(srcRaw, srcPort, dstRaw, dstPort) > 0 {
srcRaw, dstRaw = dstRaw, srcRaw
srcPort, dstPort = dstPort, srcPort
}
var key udpTupleKey
key.ALen = uint8(copy(key.AIP[:], srcRaw))
key.BLen = uint8(copy(key.BIP[:], dstRaw))
key.APort = srcPort
key.BPort = dstPort
return key
}
func reverseTuple(k udpTupleKey) udpTupleKey {
var r udpTupleKey
r.ALen = k.BLen
r.BLen = k.ALen
r.AIP = k.BIP
r.BIP = k.AIP
r.APort = k.BPort
r.BPort = k.APort
return r
}
func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int {
if len(aIP) != len(bIP) {
if len(aIP) < len(bIP) {
return -1
}
return 1
}
if c := bytes.Compare(aIP, bIP); c != 0 {
return c
}
if aPort < bPort {
return -1
}
if aPort > bPort {
return 1
}
return 0
}
type udpStream struct {
info ruleset.StreamInfo
virgin bool // true if no packets have been processed
logger Logger
rulesetVersion uint64
rulesetSource func() (ruleset.Ruleset, uint64)
activeEntries []*udpStreamEntry
doneEntries []*udpStreamEntry
lastVerdict udpVerdict
}
type udpStreamEntry struct {
Name string
Stream analyzer.UDPStream
HasLimit bool
Quota int
}
func (s *udpStream) Accept(rev bool, uc *udpContext) bool {
if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() {
return true
} else {
uc.Verdict = s.lastVerdict
return false
}
}
func (s *udpStream) Feed(rev bool, payload []byte, uc *udpContext) {
updated := false
for i := len(s.activeEntries) - 1; i >= 0; i-- {
entry := s.activeEntries[i]
update, closeUpdate, done := s.feedEntry(entry, rev, payload)
up1 := processPropUpdate(s.info.Props, entry.Name, update)
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
updated = updated || up1 || up2
if done {
s.activeEntries = append(s.activeEntries[:i], s.activeEntries[i+1:]...)
s.doneEntries = append(s.doneEntries, entry)
}
}
rs, version := s.currentRuleset()
rulesetChanged := version != s.rulesetVersion
s.rulesetVersion = version
if updated || s.virgin || rulesetChanged {
s.virgin = false
s.logger.UDPStreamPropUpdate(s.info, false)
// Match properties against ruleset
result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
if rs != nil {
result = rs.Match(s.info)
}
action := result.Action
if action == ruleset.ActionModify {
// Call the modifier instance
udpMI, ok := result.ModInstance.(modifier.UDPModifierInstance)
if !ok {
// Not for UDP, fallback to maybe
s.logger.ModifyError(s.info, errInvalidModifier)
action = ruleset.ActionMaybe
} else {
var err error
uc.Packet, err = udpMI.Process(payload)
if err != nil {
// Modifier error, fallback to maybe
s.logger.ModifyError(s.info, err)
action = ruleset.ActionMaybe
}
}
}
if action != ruleset.ActionMaybe {
verdict, final := actionToUDPVerdict(action)
s.lastVerdict = verdict
uc.Verdict = verdict
s.logger.UDPStreamAction(s.info, action, false)
if final {
s.closeActiveEntries()
}
}
}
if len(s.activeEntries) == 0 && uc.Verdict == udpVerdictAccept {
// All entries are done but no verdict issued, accept stream
s.lastVerdict = udpVerdictAcceptStream
uc.Verdict = udpVerdictAcceptStream
s.logger.UDPStreamAction(s.info, ruleset.ActionAllow, true)
}
}
func (s *udpStream) currentRuleset() (ruleset.Ruleset, uint64) {
if s.rulesetSource == nil {
return nil, s.rulesetVersion
}
return s.rulesetSource()
}
func (s *udpStream) rulesetChanged() bool {
_, version := s.currentRuleset()
return version != s.rulesetVersion
}
func (s *udpStream) Close() {
s.closeActiveEntries()
}
func (s *udpStream) closeActiveEntries() {
// Signal close to all active entries & move them to doneEntries
updated := false
for _, entry := range s.activeEntries {
update := entry.Stream.Close(false)
up := processPropUpdate(s.info.Props, entry.Name, update)
updated = updated || up
}
if updated {
s.logger.UDPStreamPropUpdate(s.info, true)
}
s.doneEntries = append(s.doneEntries, s.activeEntries...)
s.activeEntries = nil
}
func (s *udpStream) feedEntry(entry *udpStreamEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
update, done = entry.Stream.Feed(rev, data)
if entry.HasLimit {
entry.Quota -= len(data)
if entry.Quota <= 0 {
// Quota exhausted, signal close & move to doneEntries
closeUpdate = entry.Stream.Close(true)
done = true
}
}
return
}
func analyzersToUDPAnalyzers(ans []analyzer.Analyzer) []analyzer.UDPAnalyzer {
udpAns := make([]analyzer.UDPAnalyzer, 0, len(ans))
for _, a := range ans {
if udpM, ok := a.(analyzer.UDPAnalyzer); ok {
udpAns = append(udpAns, udpM)
}
}
return udpAns
}
func actionToUDPVerdict(a ruleset.Action) (v udpVerdict, final bool) {
switch a {
case ruleset.ActionMaybe:
return udpVerdictAccept, false
case ruleset.ActionAllow:
return udpVerdictAcceptStream, true
case ruleset.ActionBlock:
return udpVerdictDropStream, true
case ruleset.ActionDrop:
return udpVerdictDrop, false
case ruleset.ActionModify:
return udpVerdictAcceptModify, false
default:
// Should never happen
return udpVerdictAccept, false
}
}