fix: eliminate stale verdict poisoning, memory leaks, data races, and per-packet allocations in engine
This commit is contained in:
+40
-11
@@ -5,6 +5,7 @@ import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/io"
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||
@@ -13,10 +14,16 @@ import (
|
||||
var _ Engine = (*engine)(nil)
|
||||
|
||||
type verdictEntry struct {
|
||||
Verdict io.Verdict
|
||||
Gen int64
|
||||
Verdict io.Verdict
|
||||
Gen int64
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
verdictTTL = 15 * time.Second
|
||||
verdictSweepInterval = 15 * time.Second
|
||||
)
|
||||
|
||||
type engine struct {
|
||||
logger Logger
|
||||
io io.PacketIO
|
||||
@@ -39,7 +46,7 @@ func NewEngine(config Config) (Engine, error) {
|
||||
}
|
||||
overflowPolicy := config.OverflowPolicy
|
||||
if overflowPolicy == "" {
|
||||
overflowPolicy = OverflowPolicyAccept
|
||||
overflowPolicy = OverflowPolicyDrop
|
||||
}
|
||||
selectionMode := config.AnalyzerSelectionMode
|
||||
if selectionMode == "" {
|
||||
@@ -83,7 +90,6 @@ func NewEngine(config Config) (Engine, error) {
|
||||
|
||||
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
|
||||
e.verdictsGen.Add(1)
|
||||
e.verdicts = sync.Map{}
|
||||
for _, w := range e.workers {
|
||||
if err := w.UpdateRuleset(r); err != nil {
|
||||
return err
|
||||
@@ -100,6 +106,7 @@ func (e *engine) Run(ctx context.Context) error {
|
||||
go w.Run(ioCtx)
|
||||
}
|
||||
go e.drainResults(ioCtx)
|
||||
go e.sweepVerdicts(ioCtx)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
|
||||
@@ -124,11 +131,13 @@ func (e *engine) Run(ctx context.Context) error {
|
||||
func (e *engine) dispatch(p io.Packet) bool {
|
||||
streamID := p.StreamID()
|
||||
|
||||
if v, ok := e.verdicts.Load(streamID); ok {
|
||||
entry := v.(verdictEntry)
|
||||
if entry.Gen == e.verdictsGen.Load() {
|
||||
_ = e.io.SetVerdict(p, entry.Verdict, nil)
|
||||
return true
|
||||
if streamID != 0 {
|
||||
if v, ok := e.verdicts.Load(streamID); ok {
|
||||
entry := v.(verdictEntry)
|
||||
if entry.Gen == e.verdictsGen.Load() {
|
||||
_ = e.io.SetVerdict(p, entry.Verdict, nil)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,12 +172,32 @@ func (e *engine) dispatch(p io.Packet) bool {
|
||||
}
|
||||
|
||||
func (e *engine) applyWorkerResult(r workerResult) {
|
||||
if r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream {
|
||||
e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen})
|
||||
if r.StreamID != 0 && (r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream) {
|
||||
e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen, CreatedAt: time.Now()})
|
||||
}
|
||||
_ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket)
|
||||
}
|
||||
|
||||
func (e *engine) sweepVerdicts(ctx context.Context) {
|
||||
ticker := time.NewTicker(verdictSweepInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
e.verdicts.Range(func(key, value interface{}) bool {
|
||||
entry := value.(verdictEntry)
|
||||
if now.Sub(entry.CreatedAt) > verdictTTL {
|
||||
e.verdicts.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func validPacket(data []byte) bool {
|
||||
if len(data) == 0 {
|
||||
return false
|
||||
|
||||
+4
-1
@@ -59,7 +59,10 @@ func ParseL3(data []byte) (l3 L3Info, transport []byte, ok bool) {
|
||||
return
|
||||
}
|
||||
totalLen := int(uint16(data[2])<<8 | uint16(data[3]))
|
||||
if totalLen < int(ihl)*4 || totalLen > len(data) {
|
||||
if totalLen < int(ihl)*4 {
|
||||
return
|
||||
}
|
||||
if totalLen > len(data) {
|
||||
totalLen = len(data)
|
||||
}
|
||||
return L3Info{
|
||||
|
||||
+13
-28
@@ -1,7 +1,6 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||
@@ -9,8 +8,6 @@ import (
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||
|
||||
"github.com/bwmarrin/snowflake"
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
type fixedRuleset struct {
|
||||
@@ -60,25 +57,19 @@ func TestUDPStreamUsesUpdatedRuleset(t *testing.T) {
|
||||
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
|
||||
}
|
||||
|
||||
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
|
||||
udp := &layers.UDP{
|
||||
SrcPort: 12345,
|
||||
DstPort: 53,
|
||||
BaseLayer: layers.BaseLayer{
|
||||
Payload: []byte("query"),
|
||||
},
|
||||
}
|
||||
tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 12345, BPort: 53}
|
||||
payload := []byte("query")
|
||||
ctx := &udpContext{Verdict: udpVerdictAccept}
|
||||
s := f.New(ipFlow, udp.TransportFlow(), udp, ctx)
|
||||
s := f.New(tuple, payload, ctx)
|
||||
|
||||
if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil {
|
||||
t.Fatalf("update ruleset: %v", err)
|
||||
}
|
||||
|
||||
if !s.Accept(udp, false, ctx) {
|
||||
if !s.Accept(false, ctx) {
|
||||
t.Fatalf("unexpected Accept=false for virgin stream")
|
||||
}
|
||||
s.Feed(udp, false, ctx)
|
||||
s.Feed(false, payload, ctx)
|
||||
if ctx.Verdict != udpVerdictDropStream {
|
||||
t.Fatalf("verdict=%v want=%v", ctx.Verdict, udpVerdictDropStream)
|
||||
}
|
||||
@@ -96,21 +87,15 @@ func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
|
||||
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
|
||||
}
|
||||
|
||||
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
|
||||
udp := &layers.UDP{
|
||||
SrcPort: 12345,
|
||||
DstPort: 53,
|
||||
BaseLayer: layers.BaseLayer{
|
||||
Payload: []byte("query"),
|
||||
},
|
||||
}
|
||||
tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 12345, BPort: 53}
|
||||
payload := []byte("query")
|
||||
|
||||
ctx1 := &udpContext{Verdict: udpVerdictAccept}
|
||||
s := f.New(ipFlow, udp.TransportFlow(), udp, ctx1)
|
||||
if !s.Accept(udp, false, ctx1) {
|
||||
s := f.New(tuple, payload, ctx1)
|
||||
if !s.Accept(false, ctx1) {
|
||||
t.Fatalf("unexpected Accept=false before first feed")
|
||||
}
|
||||
s.Feed(udp, false, ctx1)
|
||||
s.Feed(false, payload, ctx1)
|
||||
if ctx1.Verdict != udpVerdictAcceptStream {
|
||||
t.Fatalf("verdict=%v want=%v", ctx1.Verdict, udpVerdictAcceptStream)
|
||||
}
|
||||
@@ -120,16 +105,16 @@ func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx2 := &udpContext{Verdict: udpVerdictAccept}
|
||||
if !s.Accept(udp, false, ctx2) {
|
||||
if !s.Accept(false, ctx2) {
|
||||
t.Fatalf("expected Accept=true after ruleset update")
|
||||
}
|
||||
s.Feed(udp, false, ctx2)
|
||||
s.Feed(false, payload, ctx2)
|
||||
if ctx2.Verdict != udpVerdictDropStream {
|
||||
t.Fatalf("verdict=%v want=%v", ctx2.Verdict, udpVerdictDropStream)
|
||||
}
|
||||
|
||||
ctx3 := &udpContext{Verdict: udpVerdictAccept}
|
||||
if s.Accept(udp, false, ctx3) {
|
||||
if s.Accept(false, ctx3) {
|
||||
t.Fatalf("expected Accept=false with unchanged ruleset and no active entries")
|
||||
}
|
||||
if ctx3.Verdict != udpVerdictDropStream {
|
||||
|
||||
+24
-3
@@ -3,6 +3,7 @@ package engine
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||
"git.difuse.io/Difuse/Mellaris/io"
|
||||
@@ -13,6 +14,8 @@ import (
|
||||
|
||||
const tcpFlowMaxBuffer = 16384
|
||||
|
||||
const tcpFlowIdleTimeout = 10 * time.Minute
|
||||
|
||||
type tcpFlowDirection uint8
|
||||
|
||||
const (
|
||||
@@ -37,6 +40,7 @@ type tcpFlow struct {
|
||||
doneEntries []*tcpFlowEntry
|
||||
lastVerdict io.Verdict
|
||||
feedCalled [2]bool
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
type tcpFlowEntry struct {
|
||||
@@ -67,16 +71,17 @@ func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
|
||||
expected := f.dirSeq[dir]
|
||||
if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected {
|
||||
f.feedCalled[dir] = true
|
||||
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
|
||||
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
|
||||
if len(f.dirBuf[dir]) <= tcpFlowMaxBuffer {
|
||||
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
|
||||
propUpdated = f.feedAnalyzers(rev)
|
||||
}
|
||||
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
|
||||
}
|
||||
}
|
||||
|
||||
f.runMatch(rs, version, rulesetChanged, propUpdated)
|
||||
f.maybeFinalizeVerdict()
|
||||
f.lastSeen = time.Now()
|
||||
return f.lastVerdict
|
||||
}
|
||||
|
||||
@@ -218,7 +223,11 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
|
||||
Props: make(analyzer.CombinedPropMap),
|
||||
}
|
||||
m.logger.TCPStreamNew(m.workerID, info)
|
||||
rs, version := m.rulesetSource()
|
||||
var rs ruleset.Ruleset
|
||||
var version uint64
|
||||
if m.rulesetSource != nil {
|
||||
rs, version = m.rulesetSource()
|
||||
}
|
||||
var ans []analyzer.TCPAnalyzer
|
||||
if rs != nil {
|
||||
baseAns := rs.Analyzers(info)
|
||||
@@ -255,6 +264,7 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
|
||||
rulesetVersion: version,
|
||||
activeEntries: entries,
|
||||
lastVerdict: io.VerdictAccept,
|
||||
lastSeen: time.Now(),
|
||||
}
|
||||
flow.dirSeq[tcpDirC2S] = tcp.Seq + 1
|
||||
return flow
|
||||
@@ -266,6 +276,17 @@ func (m *tcpFlowManager) updateRuleset(r ruleset.Ruleset, version uint64) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *tcpFlowManager) cleanupIdle(now time.Time) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for id, flow := range m.flows {
|
||||
if now.Sub(flow.lastSeen) > tcpFlowIdleTimeout {
|
||||
flow.closeActiveEntries()
|
||||
delete(m.flows, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func feedFlowEntry(entry *tcpFlowEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
|
||||
if !entry.HasLimit {
|
||||
update, done = entry.Stream.Feed(rev, true, false, 0, data)
|
||||
|
||||
+50
-54
@@ -12,8 +12,6 @@ import (
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||
|
||||
"github.com/bwmarrin/snowflake"
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
)
|
||||
|
||||
@@ -49,9 +47,10 @@ type udpStreamFactory struct {
|
||||
RulesetVersion uint64
|
||||
}
|
||||
|
||||
func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) *udpStream {
|
||||
func (f *udpStreamFactory) New(k udpTupleKey, payload []byte, uc *udpContext) *udpStream {
|
||||
id := f.Node.Generate()
|
||||
ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw())
|
||||
ipSrc := net.IP(k.AIP[:k.ALen])
|
||||
ipDst := net.IP(k.BIP[:k.BLen])
|
||||
info := ruleset.StreamInfo{
|
||||
ID: id.Int64(),
|
||||
Protocol: ruleset.ProtocolUDP,
|
||||
@@ -59,8 +58,8 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
|
||||
DstMAC: append(net.HardwareAddr(nil), uc.DstMAC...),
|
||||
SrcIP: ipSrc,
|
||||
DstIP: ipDst,
|
||||
SrcPort: uint16(udp.SrcPort),
|
||||
DstPort: uint16(udp.DstPort),
|
||||
SrcPort: k.APort,
|
||||
DstPort: k.BPort,
|
||||
Props: make(analyzer.CombinedPropMap),
|
||||
}
|
||||
f.Logger.UDPStreamNew(f.WorkerID, info)
|
||||
@@ -69,11 +68,10 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
|
||||
if rs != nil {
|
||||
baseAns := rs.Analyzers(info)
|
||||
if f.Selector != nil {
|
||||
baseAns = f.Selector.SelectUDP(baseAns, udp.Payload)
|
||||
baseAns = f.Selector.SelectUDP(baseAns, payload)
|
||||
}
|
||||
ans = analyzersToUDPAnalyzers(baseAns)
|
||||
}
|
||||
// Create entries for each analyzer
|
||||
entries := make([]*udpStreamEntry, 0, len(ans))
|
||||
for _, a := range ans {
|
||||
entries = append(entries, &udpStreamEntry{
|
||||
@@ -81,8 +79,8 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
|
||||
Stream: a.NewUDP(analyzer.UDPInfo{
|
||||
SrcIP: ipSrc,
|
||||
DstIP: ipDst,
|
||||
SrcPort: uint16(udp.SrcPort),
|
||||
DstPort: uint16(udp.DstPort),
|
||||
SrcPort: k.APort,
|
||||
DstPort: k.BPort,
|
||||
}, &analyzerLogger{
|
||||
StreamID: id.Int64(),
|
||||
Name: a.Name(),
|
||||
@@ -125,9 +123,14 @@ type udpStreamManager struct {
|
||||
}
|
||||
|
||||
type udpStreamValue struct {
|
||||
Stream *udpStream
|
||||
IPFlow gopacket.Flow
|
||||
UDPFlow gopacket.Flow
|
||||
Stream *udpStream
|
||||
Tuple udpTupleKey
|
||||
}
|
||||
|
||||
func (v *udpStreamValue) Match(k udpTupleKey) (ok, rev bool) {
|
||||
fwd := v.Tuple == k
|
||||
rev = v.Tuple == reverseTuple(k)
|
||||
return fwd || rev, rev
|
||||
}
|
||||
|
||||
type udpTupleKey struct {
|
||||
@@ -139,12 +142,6 @@ type udpTupleKey struct {
|
||||
BPort uint16
|
||||
}
|
||||
|
||||
func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) {
|
||||
fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow
|
||||
rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()
|
||||
return fwd || rev, rev
|
||||
}
|
||||
|
||||
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) {
|
||||
m := &udpStreamManager{
|
||||
factory: factory,
|
||||
@@ -153,6 +150,9 @@ func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *stats
|
||||
stats: stats,
|
||||
}
|
||||
ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) {
|
||||
if v != nil && v.Stream != nil {
|
||||
v.Stream.Close()
|
||||
}
|
||||
m.removeTupleMappingLocked(k)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -162,16 +162,12 @@ func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *stats
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) {
|
||||
rev := false
|
||||
func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, rev bool, payload []byte, uc *udpContext) {
|
||||
value, ok := m.streams.Get(streamID)
|
||||
tuple := canonicalUDPTupleKey(ipFlow, udp)
|
||||
if !ok {
|
||||
if m.stats != nil {
|
||||
m.stats.UDPTupleLookups.Add(1)
|
||||
}
|
||||
// Conntrack IDs can change during early flow lifetime on some systems.
|
||||
// Rebind by canonical 5-tuple in O(1).
|
||||
matchedKey, found := m.tupleIndex[tuple]
|
||||
var matchedValue *udpStreamValue
|
||||
var matchedRev bool
|
||||
@@ -188,7 +184,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
|
||||
}
|
||||
}
|
||||
if found {
|
||||
_, matchedRev = matchedValue.Match(ipFlow, udp.TransportFlow())
|
||||
_, matchedRev = matchedValue.Match(tuple)
|
||||
value = matchedValue
|
||||
rev = matchedRev
|
||||
if matchedKey != streamID {
|
||||
@@ -197,32 +193,27 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
|
||||
m.bindTupleLocked(streamID, tuple)
|
||||
}
|
||||
} else {
|
||||
// New stream
|
||||
value = &udpStreamValue{
|
||||
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc),
|
||||
IPFlow: ipFlow,
|
||||
UDPFlow: udp.TransportFlow(),
|
||||
Stream: m.factory.New(tuple, payload, uc),
|
||||
Tuple: tuple,
|
||||
}
|
||||
m.streams.Add(streamID, value)
|
||||
m.bindTupleLocked(streamID, tuple)
|
||||
}
|
||||
} else {
|
||||
// Stream ID exists, but is it really the same stream?
|
||||
ok, rev = value.Match(ipFlow, udp.TransportFlow())
|
||||
ok, rev = value.Match(tuple)
|
||||
if !ok {
|
||||
// It's not - close the old stream & replace it with a new one
|
||||
value.Stream.Close()
|
||||
value = &udpStreamValue{
|
||||
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc),
|
||||
IPFlow: ipFlow,
|
||||
UDPFlow: udp.TransportFlow(),
|
||||
Stream: m.factory.New(tuple, payload, uc),
|
||||
Tuple: tuple,
|
||||
}
|
||||
m.streams.Add(streamID, value)
|
||||
m.bindTupleLocked(streamID, tuple)
|
||||
}
|
||||
}
|
||||
if value.Stream.Accept(udp, rev, uc) {
|
||||
value.Stream.Feed(udp, rev, uc)
|
||||
if value.Stream.Accept(rev, uc) {
|
||||
value.Stream.Feed(rev, payload, uc)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -242,25 +233,34 @@ func (m *udpStreamManager) removeTupleMappingLocked(streamID uint32) {
|
||||
}
|
||||
}
|
||||
|
||||
func canonicalUDPTupleKey(ipFlow gopacket.Flow, udp *layers.UDP) udpTupleKey {
|
||||
srcIP := ipFlow.Src().Raw()
|
||||
dstIP := ipFlow.Dst().Raw()
|
||||
srcPort := uint16(udp.SrcPort)
|
||||
dstPort := uint16(udp.DstPort)
|
||||
func canonicalUDPTupleKey(srcIP, dstIP net.IP, srcPort, dstPort uint16) udpTupleKey {
|
||||
srcRaw := []byte(srcIP)
|
||||
dstRaw := []byte(dstIP)
|
||||
|
||||
if compareIPEndpoint(srcIP, srcPort, dstIP, dstPort) > 0 {
|
||||
srcIP, dstIP = dstIP, srcIP
|
||||
if compareIPEndpoint(srcRaw, srcPort, dstRaw, dstPort) > 0 {
|
||||
srcRaw, dstRaw = dstRaw, srcRaw
|
||||
srcPort, dstPort = dstPort, srcPort
|
||||
}
|
||||
|
||||
var key udpTupleKey
|
||||
key.ALen = uint8(copy(key.AIP[:], srcIP))
|
||||
key.BLen = uint8(copy(key.BIP[:], dstIP))
|
||||
key.ALen = uint8(copy(key.AIP[:], srcRaw))
|
||||
key.BLen = uint8(copy(key.BIP[:], dstRaw))
|
||||
key.APort = srcPort
|
||||
key.BPort = dstPort
|
||||
return key
|
||||
}
|
||||
|
||||
func reverseTuple(k udpTupleKey) udpTupleKey {
|
||||
var r udpTupleKey
|
||||
r.ALen = k.BLen
|
||||
r.BLen = k.ALen
|
||||
r.AIP = k.BIP
|
||||
r.BIP = k.AIP
|
||||
r.APort = k.BPort
|
||||
r.BPort = k.APort
|
||||
return r
|
||||
}
|
||||
|
||||
func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int {
|
||||
if len(aIP) != len(bIP) {
|
||||
if len(aIP) < len(bIP) {
|
||||
@@ -298,11 +298,8 @@ type udpStreamEntry struct {
|
||||
Quota int
|
||||
}
|
||||
|
||||
func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool {
|
||||
func (s *udpStream) Accept(rev bool, uc *udpContext) bool {
|
||||
if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() {
|
||||
// Make sure every stream matches against the ruleset at least once,
|
||||
// even if there are no activeEntries, as the ruleset may have built-in
|
||||
// properties that need to be matched.
|
||||
return true
|
||||
} else {
|
||||
uc.Verdict = s.lastVerdict
|
||||
@@ -310,12 +307,11 @@ func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) {
|
||||
func (s *udpStream) Feed(rev bool, payload []byte, uc *udpContext) {
|
||||
updated := false
|
||||
for i := len(s.activeEntries) - 1; i >= 0; i-- {
|
||||
// Important: reverse order so we can remove entries
|
||||
entry := s.activeEntries[i]
|
||||
update, closeUpdate, done := s.feedEntry(entry, rev, udp.Payload)
|
||||
update, closeUpdate, done := s.feedEntry(entry, rev, payload)
|
||||
up1 := processPropUpdate(s.info.Props, entry.Name, update)
|
||||
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
|
||||
updated = updated || up1 || up2
|
||||
@@ -345,7 +341,7 @@ func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) {
|
||||
action = ruleset.ActionMaybe
|
||||
} else {
|
||||
var err error
|
||||
uc.Packet, err = udpMI.Process(udp.Payload)
|
||||
uc.Packet, err = udpMI.Process(payload)
|
||||
if err != nil {
|
||||
// Modifier error, fallback to maybe
|
||||
s.logger.ModifyError(s.info, err)
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||
|
||||
"github.com/bwmarrin/snowflake"
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
type legacyUDPStreamValue struct {
|
||||
IPFlow gopacket.Flow
|
||||
UDPFlow gopacket.Flow
|
||||
Tuple udpTupleKey
|
||||
}
|
||||
|
||||
type emptyRuleset struct{}
|
||||
@@ -36,17 +32,20 @@ func benchmarkUDPManager(b *testing.B, churn bool) {
|
||||
}
|
||||
|
||||
const flowCount = 20000
|
||||
flows := make([]gopacket.Flow, flowCount)
|
||||
udps := make([]*layers.UDP, flowCount)
|
||||
tuples := make([]udpTupleKey, flowCount)
|
||||
payloads := make([][]byte, flowCount)
|
||||
for i := 0; i < flowCount; i++ {
|
||||
a := byte(i >> 8)
|
||||
c := byte(i)
|
||||
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4())
|
||||
udps[i] = &layers.UDP{
|
||||
SrcPort: layers.UDPPort(1024 + i%20000),
|
||||
DstPort: layers.UDPPort(20000 + (i*7)%20000),
|
||||
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}},
|
||||
}
|
||||
var t udpTupleKey
|
||||
t.AIP = [16]byte{10, a, 0, c}
|
||||
t.ALen = 4
|
||||
t.BIP = [16]byte{172, 16, a, c}
|
||||
t.BLen = 4
|
||||
t.APort = 1024 + uint16(i%20000)
|
||||
t.BPort = 20000 + uint16((i*7)%20000)
|
||||
tuples[i] = t
|
||||
payloads[i] = []byte{0x01, 0x00, 0x00, 0x00}
|
||||
}
|
||||
|
||||
ctx := &udpContext{Verdict: udpVerdictAccept}
|
||||
@@ -59,7 +58,7 @@ func benchmarkUDPManager(b *testing.B, churn bool) {
|
||||
}
|
||||
ctx.Verdict = udpVerdictAccept
|
||||
ctx.Packet = nil
|
||||
mgr.MatchWithContext(streamID, flows[idx], udps[idx], ctx)
|
||||
mgr.MatchWithContext(streamID, tuples[idx], false, payloads[idx], ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,27 +72,25 @@ func BenchmarkUDPManagerMatchStreamIDChurn(b *testing.B) {
|
||||
|
||||
func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) {
|
||||
const flowCount = 5000
|
||||
flows := make([]gopacket.Flow, flowCount)
|
||||
udps := make([]*layers.UDP, flowCount)
|
||||
tuples := make([]udpTupleKey, flowCount)
|
||||
for i := 0; i < flowCount; i++ {
|
||||
a := byte(i >> 8)
|
||||
c := byte(i)
|
||||
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4())
|
||||
udps[i] = &layers.UDP{
|
||||
SrcPort: layers.UDPPort(1024 + i%20000),
|
||||
DstPort: layers.UDPPort(20000 + (i*7)%20000),
|
||||
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}},
|
||||
}
|
||||
var t udpTupleKey
|
||||
t.AIP = [16]byte{10, a, 0, c}
|
||||
t.ALen = 4
|
||||
t.BIP = [16]byte{172, 16, a, c}
|
||||
t.BLen = 4
|
||||
t.APort = 1024 + uint16(i%20000)
|
||||
t.BPort = 20000 + uint16((i*7)%20000)
|
||||
tuples[i] = t
|
||||
}
|
||||
|
||||
streams := make(map[uint32]*legacyUDPStreamValue, flowCount)
|
||||
keys := make([]uint32, 0, flowCount)
|
||||
for i := 0; i < flowCount; i++ {
|
||||
streamID := uint32(i + 1)
|
||||
streams[streamID] = &legacyUDPStreamValue{
|
||||
IPFlow: flows[i],
|
||||
UDPFlow: udps[i].TransportFlow(),
|
||||
}
|
||||
streams[streamID] = &legacyUDPStreamValue{Tuple: tuples[i]}
|
||||
keys = append(keys, streamID)
|
||||
}
|
||||
|
||||
@@ -104,15 +101,14 @@ func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) {
|
||||
if _, ok := streams[streamID]; ok {
|
||||
continue
|
||||
}
|
||||
ipFlow := flows[idx]
|
||||
udpFlow := udps[idx].TransportFlow()
|
||||
tuple := tuples[idx]
|
||||
revTuple := reverseTuple(tuple)
|
||||
for _, k := range keys {
|
||||
v, ok := streams[k]
|
||||
if !ok || v == nil {
|
||||
continue
|
||||
}
|
||||
if (v.IPFlow == ipFlow && v.UDPFlow == udpFlow) ||
|
||||
(v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()) {
|
||||
if v.Tuple == tuple || v.Tuple == revTuple {
|
||||
delete(streams, k)
|
||||
streams[streamID] = v
|
||||
break
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
@@ -9,8 +8,6 @@ import (
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||
|
||||
"github.com/bwmarrin/snowflake"
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
type countingRuleset struct {
|
||||
@@ -54,17 +51,17 @@ func TestUDPStreamManagerRebindsByTupleInO1Path(t *testing.T) {
|
||||
t.Fatalf("new manager: %v", err)
|
||||
}
|
||||
|
||||
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
|
||||
udp := &layers.UDP{SrcPort: 50000, DstPort: 443, BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}}
|
||||
tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 50000, BPort: 443}
|
||||
payload := []byte{0x01, 0x00, 0x00, 0x00}
|
||||
|
||||
ctx1 := &udpContext{Verdict: udpVerdictAccept}
|
||||
mgr.MatchWithContext(100, ipFlow, udp, ctx1)
|
||||
mgr.MatchWithContext(100, tuple, false, payload, ctx1)
|
||||
if got := newCalls.Load(); got != 1 {
|
||||
t.Fatalf("new stream calls=%d want=1", got)
|
||||
}
|
||||
|
||||
ctx2 := &udpContext{Verdict: udpVerdictAccept}
|
||||
mgr.MatchWithContext(200, ipFlow, udp, ctx2)
|
||||
mgr.MatchWithContext(200, tuple, false, payload, ctx2)
|
||||
if got := newCalls.Load(); got != 1 {
|
||||
t.Fatalf("expected stream reuse by tuple, new stream calls=%d want=1", got)
|
||||
}
|
||||
|
||||
+11
-16
@@ -3,6 +3,7 @@ package engine
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/io"
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||
@@ -119,10 +120,16 @@ func (w *worker) FeedBlocking(p *workerPacket) {
|
||||
func (w *worker) Run(ctx context.Context) {
|
||||
w.logger.WorkerStart(w.id)
|
||||
defer w.logger.WorkerStop(w.id)
|
||||
|
||||
tcpSweepTicker := time.NewTicker(1 * time.Minute)
|
||||
defer tcpSweepTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-tcpSweepTicker.C:
|
||||
w.tcpFlowMgr.cleanupIdle(time.Now())
|
||||
case wp := <-w.packetChan:
|
||||
if wp == nil {
|
||||
return
|
||||
@@ -202,15 +209,6 @@ func (w *worker) handleIPPacket(wp *workerPacket, data []byte) (io.Verdict, []by
|
||||
func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) {
|
||||
ipSrc := l3.SrcIPAddr()
|
||||
ipDst := l3.DstIPAddr()
|
||||
endpointType := layers.EndpointIPv4
|
||||
flowSrc := ipSrc.To4()
|
||||
flowDst := ipDst.To4()
|
||||
if l3.Version == 6 {
|
||||
endpointType = layers.EndpointIPv6
|
||||
flowSrc = ipSrc.To16()
|
||||
flowDst = ipDst.To16()
|
||||
}
|
||||
ipFlow := gopacket.NewFlow(endpointType, flowSrc, flowDst)
|
||||
|
||||
if len(srcMAC) == 0 && w.macResolver != nil {
|
||||
srcMAC = w.macResolver.Resolve(ipSrc)
|
||||
@@ -221,12 +219,9 @@ func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []by
|
||||
SrcMAC: srcMAC,
|
||||
DstMAC: dstMAC,
|
||||
}
|
||||
// Temporarily set payload on a UDP layer so existing UDP handling works.
|
||||
w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{
|
||||
BaseLayer: layers.BaseLayer{Payload: payload},
|
||||
SrcPort: layers.UDPPort(udp.SrcPort),
|
||||
DstPort: layers.UDPPort(udp.DstPort),
|
||||
}, uc)
|
||||
|
||||
tuple := canonicalUDPTupleKey(ipSrc, ipDst, udp.SrcPort, udp.DstPort)
|
||||
w.udpSM.MatchWithContext(streamID, tuple, false, payload, uc)
|
||||
return io.Verdict(uc.Verdict), uc.Packet
|
||||
}
|
||||
|
||||
@@ -253,7 +248,7 @@ func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, modPayload []b
|
||||
if err != nil {
|
||||
return io.VerdictAccept, nil
|
||||
}
|
||||
return io.VerdictAcceptModify, w.modSerializeBuffer.Bytes()
|
||||
return io.VerdictAcceptModify, append([]byte(nil), w.modSerializeBuffer.Bytes()...)
|
||||
}
|
||||
|
||||
func extractL3PayloadFromEthernet(data []byte) ([]byte, bool) {
|
||||
|
||||
Reference in New Issue
Block a user