package engine import ( "bytes" "errors" "net" "sync" "git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/modifier" "git.difuse.io/Difuse/Mellaris/ruleset" "github.com/bwmarrin/snowflake" lru "github.com/hashicorp/golang-lru/v2" ) // udpVerdict is a subset of io.Verdict for UDP streams. // For UDP, we support all verdicts. type udpVerdict io.Verdict const ( udpVerdictAccept = udpVerdict(io.VerdictAccept) udpVerdictAcceptModify = udpVerdict(io.VerdictAcceptModify) udpVerdictAcceptStream = udpVerdict(io.VerdictAcceptStream) udpVerdictDrop = udpVerdict(io.VerdictDrop) udpVerdictDropStream = udpVerdict(io.VerdictDropStream) ) var errInvalidModifier = errors.New("invalid modifier") type udpContext struct { Verdict udpVerdict Packet []byte SrcMAC, DstMAC net.HardwareAddr } type udpStreamFactory struct { WorkerID int Logger Logger Node *snowflake.Node Selector *analyzerSelector Stats *statsCounters RulesetMutex sync.RWMutex Ruleset ruleset.Ruleset RulesetVersion uint64 } func (f *udpStreamFactory) New(k udpTupleKey, payload []byte, uc *udpContext) *udpStream { id := f.Node.Generate() ipSrc := net.IP(k.AIP[:k.ALen]) ipDst := net.IP(k.BIP[:k.BLen]) info := ruleset.StreamInfo{ ID: id.Int64(), Protocol: ruleset.ProtocolUDP, SrcMAC: append(net.HardwareAddr(nil), uc.SrcMAC...), DstMAC: append(net.HardwareAddr(nil), uc.DstMAC...), SrcIP: ipSrc, DstIP: ipDst, SrcPort: k.APort, DstPort: k.BPort, Props: make(analyzer.CombinedPropMap), } f.Logger.UDPStreamNew(f.WorkerID, info) rs, version := f.currentRuleset() var ans []analyzer.UDPAnalyzer if rs != nil { baseAns := rs.Analyzers(info) if f.Selector != nil { baseAns = f.Selector.SelectUDP(baseAns, payload) } ans = analyzersToUDPAnalyzers(baseAns) } entries := make([]*udpStreamEntry, 0, len(ans)) for _, a := range ans { entries = append(entries, &udpStreamEntry{ Name: a.Name(), Stream: a.NewUDP(analyzer.UDPInfo{ SrcIP: ipSrc, DstIP: ipDst, SrcPort: k.APort, DstPort: k.BPort, }, &analyzerLogger{ StreamID: id.Int64(), Name: a.Name(), Logger: f.Logger, }), HasLimit: a.Limit() > 0, Quota: a.Limit(), }) } return &udpStream{ info: info, virgin: true, logger: f.Logger, rulesetVersion: version, rulesetSource: f.currentRuleset, activeEntries: entries, } } func (f *udpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error { f.RulesetMutex.Lock() defer f.RulesetMutex.Unlock() f.Ruleset = r f.RulesetVersion++ return nil } func (f *udpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) { f.RulesetMutex.RLock() defer f.RulesetMutex.RUnlock() return f.Ruleset, f.RulesetVersion } type udpStreamManager struct { factory *udpStreamFactory streams *lru.Cache[uint32, *udpStreamValue] tupleIndex map[udpTupleKey]uint32 streamTuples map[uint32]udpTupleKey stats *statsCounters } type udpStreamValue struct { Stream *udpStream Tuple udpTupleKey } func (v *udpStreamValue) Match(k udpTupleKey) (ok, rev bool) { fwd := v.Tuple == k rev = v.Tuple == reverseTuple(k) return fwd || rev, rev } type udpTupleKey struct { AIP [16]byte BIP [16]byte ALen uint8 BLen uint8 APort uint16 BPort uint16 } func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) { m := &udpStreamManager{ factory: factory, tupleIndex: make(map[udpTupleKey]uint32, maxStreams), streamTuples: make(map[uint32]udpTupleKey, maxStreams), stats: stats, } ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) { if v != nil && v.Stream != nil { v.Stream.Close() } m.removeTupleMappingLocked(k) }) if err != nil { return nil, err } m.streams = ss return m, nil } func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, rev bool, payload []byte, uc *udpContext) { value, ok := m.streams.Get(streamID) if !ok { if m.stats != nil { m.stats.UDPTupleLookups.Add(1) } matchedKey, found := m.tupleIndex[tuple] var matchedValue *udpStreamValue var matchedRev bool if found { if m.stats != nil { m.stats.UDPTupleHits.Add(1) } var hasValue bool matchedValue, hasValue = m.streams.Get(matchedKey) if !hasValue || matchedValue == nil { delete(m.tupleIndex, tuple) delete(m.streamTuples, matchedKey) found = false } } if found { _, matchedRev = matchedValue.Match(tuple) value = matchedValue rev = matchedRev if matchedKey != streamID { m.streams.Remove(matchedKey) m.streams.Add(streamID, matchedValue) m.bindTupleLocked(streamID, tuple) } } else { value = &udpStreamValue{ Stream: m.factory.New(tuple, payload, uc), Tuple: tuple, } m.streams.Add(streamID, value) m.bindTupleLocked(streamID, tuple) } } else { ok, rev = value.Match(tuple) if !ok { value.Stream.Close() value = &udpStreamValue{ Stream: m.factory.New(tuple, payload, uc), Tuple: tuple, } m.streams.Add(streamID, value) m.bindTupleLocked(streamID, tuple) } } if value.Stream.Accept(rev, uc) { value.Stream.Feed(rev, payload, uc) } } func (m *udpStreamManager) bindTupleLocked(streamID uint32, key udpTupleKey) { m.removeTupleMappingLocked(streamID) m.tupleIndex[key] = streamID m.streamTuples[streamID] = key } func (m *udpStreamManager) removeTupleMappingLocked(streamID uint32) { if key, ok := m.streamTuples[streamID]; ok { delete(m.streamTuples, streamID) current, exists := m.tupleIndex[key] if exists && current == streamID { delete(m.tupleIndex, key) } } } func canonicalUDPTupleKey(srcIP, dstIP net.IP, srcPort, dstPort uint16) udpTupleKey { srcRaw := []byte(srcIP) dstRaw := []byte(dstIP) if compareIPEndpoint(srcRaw, srcPort, dstRaw, dstPort) > 0 { srcRaw, dstRaw = dstRaw, srcRaw srcPort, dstPort = dstPort, srcPort } var key udpTupleKey key.ALen = uint8(copy(key.AIP[:], srcRaw)) key.BLen = uint8(copy(key.BIP[:], dstRaw)) key.APort = srcPort key.BPort = dstPort return key } func reverseTuple(k udpTupleKey) udpTupleKey { var r udpTupleKey r.ALen = k.BLen r.BLen = k.ALen r.AIP = k.BIP r.BIP = k.AIP r.APort = k.BPort r.BPort = k.APort return r } func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int { if len(aIP) != len(bIP) { if len(aIP) < len(bIP) { return -1 } return 1 } if c := bytes.Compare(aIP, bIP); c != 0 { return c } if aPort < bPort { return -1 } if aPort > bPort { return 1 } return 0 } type udpStream struct { info ruleset.StreamInfo virgin bool // true if no packets have been processed logger Logger rulesetVersion uint64 rulesetSource func() (ruleset.Ruleset, uint64) activeEntries []*udpStreamEntry doneEntries []*udpStreamEntry lastVerdict udpVerdict } type udpStreamEntry struct { Name string Stream analyzer.UDPStream HasLimit bool Quota int } func (s *udpStream) Accept(rev bool, uc *udpContext) bool { if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() { return true } else { uc.Verdict = s.lastVerdict return false } } func (s *udpStream) Feed(rev bool, payload []byte, uc *udpContext) { updated := false for i := len(s.activeEntries) - 1; i >= 0; i-- { entry := s.activeEntries[i] update, closeUpdate, done := s.feedEntry(entry, rev, payload) up1 := processPropUpdate(s.info.Props, entry.Name, update) up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate) updated = updated || up1 || up2 if done { s.activeEntries = append(s.activeEntries[:i], s.activeEntries[i+1:]...) s.doneEntries = append(s.doneEntries, entry) } } rs, version := s.currentRuleset() rulesetChanged := version != s.rulesetVersion s.rulesetVersion = version if updated || s.virgin || rulesetChanged { s.virgin = false s.logger.UDPStreamPropUpdate(s.info, false) // Match properties against ruleset result := ruleset.MatchResult{Action: ruleset.ActionMaybe} if rs != nil { result = rs.Match(s.info) } action := result.Action if action == ruleset.ActionModify { // Call the modifier instance udpMI, ok := result.ModInstance.(modifier.UDPModifierInstance) if !ok { // Not for UDP, fallback to maybe s.logger.ModifyError(s.info, errInvalidModifier) action = ruleset.ActionMaybe } else { var err error uc.Packet, err = udpMI.Process(payload) if err != nil { // Modifier error, fallback to maybe s.logger.ModifyError(s.info, err) action = ruleset.ActionMaybe } } } if action != ruleset.ActionMaybe { verdict, final := actionToUDPVerdict(action) s.lastVerdict = verdict uc.Verdict = verdict s.logger.UDPStreamAction(s.info, action, false) if final { s.closeActiveEntries() } } } if len(s.activeEntries) == 0 && uc.Verdict == udpVerdictAccept { // All entries are done but no verdict issued, accept stream s.lastVerdict = udpVerdictAcceptStream uc.Verdict = udpVerdictAcceptStream s.logger.UDPStreamAction(s.info, ruleset.ActionAllow, true) } } func (s *udpStream) currentRuleset() (ruleset.Ruleset, uint64) { if s.rulesetSource == nil { return nil, s.rulesetVersion } return s.rulesetSource() } func (s *udpStream) rulesetChanged() bool { _, version := s.currentRuleset() return version != s.rulesetVersion } func (s *udpStream) Close() { s.closeActiveEntries() } func (s *udpStream) closeActiveEntries() { // Signal close to all active entries & move them to doneEntries updated := false for _, entry := range s.activeEntries { update := entry.Stream.Close(false) up := processPropUpdate(s.info.Props, entry.Name, update) updated = updated || up } if updated { s.logger.UDPStreamPropUpdate(s.info, true) } s.doneEntries = append(s.doneEntries, s.activeEntries...) s.activeEntries = nil } func (s *udpStream) feedEntry(entry *udpStreamEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) { update, done = entry.Stream.Feed(rev, data) if entry.HasLimit { entry.Quota -= len(data) if entry.Quota <= 0 { // Quota exhausted, signal close & move to doneEntries closeUpdate = entry.Stream.Close(true) done = true } } return } func analyzersToUDPAnalyzers(ans []analyzer.Analyzer) []analyzer.UDPAnalyzer { udpAns := make([]analyzer.UDPAnalyzer, 0, len(ans)) for _, a := range ans { if udpM, ok := a.(analyzer.UDPAnalyzer); ok { udpAns = append(udpAns, udpM) } } return udpAns } func actionToUDPVerdict(a ruleset.Action) (v udpVerdict, final bool) { switch a { case ruleset.ActionMaybe: return udpVerdictAccept, false case ruleset.ActionAllow: return udpVerdictAcceptStream, true case ruleset.ActionBlock: return udpVerdictDropStream, true case ruleset.ActionDrop: return udpVerdictDrop, false case ruleset.ActionModify: return udpVerdictAcceptModify, false default: // Should never happen return udpVerdictAccept, false } }