441 lines
11 KiB
Go
441 lines
11 KiB
Go
package engine
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"net"
|
|
"sync"
|
|
|
|
"git.difuse.io/Difuse/Mellaris/analyzer"
|
|
"git.difuse.io/Difuse/Mellaris/io"
|
|
"git.difuse.io/Difuse/Mellaris/modifier"
|
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
|
|
|
"github.com/bwmarrin/snowflake"
|
|
lru "github.com/hashicorp/golang-lru/v2"
|
|
)
|
|
|
|
// udpVerdict is a subset of io.Verdict for UDP streams.
|
|
// For UDP, we support all verdicts.
|
|
type udpVerdict io.Verdict
|
|
|
|
const (
|
|
udpVerdictAccept = udpVerdict(io.VerdictAccept)
|
|
udpVerdictAcceptModify = udpVerdict(io.VerdictAcceptModify)
|
|
udpVerdictAcceptStream = udpVerdict(io.VerdictAcceptStream)
|
|
udpVerdictDrop = udpVerdict(io.VerdictDrop)
|
|
udpVerdictDropStream = udpVerdict(io.VerdictDropStream)
|
|
)
|
|
|
|
var errInvalidModifier = errors.New("invalid modifier")
|
|
|
|
type udpContext struct {
|
|
Verdict udpVerdict
|
|
Packet []byte
|
|
SrcMAC, DstMAC net.HardwareAddr
|
|
}
|
|
|
|
type udpStreamFactory struct {
|
|
WorkerID int
|
|
Logger Logger
|
|
Node *snowflake.Node
|
|
Selector *analyzerSelector
|
|
Stats *statsCounters
|
|
|
|
RulesetMutex sync.RWMutex
|
|
Ruleset ruleset.Ruleset
|
|
RulesetVersion uint64
|
|
}
|
|
|
|
func (f *udpStreamFactory) New(k udpTupleKey, payload []byte, uc *udpContext) *udpStream {
|
|
id := f.Node.Generate()
|
|
ipSrc := net.IP(k.AIP[:k.ALen])
|
|
ipDst := net.IP(k.BIP[:k.BLen])
|
|
info := ruleset.StreamInfo{
|
|
ID: id.Int64(),
|
|
Protocol: ruleset.ProtocolUDP,
|
|
SrcMAC: append(net.HardwareAddr(nil), uc.SrcMAC...),
|
|
DstMAC: append(net.HardwareAddr(nil), uc.DstMAC...),
|
|
SrcIP: ipSrc,
|
|
DstIP: ipDst,
|
|
SrcPort: k.APort,
|
|
DstPort: k.BPort,
|
|
Props: make(analyzer.CombinedPropMap),
|
|
}
|
|
f.Logger.UDPStreamNew(f.WorkerID, info)
|
|
rs, version := f.currentRuleset()
|
|
var ans []analyzer.UDPAnalyzer
|
|
if rs != nil {
|
|
baseAns := rs.Analyzers(info)
|
|
if f.Selector != nil {
|
|
baseAns = f.Selector.SelectUDP(baseAns, payload)
|
|
}
|
|
ans = analyzersToUDPAnalyzers(baseAns)
|
|
}
|
|
entries := make([]*udpStreamEntry, 0, len(ans))
|
|
for _, a := range ans {
|
|
entries = append(entries, &udpStreamEntry{
|
|
Name: a.Name(),
|
|
Stream: a.NewUDP(analyzer.UDPInfo{
|
|
SrcIP: ipSrc,
|
|
DstIP: ipDst,
|
|
SrcPort: k.APort,
|
|
DstPort: k.BPort,
|
|
}, &analyzerLogger{
|
|
StreamID: id.Int64(),
|
|
Name: a.Name(),
|
|
Logger: f.Logger,
|
|
}),
|
|
HasLimit: a.Limit() > 0,
|
|
Quota: a.Limit(),
|
|
})
|
|
}
|
|
return &udpStream{
|
|
info: info,
|
|
virgin: true,
|
|
logger: f.Logger,
|
|
rulesetVersion: version,
|
|
rulesetSource: f.currentRuleset,
|
|
activeEntries: entries,
|
|
}
|
|
}
|
|
|
|
func (f *udpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error {
|
|
f.RulesetMutex.Lock()
|
|
defer f.RulesetMutex.Unlock()
|
|
f.Ruleset = r
|
|
f.RulesetVersion++
|
|
return nil
|
|
}
|
|
|
|
func (f *udpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) {
|
|
f.RulesetMutex.RLock()
|
|
defer f.RulesetMutex.RUnlock()
|
|
return f.Ruleset, f.RulesetVersion
|
|
}
|
|
|
|
type udpStreamManager struct {
|
|
factory *udpStreamFactory
|
|
streams *lru.Cache[uint32, *udpStreamValue]
|
|
tupleIndex map[udpTupleKey]uint32
|
|
streamTuples map[uint32]udpTupleKey
|
|
stats *statsCounters
|
|
}
|
|
|
|
type udpStreamValue struct {
|
|
Stream *udpStream
|
|
Tuple udpTupleKey
|
|
}
|
|
|
|
func (v *udpStreamValue) Match(k udpTupleKey) (ok, rev bool) {
|
|
fwd := v.Tuple == k
|
|
rev = v.Tuple == reverseTuple(k)
|
|
return fwd || rev, rev
|
|
}
|
|
|
|
type udpTupleKey struct {
|
|
AIP [16]byte
|
|
BIP [16]byte
|
|
ALen uint8
|
|
BLen uint8
|
|
APort uint16
|
|
BPort uint16
|
|
}
|
|
|
|
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) {
|
|
m := &udpStreamManager{
|
|
factory: factory,
|
|
tupleIndex: make(map[udpTupleKey]uint32, maxStreams),
|
|
streamTuples: make(map[uint32]udpTupleKey, maxStreams),
|
|
stats: stats,
|
|
}
|
|
ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) {
|
|
if v != nil && v.Stream != nil {
|
|
v.Stream.Close()
|
|
}
|
|
m.removeTupleMappingLocked(k)
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
m.streams = ss
|
|
return m, nil
|
|
}
|
|
|
|
func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, rev bool, payload []byte, uc *udpContext) {
|
|
value, ok := m.streams.Get(streamID)
|
|
if !ok {
|
|
if m.stats != nil {
|
|
m.stats.UDPTupleLookups.Add(1)
|
|
}
|
|
matchedKey, found := m.tupleIndex[tuple]
|
|
var matchedValue *udpStreamValue
|
|
var matchedRev bool
|
|
if found {
|
|
if m.stats != nil {
|
|
m.stats.UDPTupleHits.Add(1)
|
|
}
|
|
var hasValue bool
|
|
matchedValue, hasValue = m.streams.Get(matchedKey)
|
|
if !hasValue || matchedValue == nil {
|
|
delete(m.tupleIndex, tuple)
|
|
delete(m.streamTuples, matchedKey)
|
|
found = false
|
|
}
|
|
}
|
|
if found {
|
|
_, matchedRev = matchedValue.Match(tuple)
|
|
value = matchedValue
|
|
rev = matchedRev
|
|
if matchedKey != streamID {
|
|
m.streams.Remove(matchedKey)
|
|
m.streams.Add(streamID, matchedValue)
|
|
m.bindTupleLocked(streamID, tuple)
|
|
}
|
|
} else {
|
|
value = &udpStreamValue{
|
|
Stream: m.factory.New(tuple, payload, uc),
|
|
Tuple: tuple,
|
|
}
|
|
m.streams.Add(streamID, value)
|
|
m.bindTupleLocked(streamID, tuple)
|
|
}
|
|
} else {
|
|
ok, rev = value.Match(tuple)
|
|
if !ok {
|
|
value.Stream.Close()
|
|
value = &udpStreamValue{
|
|
Stream: m.factory.New(tuple, payload, uc),
|
|
Tuple: tuple,
|
|
}
|
|
m.streams.Add(streamID, value)
|
|
m.bindTupleLocked(streamID, tuple)
|
|
}
|
|
}
|
|
if value.Stream.Accept(rev, uc) {
|
|
value.Stream.Feed(rev, payload, uc)
|
|
}
|
|
}
|
|
|
|
func (m *udpStreamManager) bindTupleLocked(streamID uint32, key udpTupleKey) {
|
|
m.removeTupleMappingLocked(streamID)
|
|
m.tupleIndex[key] = streamID
|
|
m.streamTuples[streamID] = key
|
|
}
|
|
|
|
func (m *udpStreamManager) removeTupleMappingLocked(streamID uint32) {
|
|
if key, ok := m.streamTuples[streamID]; ok {
|
|
delete(m.streamTuples, streamID)
|
|
current, exists := m.tupleIndex[key]
|
|
if exists && current == streamID {
|
|
delete(m.tupleIndex, key)
|
|
}
|
|
}
|
|
}
|
|
|
|
func canonicalUDPTupleKey(srcIP, dstIP net.IP, srcPort, dstPort uint16) udpTupleKey {
|
|
srcRaw := []byte(srcIP)
|
|
dstRaw := []byte(dstIP)
|
|
|
|
if compareIPEndpoint(srcRaw, srcPort, dstRaw, dstPort) > 0 {
|
|
srcRaw, dstRaw = dstRaw, srcRaw
|
|
srcPort, dstPort = dstPort, srcPort
|
|
}
|
|
|
|
var key udpTupleKey
|
|
key.ALen = uint8(copy(key.AIP[:], srcRaw))
|
|
key.BLen = uint8(copy(key.BIP[:], dstRaw))
|
|
key.APort = srcPort
|
|
key.BPort = dstPort
|
|
return key
|
|
}
|
|
|
|
func reverseTuple(k udpTupleKey) udpTupleKey {
|
|
var r udpTupleKey
|
|
r.ALen = k.BLen
|
|
r.BLen = k.ALen
|
|
r.AIP = k.BIP
|
|
r.BIP = k.AIP
|
|
r.APort = k.BPort
|
|
r.BPort = k.APort
|
|
return r
|
|
}
|
|
|
|
func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int {
|
|
if len(aIP) != len(bIP) {
|
|
if len(aIP) < len(bIP) {
|
|
return -1
|
|
}
|
|
return 1
|
|
}
|
|
if c := bytes.Compare(aIP, bIP); c != 0 {
|
|
return c
|
|
}
|
|
if aPort < bPort {
|
|
return -1
|
|
}
|
|
if aPort > bPort {
|
|
return 1
|
|
}
|
|
return 0
|
|
}
|
|
|
|
type udpStream struct {
|
|
info ruleset.StreamInfo
|
|
virgin bool // true if no packets have been processed
|
|
logger Logger
|
|
rulesetVersion uint64
|
|
rulesetSource func() (ruleset.Ruleset, uint64)
|
|
activeEntries []*udpStreamEntry
|
|
doneEntries []*udpStreamEntry
|
|
lastVerdict udpVerdict
|
|
}
|
|
|
|
type udpStreamEntry struct {
|
|
Name string
|
|
Stream analyzer.UDPStream
|
|
HasLimit bool
|
|
Quota int
|
|
}
|
|
|
|
func (s *udpStream) Accept(rev bool, uc *udpContext) bool {
|
|
if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() {
|
|
return true
|
|
} else {
|
|
uc.Verdict = s.lastVerdict
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (s *udpStream) Feed(rev bool, payload []byte, uc *udpContext) {
|
|
updated := false
|
|
for i := len(s.activeEntries) - 1; i >= 0; i-- {
|
|
entry := s.activeEntries[i]
|
|
update, closeUpdate, done := s.feedEntry(entry, rev, payload)
|
|
up1 := processPropUpdate(s.info.Props, entry.Name, update)
|
|
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
|
|
updated = updated || up1 || up2
|
|
if done {
|
|
s.activeEntries = append(s.activeEntries[:i], s.activeEntries[i+1:]...)
|
|
s.doneEntries = append(s.doneEntries, entry)
|
|
}
|
|
}
|
|
rs, version := s.currentRuleset()
|
|
rulesetChanged := version != s.rulesetVersion
|
|
s.rulesetVersion = version
|
|
if updated || s.virgin || rulesetChanged {
|
|
s.virgin = false
|
|
s.logger.UDPStreamPropUpdate(s.info, false)
|
|
// Match properties against ruleset
|
|
result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
|
|
if rs != nil {
|
|
result = rs.Match(s.info)
|
|
}
|
|
action := result.Action
|
|
if action == ruleset.ActionModify {
|
|
// Call the modifier instance
|
|
udpMI, ok := result.ModInstance.(modifier.UDPModifierInstance)
|
|
if !ok {
|
|
// Not for UDP, fallback to maybe
|
|
s.logger.ModifyError(s.info, errInvalidModifier)
|
|
action = ruleset.ActionMaybe
|
|
} else {
|
|
var err error
|
|
uc.Packet, err = udpMI.Process(payload)
|
|
if err != nil {
|
|
// Modifier error, fallback to maybe
|
|
s.logger.ModifyError(s.info, err)
|
|
action = ruleset.ActionMaybe
|
|
}
|
|
}
|
|
}
|
|
if action != ruleset.ActionMaybe {
|
|
verdict, final := actionToUDPVerdict(action)
|
|
s.lastVerdict = verdict
|
|
uc.Verdict = verdict
|
|
s.logger.UDPStreamAction(s.info, action, false)
|
|
if final {
|
|
s.closeActiveEntries()
|
|
}
|
|
}
|
|
}
|
|
if len(s.activeEntries) == 0 && uc.Verdict == udpVerdictAccept {
|
|
// All entries are done but no verdict issued, accept stream
|
|
s.lastVerdict = udpVerdictAcceptStream
|
|
uc.Verdict = udpVerdictAcceptStream
|
|
s.logger.UDPStreamAction(s.info, ruleset.ActionAllow, true)
|
|
}
|
|
}
|
|
|
|
func (s *udpStream) currentRuleset() (ruleset.Ruleset, uint64) {
|
|
if s.rulesetSource == nil {
|
|
return nil, s.rulesetVersion
|
|
}
|
|
return s.rulesetSource()
|
|
}
|
|
|
|
func (s *udpStream) rulesetChanged() bool {
|
|
_, version := s.currentRuleset()
|
|
return version != s.rulesetVersion
|
|
}
|
|
|
|
func (s *udpStream) Close() {
|
|
s.closeActiveEntries()
|
|
}
|
|
|
|
func (s *udpStream) closeActiveEntries() {
|
|
// Signal close to all active entries & move them to doneEntries
|
|
updated := false
|
|
for _, entry := range s.activeEntries {
|
|
update := entry.Stream.Close(false)
|
|
up := processPropUpdate(s.info.Props, entry.Name, update)
|
|
updated = updated || up
|
|
}
|
|
if updated {
|
|
s.logger.UDPStreamPropUpdate(s.info, true)
|
|
}
|
|
s.doneEntries = append(s.doneEntries, s.activeEntries...)
|
|
s.activeEntries = nil
|
|
}
|
|
|
|
func (s *udpStream) feedEntry(entry *udpStreamEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
|
|
update, done = entry.Stream.Feed(rev, data)
|
|
if entry.HasLimit {
|
|
entry.Quota -= len(data)
|
|
if entry.Quota <= 0 {
|
|
// Quota exhausted, signal close & move to doneEntries
|
|
closeUpdate = entry.Stream.Close(true)
|
|
done = true
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func analyzersToUDPAnalyzers(ans []analyzer.Analyzer) []analyzer.UDPAnalyzer {
|
|
udpAns := make([]analyzer.UDPAnalyzer, 0, len(ans))
|
|
for _, a := range ans {
|
|
if udpM, ok := a.(analyzer.UDPAnalyzer); ok {
|
|
udpAns = append(udpAns, udpM)
|
|
}
|
|
}
|
|
return udpAns
|
|
}
|
|
|
|
func actionToUDPVerdict(a ruleset.Action) (v udpVerdict, final bool) {
|
|
switch a {
|
|
case ruleset.ActionMaybe:
|
|
return udpVerdictAccept, false
|
|
case ruleset.ActionAllow:
|
|
return udpVerdictAcceptStream, true
|
|
case ruleset.ActionBlock:
|
|
return udpVerdictDropStream, true
|
|
case ruleset.ActionDrop:
|
|
return udpVerdictDrop, false
|
|
case ruleset.ActionModify:
|
|
return udpVerdictAcceptModify, false
|
|
default:
|
|
// Should never happen
|
|
return udpVerdictAccept, false
|
|
}
|
|
}
|