fix: eliminate stale verdict poisoning, memory leaks, data races, and per-packet allocations in engine

This commit is contained in:
2026-05-15 02:08:22 +00:00
parent bc25169f41
commit 301c252c43
15 changed files with 222 additions and 163 deletions
+2 -2
View File
@@ -130,8 +130,8 @@ func (s *httpStream) parseResponseLine() utils.LSMAction {
return utils.LSMActionCancel
}
version := fields[0]
status, _ := strconv.Atoi(fields[1])
if !strings.HasPrefix(version, "HTTP/") || status == 0 {
status, err := strconv.Atoi(fields[1])
if err != nil || !strings.HasPrefix(version, "HTTP/") || status == 0 {
// Invalid version
return utils.LSMActionCancel
}
+4 -2
View File
@@ -6,6 +6,8 @@ import (
"git.difuse.io/Difuse/Mellaris/analyzer/utils"
)
const maxHandshakeLen = 65536
var _ analyzer.TCPAnalyzer = (*TLSAnalyzer)(nil)
type TLSAnalyzer struct{}
@@ -123,7 +125,7 @@ func (s *tlsStream) tlsClientHelloPreprocess() utils.LSMAction {
}
s.clientHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8])
if s.clientHelloLen < minDataSize {
if s.clientHelloLen < minDataSize || s.clientHelloLen > maxHandshakeLen {
return utils.LSMActionCancel
}
@@ -167,7 +169,7 @@ func (s *tlsStream) tlsServerHelloPreprocess() utils.LSMAction {
}
s.serverHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8])
if s.serverHelloLen < minDataSize {
if s.serverHelloLen < minDataSize || s.serverHelloLen > maxHandshakeLen {
return utils.LSMActionCancel
}
+3 -2
View File
@@ -38,6 +38,7 @@ const (
OpenVPNMinPktLen = 6
OpenVPNTCPPktDefaultLimit = 256
OpenVPNUDPPktDefaultLimit = 256
OpenVPNTCPMaxPktLen = 4096
)
type OpenVPNAnalyzer struct{}
@@ -195,7 +196,7 @@ func newOpenVPNUDPStream(logger analyzer.Logger) *openvpnUDPStream {
}
func (o *openvpnUDPStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, d bool) {
if len(data) == 0 {
if len(data) < OpenVPNMinPktLen {
return nil, false
}
var update *analyzer.PropUpdate
@@ -338,7 +339,7 @@ func (o *openvpnTCPStream) parsePkt(rev bool) (p *openvpnPkt, action utils.LSMAc
return nil, utils.LSMActionPause
}
if pktLen < OpenVPNMinPktLen {
if pktLen < OpenVPNMinPktLen || pktLen > OpenVPNTCPMaxPktLen {
return nil, utils.LSMActionCancel
}
+4
View File
@@ -14,6 +14,7 @@ import (
const (
quicInvalidCountThreshold = 16
quicMaxCryptoDataLen = 256 * 1024
quicMaxFrameEntries = 100
)
var (
@@ -158,6 +159,9 @@ func (s *quicStream) mergeFrame(offset int64, data []byte) {
if len(data) == 0 || offset < 0 {
return
}
if len(s.frames) >= quicMaxFrameEntries {
return
}
if s.frames == nil {
s.frames = make(map[int64][]byte)
}
+33 -4
View File
@@ -5,6 +5,7 @@ import (
"runtime"
"sync"
"sync/atomic"
"time"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
@@ -15,8 +16,14 @@ var _ Engine = (*engine)(nil)
type verdictEntry struct {
Verdict io.Verdict
Gen int64
CreatedAt time.Time
}
const (
verdictTTL = 15 * time.Second
verdictSweepInterval = 15 * time.Second
)
type engine struct {
logger Logger
io io.PacketIO
@@ -39,7 +46,7 @@ func NewEngine(config Config) (Engine, error) {
}
overflowPolicy := config.OverflowPolicy
if overflowPolicy == "" {
overflowPolicy = OverflowPolicyAccept
overflowPolicy = OverflowPolicyDrop
}
selectionMode := config.AnalyzerSelectionMode
if selectionMode == "" {
@@ -83,7 +90,6 @@ func NewEngine(config Config) (Engine, error) {
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
e.verdictsGen.Add(1)
e.verdicts = sync.Map{}
for _, w := range e.workers {
if err := w.UpdateRuleset(r); err != nil {
return err
@@ -100,6 +106,7 @@ func (e *engine) Run(ctx context.Context) error {
go w.Run(ioCtx)
}
go e.drainResults(ioCtx)
go e.sweepVerdicts(ioCtx)
errChan := make(chan error, 1)
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
@@ -124,6 +131,7 @@ func (e *engine) Run(ctx context.Context) error {
func (e *engine) dispatch(p io.Packet) bool {
streamID := p.StreamID()
if streamID != 0 {
if v, ok := e.verdicts.Load(streamID); ok {
entry := v.(verdictEntry)
if entry.Gen == e.verdictsGen.Load() {
@@ -131,6 +139,7 @@ func (e *engine) dispatch(p io.Packet) bool {
return true
}
}
}
data := p.Data()
if !validPacket(data) {
@@ -163,12 +172,32 @@ func (e *engine) dispatch(p io.Packet) bool {
}
func (e *engine) applyWorkerResult(r workerResult) {
if r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream {
e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen})
if r.StreamID != 0 && (r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream) {
e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen, CreatedAt: time.Now()})
}
_ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket)
}
func (e *engine) sweepVerdicts(ctx context.Context) {
ticker := time.NewTicker(verdictSweepInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
e.verdicts.Range(func(key, value interface{}) bool {
entry := value.(verdictEntry)
if now.Sub(entry.CreatedAt) > verdictTTL {
e.verdicts.Delete(key)
}
return true
})
}
}
}
func validPacket(data []byte) bool {
if len(data) == 0 {
return false
+4 -1
View File
@@ -59,7 +59,10 @@ func ParseL3(data []byte) (l3 L3Info, transport []byte, ok bool) {
return
}
totalLen := int(uint16(data[2])<<8 | uint16(data[3]))
if totalLen < int(ihl)*4 || totalLen > len(data) {
if totalLen < int(ihl)*4 {
return
}
if totalLen > len(data) {
totalLen = len(data)
}
return L3Info{
+13 -28
View File
@@ -1,7 +1,6 @@
package engine
import (
"net"
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
@@ -9,8 +8,6 @@ import (
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
type fixedRuleset struct {
@@ -60,25 +57,19 @@ func TestUDPStreamUsesUpdatedRuleset(t *testing.T) {
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
}
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
udp := &layers.UDP{
SrcPort: 12345,
DstPort: 53,
BaseLayer: layers.BaseLayer{
Payload: []byte("query"),
},
}
tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 12345, BPort: 53}
payload := []byte("query")
ctx := &udpContext{Verdict: udpVerdictAccept}
s := f.New(ipFlow, udp.TransportFlow(), udp, ctx)
s := f.New(tuple, payload, ctx)
if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil {
t.Fatalf("update ruleset: %v", err)
}
if !s.Accept(udp, false, ctx) {
if !s.Accept(false, ctx) {
t.Fatalf("unexpected Accept=false for virgin stream")
}
s.Feed(udp, false, ctx)
s.Feed(false, payload, ctx)
if ctx.Verdict != udpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx.Verdict, udpVerdictDropStream)
}
@@ -96,21 +87,15 @@ func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
}
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
udp := &layers.UDP{
SrcPort: 12345,
DstPort: 53,
BaseLayer: layers.BaseLayer{
Payload: []byte("query"),
},
}
tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 12345, BPort: 53}
payload := []byte("query")
ctx1 := &udpContext{Verdict: udpVerdictAccept}
s := f.New(ipFlow, udp.TransportFlow(), udp, ctx1)
if !s.Accept(udp, false, ctx1) {
s := f.New(tuple, payload, ctx1)
if !s.Accept(false, ctx1) {
t.Fatalf("unexpected Accept=false before first feed")
}
s.Feed(udp, false, ctx1)
s.Feed(false, payload, ctx1)
if ctx1.Verdict != udpVerdictAcceptStream {
t.Fatalf("verdict=%v want=%v", ctx1.Verdict, udpVerdictAcceptStream)
}
@@ -120,16 +105,16 @@ func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
}
ctx2 := &udpContext{Verdict: udpVerdictAccept}
if !s.Accept(udp, false, ctx2) {
if !s.Accept(false, ctx2) {
t.Fatalf("expected Accept=true after ruleset update")
}
s.Feed(udp, false, ctx2)
s.Feed(false, payload, ctx2)
if ctx2.Verdict != udpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx2.Verdict, udpVerdictDropStream)
}
ctx3 := &udpContext{Verdict: udpVerdictAccept}
if s.Accept(udp, false, ctx3) {
if s.Accept(false, ctx3) {
t.Fatalf("expected Accept=false with unchanged ruleset and no active entries")
}
if ctx3.Verdict != udpVerdictDropStream {
+24 -3
View File
@@ -3,6 +3,7 @@ package engine
import (
"net"
"sync"
"time"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io"
@@ -13,6 +14,8 @@ import (
const tcpFlowMaxBuffer = 16384
const tcpFlowIdleTimeout = 10 * time.Minute
type tcpFlowDirection uint8
const (
@@ -37,6 +40,7 @@ type tcpFlow struct {
doneEntries []*tcpFlowEntry
lastVerdict io.Verdict
feedCalled [2]bool
lastSeen time.Time
}
type tcpFlowEntry struct {
@@ -67,16 +71,17 @@ func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
expected := f.dirSeq[dir]
if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected {
f.feedCalled[dir] = true
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
if len(f.dirBuf[dir]) <= tcpFlowMaxBuffer {
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
propUpdated = f.feedAnalyzers(rev)
}
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
}
}
f.runMatch(rs, version, rulesetChanged, propUpdated)
f.maybeFinalizeVerdict()
f.lastSeen = time.Now()
return f.lastVerdict
}
@@ -218,7 +223,11 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
Props: make(analyzer.CombinedPropMap),
}
m.logger.TCPStreamNew(m.workerID, info)
rs, version := m.rulesetSource()
var rs ruleset.Ruleset
var version uint64
if m.rulesetSource != nil {
rs, version = m.rulesetSource()
}
var ans []analyzer.TCPAnalyzer
if rs != nil {
baseAns := rs.Analyzers(info)
@@ -255,6 +264,7 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
rulesetVersion: version,
activeEntries: entries,
lastVerdict: io.VerdictAccept,
lastSeen: time.Now(),
}
flow.dirSeq[tcpDirC2S] = tcp.Seq + 1
return flow
@@ -266,6 +276,17 @@ func (m *tcpFlowManager) updateRuleset(r ruleset.Ruleset, version uint64) {
}
}
func (m *tcpFlowManager) cleanupIdle(now time.Time) {
m.mu.Lock()
defer m.mu.Unlock()
for id, flow := range m.flows {
if now.Sub(flow.lastSeen) > tcpFlowIdleTimeout {
flow.closeActiveEntries()
delete(m.flows, id)
}
}
}
func feedFlowEntry(entry *tcpFlowEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
if !entry.HasLimit {
update, done = entry.Stream.Feed(rev, true, false, 0, data)
+49 -53
View File
@@ -12,8 +12,6 @@ import (
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
lru "github.com/hashicorp/golang-lru/v2"
)
@@ -49,9 +47,10 @@ type udpStreamFactory struct {
RulesetVersion uint64
}
func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) *udpStream {
func (f *udpStreamFactory) New(k udpTupleKey, payload []byte, uc *udpContext) *udpStream {
id := f.Node.Generate()
ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw())
ipSrc := net.IP(k.AIP[:k.ALen])
ipDst := net.IP(k.BIP[:k.BLen])
info := ruleset.StreamInfo{
ID: id.Int64(),
Protocol: ruleset.ProtocolUDP,
@@ -59,8 +58,8 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
DstMAC: append(net.HardwareAddr(nil), uc.DstMAC...),
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: uint16(udp.SrcPort),
DstPort: uint16(udp.DstPort),
SrcPort: k.APort,
DstPort: k.BPort,
Props: make(analyzer.CombinedPropMap),
}
f.Logger.UDPStreamNew(f.WorkerID, info)
@@ -69,11 +68,10 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
if rs != nil {
baseAns := rs.Analyzers(info)
if f.Selector != nil {
baseAns = f.Selector.SelectUDP(baseAns, udp.Payload)
baseAns = f.Selector.SelectUDP(baseAns, payload)
}
ans = analyzersToUDPAnalyzers(baseAns)
}
// Create entries for each analyzer
entries := make([]*udpStreamEntry, 0, len(ans))
for _, a := range ans {
entries = append(entries, &udpStreamEntry{
@@ -81,8 +79,8 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
Stream: a.NewUDP(analyzer.UDPInfo{
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: uint16(udp.SrcPort),
DstPort: uint16(udp.DstPort),
SrcPort: k.APort,
DstPort: k.BPort,
}, &analyzerLogger{
StreamID: id.Int64(),
Name: a.Name(),
@@ -126,8 +124,13 @@ type udpStreamManager struct {
type udpStreamValue struct {
Stream *udpStream
IPFlow gopacket.Flow
UDPFlow gopacket.Flow
Tuple udpTupleKey
}
func (v *udpStreamValue) Match(k udpTupleKey) (ok, rev bool) {
fwd := v.Tuple == k
rev = v.Tuple == reverseTuple(k)
return fwd || rev, rev
}
type udpTupleKey struct {
@@ -139,12 +142,6 @@ type udpTupleKey struct {
BPort uint16
}
func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) {
fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow
rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()
return fwd || rev, rev
}
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) {
m := &udpStreamManager{
factory: factory,
@@ -153,6 +150,9 @@ func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *stats
stats: stats,
}
ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) {
if v != nil && v.Stream != nil {
v.Stream.Close()
}
m.removeTupleMappingLocked(k)
})
if err != nil {
@@ -162,16 +162,12 @@ func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *stats
return m, nil
}
func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) {
rev := false
func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, rev bool, payload []byte, uc *udpContext) {
value, ok := m.streams.Get(streamID)
tuple := canonicalUDPTupleKey(ipFlow, udp)
if !ok {
if m.stats != nil {
m.stats.UDPTupleLookups.Add(1)
}
// Conntrack IDs can change during early flow lifetime on some systems.
// Rebind by canonical 5-tuple in O(1).
matchedKey, found := m.tupleIndex[tuple]
var matchedValue *udpStreamValue
var matchedRev bool
@@ -188,7 +184,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
}
}
if found {
_, matchedRev = matchedValue.Match(ipFlow, udp.TransportFlow())
_, matchedRev = matchedValue.Match(tuple)
value = matchedValue
rev = matchedRev
if matchedKey != streamID {
@@ -197,32 +193,27 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
m.bindTupleLocked(streamID, tuple)
}
} else {
// New stream
value = &udpStreamValue{
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc),
IPFlow: ipFlow,
UDPFlow: udp.TransportFlow(),
Stream: m.factory.New(tuple, payload, uc),
Tuple: tuple,
}
m.streams.Add(streamID, value)
m.bindTupleLocked(streamID, tuple)
}
} else {
// Stream ID exists, but is it really the same stream?
ok, rev = value.Match(ipFlow, udp.TransportFlow())
ok, rev = value.Match(tuple)
if !ok {
// It's not - close the old stream & replace it with a new one
value.Stream.Close()
value = &udpStreamValue{
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc),
IPFlow: ipFlow,
UDPFlow: udp.TransportFlow(),
Stream: m.factory.New(tuple, payload, uc),
Tuple: tuple,
}
m.streams.Add(streamID, value)
m.bindTupleLocked(streamID, tuple)
}
}
if value.Stream.Accept(udp, rev, uc) {
value.Stream.Feed(udp, rev, uc)
if value.Stream.Accept(rev, uc) {
value.Stream.Feed(rev, payload, uc)
}
}
@@ -242,25 +233,34 @@ func (m *udpStreamManager) removeTupleMappingLocked(streamID uint32) {
}
}
func canonicalUDPTupleKey(ipFlow gopacket.Flow, udp *layers.UDP) udpTupleKey {
srcIP := ipFlow.Src().Raw()
dstIP := ipFlow.Dst().Raw()
srcPort := uint16(udp.SrcPort)
dstPort := uint16(udp.DstPort)
func canonicalUDPTupleKey(srcIP, dstIP net.IP, srcPort, dstPort uint16) udpTupleKey {
srcRaw := []byte(srcIP)
dstRaw := []byte(dstIP)
if compareIPEndpoint(srcIP, srcPort, dstIP, dstPort) > 0 {
srcIP, dstIP = dstIP, srcIP
if compareIPEndpoint(srcRaw, srcPort, dstRaw, dstPort) > 0 {
srcRaw, dstRaw = dstRaw, srcRaw
srcPort, dstPort = dstPort, srcPort
}
var key udpTupleKey
key.ALen = uint8(copy(key.AIP[:], srcIP))
key.BLen = uint8(copy(key.BIP[:], dstIP))
key.ALen = uint8(copy(key.AIP[:], srcRaw))
key.BLen = uint8(copy(key.BIP[:], dstRaw))
key.APort = srcPort
key.BPort = dstPort
return key
}
func reverseTuple(k udpTupleKey) udpTupleKey {
var r udpTupleKey
r.ALen = k.BLen
r.BLen = k.ALen
r.AIP = k.BIP
r.BIP = k.AIP
r.APort = k.BPort
r.BPort = k.APort
return r
}
func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int {
if len(aIP) != len(bIP) {
if len(aIP) < len(bIP) {
@@ -298,11 +298,8 @@ type udpStreamEntry struct {
Quota int
}
func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool {
func (s *udpStream) Accept(rev bool, uc *udpContext) bool {
if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() {
// Make sure every stream matches against the ruleset at least once,
// even if there are no activeEntries, as the ruleset may have built-in
// properties that need to be matched.
return true
} else {
uc.Verdict = s.lastVerdict
@@ -310,12 +307,11 @@ func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool {
}
}
func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) {
func (s *udpStream) Feed(rev bool, payload []byte, uc *udpContext) {
updated := false
for i := len(s.activeEntries) - 1; i >= 0; i-- {
// Important: reverse order so we can remove entries
entry := s.activeEntries[i]
update, closeUpdate, done := s.feedEntry(entry, rev, udp.Payload)
update, closeUpdate, done := s.feedEntry(entry, rev, payload)
up1 := processPropUpdate(s.info.Props, entry.Name, update)
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
updated = updated || up1 || up2
@@ -345,7 +341,7 @@ func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) {
action = ruleset.ActionMaybe
} else {
var err error
uc.Packet, err = udpMI.Process(udp.Payload)
uc.Packet, err = udpMI.Process(payload)
if err != nil {
// Modifier error, fallback to maybe
s.logger.ModifyError(s.info, err)
+26 -30
View File
@@ -1,20 +1,16 @@
package engine
import (
"net"
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
type legacyUDPStreamValue struct {
IPFlow gopacket.Flow
UDPFlow gopacket.Flow
Tuple udpTupleKey
}
type emptyRuleset struct{}
@@ -36,17 +32,20 @@ func benchmarkUDPManager(b *testing.B, churn bool) {
}
const flowCount = 20000
flows := make([]gopacket.Flow, flowCount)
udps := make([]*layers.UDP, flowCount)
tuples := make([]udpTupleKey, flowCount)
payloads := make([][]byte, flowCount)
for i := 0; i < flowCount; i++ {
a := byte(i >> 8)
c := byte(i)
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4())
udps[i] = &layers.UDP{
SrcPort: layers.UDPPort(1024 + i%20000),
DstPort: layers.UDPPort(20000 + (i*7)%20000),
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}},
}
var t udpTupleKey
t.AIP = [16]byte{10, a, 0, c}
t.ALen = 4
t.BIP = [16]byte{172, 16, a, c}
t.BLen = 4
t.APort = 1024 + uint16(i%20000)
t.BPort = 20000 + uint16((i*7)%20000)
tuples[i] = t
payloads[i] = []byte{0x01, 0x00, 0x00, 0x00}
}
ctx := &udpContext{Verdict: udpVerdictAccept}
@@ -59,7 +58,7 @@ func benchmarkUDPManager(b *testing.B, churn bool) {
}
ctx.Verdict = udpVerdictAccept
ctx.Packet = nil
mgr.MatchWithContext(streamID, flows[idx], udps[idx], ctx)
mgr.MatchWithContext(streamID, tuples[idx], false, payloads[idx], ctx)
}
}
@@ -73,27 +72,25 @@ func BenchmarkUDPManagerMatchStreamIDChurn(b *testing.B) {
func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) {
const flowCount = 5000
flows := make([]gopacket.Flow, flowCount)
udps := make([]*layers.UDP, flowCount)
tuples := make([]udpTupleKey, flowCount)
for i := 0; i < flowCount; i++ {
a := byte(i >> 8)
c := byte(i)
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4())
udps[i] = &layers.UDP{
SrcPort: layers.UDPPort(1024 + i%20000),
DstPort: layers.UDPPort(20000 + (i*7)%20000),
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}},
}
var t udpTupleKey
t.AIP = [16]byte{10, a, 0, c}
t.ALen = 4
t.BIP = [16]byte{172, 16, a, c}
t.BLen = 4
t.APort = 1024 + uint16(i%20000)
t.BPort = 20000 + uint16((i*7)%20000)
tuples[i] = t
}
streams := make(map[uint32]*legacyUDPStreamValue, flowCount)
keys := make([]uint32, 0, flowCount)
for i := 0; i < flowCount; i++ {
streamID := uint32(i + 1)
streams[streamID] = &legacyUDPStreamValue{
IPFlow: flows[i],
UDPFlow: udps[i].TransportFlow(),
}
streams[streamID] = &legacyUDPStreamValue{Tuple: tuples[i]}
keys = append(keys, streamID)
}
@@ -104,15 +101,14 @@ func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) {
if _, ok := streams[streamID]; ok {
continue
}
ipFlow := flows[idx]
udpFlow := udps[idx].TransportFlow()
tuple := tuples[idx]
revTuple := reverseTuple(tuple)
for _, k := range keys {
v, ok := streams[k]
if !ok || v == nil {
continue
}
if (v.IPFlow == ipFlow && v.UDPFlow == udpFlow) ||
(v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()) {
if v.Tuple == tuple || v.Tuple == revTuple {
delete(streams, k)
streams[streamID] = v
break
+4 -7
View File
@@ -1,7 +1,6 @@
package engine
import (
"net"
"sync/atomic"
"testing"
@@ -9,8 +8,6 @@ import (
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
type countingRuleset struct {
@@ -54,17 +51,17 @@ func TestUDPStreamManagerRebindsByTupleInO1Path(t *testing.T) {
t.Fatalf("new manager: %v", err)
}
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
udp := &layers.UDP{SrcPort: 50000, DstPort: 443, BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}}
tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 50000, BPort: 443}
payload := []byte{0x01, 0x00, 0x00, 0x00}
ctx1 := &udpContext{Verdict: udpVerdictAccept}
mgr.MatchWithContext(100, ipFlow, udp, ctx1)
mgr.MatchWithContext(100, tuple, false, payload, ctx1)
if got := newCalls.Load(); got != 1 {
t.Fatalf("new stream calls=%d want=1", got)
}
ctx2 := &udpContext{Verdict: udpVerdictAccept}
mgr.MatchWithContext(200, ipFlow, udp, ctx2)
mgr.MatchWithContext(200, tuple, false, payload, ctx2)
if got := newCalls.Load(); got != 1 {
t.Fatalf("expected stream reuse by tuple, new stream calls=%d want=1", got)
}
+11 -16
View File
@@ -3,6 +3,7 @@ package engine
import (
"context"
"net"
"time"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
@@ -119,10 +120,16 @@ func (w *worker) FeedBlocking(p *workerPacket) {
func (w *worker) Run(ctx context.Context) {
w.logger.WorkerStart(w.id)
defer w.logger.WorkerStop(w.id)
tcpSweepTicker := time.NewTicker(1 * time.Minute)
defer tcpSweepTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-tcpSweepTicker.C:
w.tcpFlowMgr.cleanupIdle(time.Now())
case wp := <-w.packetChan:
if wp == nil {
return
@@ -202,15 +209,6 @@ func (w *worker) handleIPPacket(wp *workerPacket, data []byte) (io.Verdict, []by
func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) {
ipSrc := l3.SrcIPAddr()
ipDst := l3.DstIPAddr()
endpointType := layers.EndpointIPv4
flowSrc := ipSrc.To4()
flowDst := ipDst.To4()
if l3.Version == 6 {
endpointType = layers.EndpointIPv6
flowSrc = ipSrc.To16()
flowDst = ipDst.To16()
}
ipFlow := gopacket.NewFlow(endpointType, flowSrc, flowDst)
if len(srcMAC) == 0 && w.macResolver != nil {
srcMAC = w.macResolver.Resolve(ipSrc)
@@ -221,12 +219,9 @@ func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []by
SrcMAC: srcMAC,
DstMAC: dstMAC,
}
// Temporarily set payload on a UDP layer so existing UDP handling works.
w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{
BaseLayer: layers.BaseLayer{Payload: payload},
SrcPort: layers.UDPPort(udp.SrcPort),
DstPort: layers.UDPPort(udp.DstPort),
}, uc)
tuple := canonicalUDPTupleKey(ipSrc, ipDst, udp.SrcPort, udp.DstPort)
w.udpSM.MatchWithContext(streamID, tuple, false, payload, uc)
return io.Verdict(uc.Verdict), uc.Packet
}
@@ -253,7 +248,7 @@ func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, modPayload []b
if err != nil {
return io.VerdictAccept, nil
}
return io.VerdictAcceptModify, w.modSerializeBuffer.Bytes()
return io.VerdictAcceptModify, append([]byte(nil), w.modSerializeBuffer.Bytes()...)
}
func extractL3PayloadFromEthernet(data []byte) ([]byte, bool) {
+1 -1
View File
@@ -456,7 +456,7 @@ func ctIDFromCtBytes(ct []byte) uint32 {
return 0
}
for _, attr := range ctAttrs {
if attr.Type == 12 { // CTA_ID
if attr.Type == 12 && len(attr.Data) >= 4 { // CTA_ID
return binary.BigEndian.Uint32(attr.Data)
}
}
+6
View File
@@ -4,6 +4,7 @@ import (
"io"
"net/http"
"os"
"sync"
"time"
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
@@ -31,6 +32,7 @@ type V2GeoLoader struct {
DownloadFunc func(filename, url string)
DownloadErrFunc func(err error)
mu sync.Mutex
geoipMap map[string]*v2geo.GeoIP
geositeMap map[string]*v2geo.GeoSite
}
@@ -80,6 +82,8 @@ func (l *V2GeoLoader) download(filename, url string) error {
}
func (l *V2GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) {
l.mu.Lock()
defer l.mu.Unlock()
if l.geoipMap != nil {
return l.geoipMap, nil
}
@@ -104,6 +108,8 @@ func (l *V2GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) {
}
func (l *V2GeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) {
l.mu.Lock()
defer l.mu.Unlock()
if l.geositeMap != nil {
return l.geositeMap, nil
}
+30 -6
View File
@@ -519,7 +519,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
InitFunc: geoMatcher.LoadGeoIP,
PatchFunc: nil,
Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil
a, ok1 := params[0].(string)
b, ok2 := params[1].(string)
if !ok1 || !ok2 {
return false, nil
}
return geoMatcher.MatchGeoIp(a, b), nil
},
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)},
},
@@ -527,7 +532,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
InitFunc: geoMatcher.LoadGeoSite,
PatchFunc: nil,
Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil
a, ok1 := params[0].(string)
b, ok2 := params[1].(string)
if !ok1 || !ok2 {
return false, nil
}
return geoMatcher.MatchGeoSite(a, b), nil
},
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
},
@@ -535,7 +545,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
InitFunc: geoMatcher.LoadGeoSite,
PatchFunc: nil,
Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoSiteSet(params[0].(string), params[1].(*geo.SiteConditionSet)), nil
a, ok1 := params[0].(string)
b, ok2 := params[1].(*geo.SiteConditionSet)
if !ok1 || !ok2 {
return false, nil
}
return geoMatcher.MatchGeoSiteSet(a, b), nil
},
Types: []reflect.Type{
reflect.TypeOf((func(string, *geo.SiteConditionSet) bool)(nil)),
@@ -556,7 +571,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
return nil
},
Func: func(params ...any) (any, error) {
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil
a, ok1 := params[0].(string)
b, ok2 := params[1].(*net.IPNet)
if !ok1 || !ok2 {
return false, nil
}
return builtins.MatchCIDR(a, b), nil
},
Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)},
},
@@ -565,7 +585,6 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
PatchFunc: func(args *[]ast.Node) error {
var serverStr *ast.StringNode
if len(*args) > 1 {
// Has the optional server argument
var ok bool
serverStr, ok = (*args)[1].(*ast.StringNode)
if !ok {
@@ -595,9 +614,14 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
stats.LookupLatencyNanos.Add(uint64(time.Since(start).Nanoseconds()))
}()
}
a, ok1 := params[0].(string)
b, ok2 := params[1].(*net.Resolver)
if !ok1 || !ok2 {
return nil, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
defer cancel()
out, err := params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
out, err := b.LookupHost(ctx, a)
if err != nil && stats != nil {
stats.LookupErrors.Add(1)
}