engine: more performance improvements
This commit is contained in:
+6
-1
@@ -27,6 +27,7 @@ const (
|
||||
type engine struct {
|
||||
logger Logger
|
||||
io io.PacketIO
|
||||
macResolver *sourceMACResolver
|
||||
workers []*worker
|
||||
stats *statsCounters
|
||||
verdicts sync.Map // streamID(uint32) -> verdictEntry
|
||||
@@ -46,7 +47,7 @@ func NewEngine(config Config) (Engine, error) {
|
||||
}
|
||||
overflowPolicy := config.OverflowPolicy
|
||||
if overflowPolicy == "" {
|
||||
overflowPolicy = OverflowPolicyDrop
|
||||
overflowPolicy = OverflowPolicyAccept
|
||||
}
|
||||
selectionMode := config.AnalyzerSelectionMode
|
||||
if selectionMode == "" {
|
||||
@@ -80,6 +81,7 @@ func NewEngine(config Config) (Engine, error) {
|
||||
e := &engine{
|
||||
logger: config.Logger,
|
||||
io: config.IO,
|
||||
macResolver: macResolver,
|
||||
workers: workers,
|
||||
stats: stats,
|
||||
overflowPolicy: overflowPolicy,
|
||||
@@ -105,6 +107,9 @@ func (e *engine) Run(ctx context.Context) error {
|
||||
for _, w := range e.workers {
|
||||
go w.Run(ioCtx)
|
||||
}
|
||||
if e.macResolver != nil {
|
||||
go e.macResolver.Run(ioCtx)
|
||||
}
|
||||
go e.drainResults(ioCtx)
|
||||
go e.sweepVerdicts(ioCtx)
|
||||
|
||||
|
||||
+30
-41
@@ -5,6 +5,7 @@ package engine
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -52,38 +53,6 @@ func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
r.mu.RLock()
|
||||
ifaceRefreshDue := now.Sub(r.lastIfaceRefresh) > ifaceCacheTTL
|
||||
arpRefreshDue := now.Sub(r.lastARPRefresh) > arpCacheTTL
|
||||
ndpRefreshDue := now.Sub(r.lastNDPRefresh) > ndpCacheTTL
|
||||
if mac := r.ifaceByIP[ipKey]; len(mac) != 0 {
|
||||
out := append(net.HardwareAddr(nil), mac...)
|
||||
r.mu.RUnlock()
|
||||
return out
|
||||
}
|
||||
if mac := r.arpByIP[ipKey]; len(mac) != 0 && !arpRefreshDue {
|
||||
out := append(net.HardwareAddr(nil), mac...)
|
||||
r.mu.RUnlock()
|
||||
return out
|
||||
}
|
||||
if mac := r.ndpByIP[ipKey]; len(mac) != 0 && !ndpRefreshDue {
|
||||
out := append(net.HardwareAddr(nil), mac...)
|
||||
r.mu.RUnlock()
|
||||
return out
|
||||
}
|
||||
r.mu.RUnlock()
|
||||
|
||||
if ifaceRefreshDue {
|
||||
r.refreshIfaceCache(now)
|
||||
}
|
||||
if arpRefreshDue {
|
||||
r.refreshARPCache(now)
|
||||
}
|
||||
if ndpRefreshDue {
|
||||
r.refreshNDPCache(now)
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
if mac := r.ifaceByIP[ipKey]; len(mac) != 0 {
|
||||
@@ -95,18 +64,38 @@ func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr {
|
||||
if mac := r.ndpByIP[ipKey]; len(mac) != 0 {
|
||||
return append(net.HardwareAddr(nil), mac...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// On-demand IPv6 neighbor lookup via route-netlink as a last fast path.
|
||||
if ip.To4() == nil {
|
||||
if mac, ok := lookupNeighborMACNetlink(ip); ok {
|
||||
out := append(net.HardwareAddr(nil), mac...)
|
||||
r.mu.Lock()
|
||||
r.ndpByIP[ipKey] = append(net.HardwareAddr(nil), mac...)
|
||||
r.mu.Unlock()
|
||||
return out
|
||||
func (r *sourceMACResolver) Run(ctx context.Context) {
|
||||
r.refreshAll(time.Now())
|
||||
ticker := time.NewTicker(arpCacheTTL)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case now := <-ticker.C:
|
||||
r.refreshAll(now)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *sourceMACResolver) refreshAll(now time.Time) {
|
||||
r.mu.RLock()
|
||||
ifaceRefreshDue := now.Sub(r.lastIfaceRefresh) > ifaceCacheTTL
|
||||
arpRefreshDue := now.Sub(r.lastARPRefresh) > arpCacheTTL
|
||||
ndpRefreshDue := now.Sub(r.lastNDPRefresh) > ndpCacheTTL
|
||||
r.mu.RUnlock()
|
||||
if ifaceRefreshDue {
|
||||
r.refreshIfaceCache(now)
|
||||
}
|
||||
if arpRefreshDue {
|
||||
r.refreshARPCache(now)
|
||||
}
|
||||
if ndpRefreshDue {
|
||||
r.refreshNDPCache(now)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *sourceMACResolver) refreshIfaceCache(now time.Time) {
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
|
||||
package engine
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
type sourceMACResolver struct{}
|
||||
|
||||
@@ -15,3 +18,7 @@ func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr {
|
||||
_ = ip
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *sourceMACResolver) Run(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/io"
|
||||
)
|
||||
|
||||
type recordingPacket struct {
|
||||
streamID uint32
|
||||
data []byte
|
||||
}
|
||||
|
||||
func (p recordingPacket) StreamID() uint32 { return p.streamID }
|
||||
func (p recordingPacket) Data() []byte { return p.data }
|
||||
|
||||
type recordingPacketIO struct {
|
||||
verdict io.Verdict
|
||||
}
|
||||
|
||||
func (r *recordingPacketIO) Register(context.Context, io.PacketCallback) error { return nil }
|
||||
func (r *recordingPacketIO) SetVerdict(_ io.Packet, v io.Verdict, _ []byte) error {
|
||||
r.verdict = v
|
||||
return nil
|
||||
}
|
||||
func (r *recordingPacketIO) ProtectedDialContext(context.Context, string, string) (net.Conn, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *recordingPacketIO) Close() error { return nil }
|
||||
|
||||
func TestEngineDefaultOverflowPolicyAccepts(t *testing.T) {
|
||||
packetIO := &recordingPacketIO{}
|
||||
eng, err := NewEngine(Config{
|
||||
Logger: noopTestLogger{},
|
||||
IO: packetIO,
|
||||
Workers: 1,
|
||||
WorkerQueueSize: 1,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewEngine error: %v", err)
|
||||
}
|
||||
e := eng.(*engine)
|
||||
if e.overflowPolicy != OverflowPolicyAccept {
|
||||
t.Fatalf("overflow policy=%v want=%v", e.overflowPolicy, OverflowPolicyAccept)
|
||||
}
|
||||
|
||||
e.workers[0].packetChan <- &workerPacket{}
|
||||
packet := recordingPacket{
|
||||
streamID: 1,
|
||||
data: serializeIPv6TCP(
|
||||
t,
|
||||
net.ParseIP("2001:db8::11").To16(),
|
||||
net.ParseIP("2001:db8::22").To16(),
|
||||
42310,
|
||||
443,
|
||||
1000,
|
||||
),
|
||||
}
|
||||
|
||||
e.dispatch(packet)
|
||||
stats := e.Stats()
|
||||
if packetIO.verdict != io.VerdictAccept {
|
||||
t.Fatalf("overflow verdict=%v want=%v", packetIO.verdict, io.VerdictAccept)
|
||||
}
|
||||
if stats.OverflowEvents != 1 || stats.OverflowAccepts != 1 || stats.OverflowDrops != 0 {
|
||||
t.Fatalf("overflow stats=%+v", stats)
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,89 @@ func (r fixedRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult {
|
||||
return ruleset.MatchResult{Action: r.action}
|
||||
}
|
||||
|
||||
type analyzerRuleset struct {
|
||||
action ruleset.Action
|
||||
ans []analyzer.Analyzer
|
||||
}
|
||||
|
||||
func (r analyzerRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer {
|
||||
return r.ans
|
||||
}
|
||||
|
||||
func (r analyzerRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult {
|
||||
return ruleset.MatchResult{Action: r.action}
|
||||
}
|
||||
|
||||
type countingTCPAnalyzer struct {
|
||||
newCalls *int
|
||||
feedCalls *int
|
||||
}
|
||||
|
||||
func (a countingTCPAnalyzer) Name() string { return "tls" }
|
||||
func (a countingTCPAnalyzer) Limit() int { return 0 }
|
||||
func (a countingTCPAnalyzer) NewTCP(analyzer.TCPInfo, analyzer.Logger) analyzer.TCPStream {
|
||||
(*a.newCalls)++
|
||||
return countingTCPStream{feedCalls: a.feedCalls}
|
||||
}
|
||||
|
||||
type countingTCPStream struct {
|
||||
feedCalls *int
|
||||
}
|
||||
|
||||
func (s countingTCPStream) Feed(bool, bool, bool, int, []byte) (*analyzer.PropUpdate, bool) {
|
||||
(*s.feedCalls)++
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s countingTCPStream) Close(bool) *analyzer.PropUpdate {
|
||||
return nil
|
||||
}
|
||||
|
||||
type logFinalizingRuleset struct {
|
||||
ans []analyzer.Analyzer
|
||||
}
|
||||
|
||||
func (r logFinalizingRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer {
|
||||
return r.ans
|
||||
}
|
||||
|
||||
func (r logFinalizingRuleset) Match(info ruleset.StreamInfo) ruleset.MatchResult {
|
||||
if _, ok := info.Props["tls"]; ok {
|
||||
return ruleset.MatchResult{Action: ruleset.ActionMaybe, Logged: true}
|
||||
}
|
||||
return ruleset.MatchResult{Action: ruleset.ActionMaybe}
|
||||
}
|
||||
|
||||
func (r logFinalizingRuleset) CanFinalizeAfterLog(ruleset.StreamInfo, []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type requestPropTCPAnalyzer struct {
|
||||
closeCalls *int
|
||||
}
|
||||
|
||||
func (a requestPropTCPAnalyzer) Name() string { return "tls" }
|
||||
func (a requestPropTCPAnalyzer) Limit() int { return 0 }
|
||||
func (a requestPropTCPAnalyzer) NewTCP(analyzer.TCPInfo, analyzer.Logger) analyzer.TCPStream {
|
||||
return requestPropTCPStream{closeCalls: a.closeCalls}
|
||||
}
|
||||
|
||||
type requestPropTCPStream struct {
|
||||
closeCalls *int
|
||||
}
|
||||
|
||||
func (s requestPropTCPStream) Feed(bool, bool, bool, int, []byte) (*analyzer.PropUpdate, bool) {
|
||||
return &analyzer.PropUpdate{
|
||||
Type: analyzer.PropUpdateMerge,
|
||||
M: analyzer.PropMap{"req": analyzer.PropMap{"sni": "good.example"}},
|
||||
}, false
|
||||
}
|
||||
|
||||
func (s requestPropTCPStream) Close(bool) *analyzer.PropUpdate {
|
||||
(*s.closeCalls)++
|
||||
return nil
|
||||
}
|
||||
|
||||
type noopTestLogger struct{}
|
||||
|
||||
func (noopTestLogger) WorkerStart(int) {}
|
||||
@@ -207,3 +290,90 @@ func TestTCPFlowReevaluatesAfterRulesetVersionChange(t *testing.T) {
|
||||
t.Fatalf("cached verdict after update=%v want=%v", v, io.VerdictDropStream)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPFlowDelaysAnalyzerCreationUntilPayload(t *testing.T) {
|
||||
node, err := snowflake.NewNode(0)
|
||||
if err != nil {
|
||||
t.Fatalf("create node: %v", err)
|
||||
}
|
||||
newCalls := 0
|
||||
feedCalls := 0
|
||||
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{}))
|
||||
mgr.updateRuleset(analyzerRuleset{
|
||||
action: ruleset.ActionMaybe,
|
||||
ans: []analyzer.Analyzer{countingTCPAnalyzer{
|
||||
newCalls: &newCalls,
|
||||
feedCalls: &feedCalls,
|
||||
}},
|
||||
}, 0)
|
||||
|
||||
l3 := L3Info{
|
||||
Version: 4,
|
||||
Protocol: 6,
|
||||
SrcIP: [4]byte{10, 0, 0, 1},
|
||||
DstIP: [4]byte{10, 0, 0, 2},
|
||||
}
|
||||
tcp := TCPInfo{
|
||||
SrcPort: 12345,
|
||||
DstPort: 443,
|
||||
Seq: 100,
|
||||
}
|
||||
|
||||
v := mgr.handle(1, l3, tcp, nil, nil, nil)
|
||||
if v != io.VerdictAccept {
|
||||
t.Fatalf("empty packet verdict=%v want=%v", v, io.VerdictAccept)
|
||||
}
|
||||
if newCalls != 0 || feedCalls != 0 {
|
||||
t.Fatalf("empty packet created/feed analyzer: new=%d feed=%d", newCalls, feedCalls)
|
||||
}
|
||||
|
||||
tcp.Seq = 101
|
||||
v = mgr.handle(1, l3, tcp, []byte{0x16, 0x03, 0x01}, nil, nil)
|
||||
if v != io.VerdictAccept {
|
||||
t.Fatalf("payload verdict=%v want=%v", v, io.VerdictAccept)
|
||||
}
|
||||
if newCalls != 1 || feedCalls != 1 {
|
||||
t.Fatalf("payload should create/feed analyzer once: new=%d feed=%d", newCalls, feedCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPFlowFinalizesAfterLogClassification(t *testing.T) {
|
||||
node, err := snowflake.NewNode(0)
|
||||
if err != nil {
|
||||
t.Fatalf("create node: %v", err)
|
||||
}
|
||||
closeCalls := 0
|
||||
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{}))
|
||||
mgr.updateRuleset(logFinalizingRuleset{
|
||||
ans: []analyzer.Analyzer{requestPropTCPAnalyzer{closeCalls: &closeCalls}},
|
||||
}, 0)
|
||||
|
||||
l3 := L3Info{
|
||||
Version: 4,
|
||||
Protocol: 6,
|
||||
SrcIP: [4]byte{10, 0, 0, 1},
|
||||
DstIP: [4]byte{10, 0, 0, 2},
|
||||
}
|
||||
tcp := TCPInfo{
|
||||
SrcPort: 12345,
|
||||
DstPort: 443,
|
||||
Seq: 100,
|
||||
}
|
||||
|
||||
v := mgr.handle(1, l3, tcp, nil, nil, nil)
|
||||
if v != io.VerdictAccept {
|
||||
t.Fatalf("empty packet verdict=%v want=%v", v, io.VerdictAccept)
|
||||
}
|
||||
|
||||
tcp.Seq = 101
|
||||
v = mgr.handle(1, l3, tcp, []byte{0x16, 0x03, 0x01}, nil, nil)
|
||||
if v != io.VerdictAcceptStream {
|
||||
t.Fatalf("payload verdict=%v want=%v", v, io.VerdictAcceptStream)
|
||||
}
|
||||
if closeCalls != 1 {
|
||||
t.Fatalf("expected analyzer to be closed once after finalization, got %d", closeCalls)
|
||||
}
|
||||
if _, ok := mgr.flows[1]; ok {
|
||||
t.Fatal("expected finalized TCP flow to be removed from manager")
|
||||
}
|
||||
}
|
||||
|
||||
+75
-4
@@ -27,6 +27,8 @@ type tcpFlow struct {
|
||||
streamID uint32
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
srcIP net.IP
|
||||
dstIP net.IP
|
||||
|
||||
dirSeq [2]uint32
|
||||
dirBuf [2][]byte
|
||||
@@ -41,6 +43,9 @@ type tcpFlow struct {
|
||||
lastVerdict io.Verdict
|
||||
feedCalled [2]bool
|
||||
lastSeen time.Time
|
||||
|
||||
pendingAnalyzers []analyzer.Analyzer
|
||||
selector *analyzerSelector
|
||||
}
|
||||
|
||||
type tcpFlowEntry struct {
|
||||
@@ -54,7 +59,7 @@ func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
|
||||
rs, version := f.currentRuleset()
|
||||
rulesetChanged := version != f.rulesetVersion
|
||||
|
||||
if !f.virgin && !rulesetChanged && len(f.activeEntries) == 0 {
|
||||
if !f.virgin && !rulesetChanged && len(f.activeEntries) == 0 && !f.hasPendingAnalyzers() {
|
||||
return f.lastVerdict
|
||||
}
|
||||
|
||||
@@ -68,6 +73,9 @@ func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
|
||||
propUpdated := false
|
||||
if len(payload) > 0 {
|
||||
dir, rev := f.resolveDirection(tcp)
|
||||
if len(f.pendingAnalyzers) > 0 {
|
||||
f.initPendingAnalyzers(payload)
|
||||
}
|
||||
expected := f.dirSeq[dir]
|
||||
if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected {
|
||||
f.feedCalled[dir] = true
|
||||
@@ -108,6 +116,52 @@ func (f *tcpFlow) feedAnalyzers(rev bool) bool {
|
||||
return updated
|
||||
}
|
||||
|
||||
func (f *tcpFlow) initPendingAnalyzers(payload []byte) {
|
||||
baseAns := f.pendingAnalyzers
|
||||
f.pendingAnalyzers = nil
|
||||
if f.selector != nil {
|
||||
baseAns = f.selector.SelectTCP(baseAns, payload)
|
||||
}
|
||||
ans := analyzersToTCPAnalyzers(baseAns)
|
||||
if len(ans) == 0 {
|
||||
return
|
||||
}
|
||||
entries := make([]*tcpFlowEntry, 0, len(ans))
|
||||
for _, a := range ans {
|
||||
entries = append(entries, &tcpFlowEntry{
|
||||
Name: a.Name(),
|
||||
Stream: a.NewTCP(analyzer.TCPInfo{
|
||||
SrcIP: f.srcIP,
|
||||
DstIP: f.dstIP,
|
||||
SrcPort: f.srcPort,
|
||||
DstPort: f.dstPort,
|
||||
}, &analyzerLogger{
|
||||
StreamID: f.info.ID,
|
||||
Name: a.Name(),
|
||||
Logger: f.logger,
|
||||
}),
|
||||
HasLimit: a.Limit() > 0,
|
||||
Quota: a.Limit(),
|
||||
})
|
||||
}
|
||||
f.activeEntries = append(f.activeEntries, entries...)
|
||||
}
|
||||
|
||||
func (f *tcpFlow) hasPendingAnalyzers() bool {
|
||||
return len(f.pendingAnalyzers) > 0
|
||||
}
|
||||
|
||||
func (f *tcpFlow) analyzerNames() []string {
|
||||
names := make([]string, 0, len(f.activeEntries)+len(f.pendingAnalyzers))
|
||||
for _, entry := range f.activeEntries {
|
||||
names = append(names, entry.Name)
|
||||
}
|
||||
for _, a := range f.pendingAnalyzers {
|
||||
names = append(names, a.Name())
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func (f *tcpFlow) runMatch(rs ruleset.Ruleset, version uint64, rulesetChanged bool, propUpdated bool) {
|
||||
if !propUpdated && !f.virgin && !rulesetChanged {
|
||||
return
|
||||
@@ -125,11 +179,15 @@ func (f *tcpFlow) runMatch(rs ruleset.Ruleset, version uint64, rulesetChanged bo
|
||||
f.lastVerdict = verdict
|
||||
f.closeActiveEntries()
|
||||
f.logger.TCPStreamAction(f.info, action, false)
|
||||
} else if result.Logged && canFinalizeAfterLog(rs, f.info, f.analyzerNames()) {
|
||||
f.lastVerdict = io.VerdictAcceptStream
|
||||
f.closeActiveEntries()
|
||||
f.logger.TCPStreamAction(f.info, ruleset.ActionAllow, true)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *tcpFlow) maybeFinalizeVerdict() {
|
||||
if len(f.activeEntries) == 0 && f.lastVerdict == io.VerdictAccept {
|
||||
if len(f.activeEntries) == 0 && !f.hasPendingAnalyzers() && f.lastVerdict == io.VerdictAccept {
|
||||
f.lastVerdict = io.VerdictAcceptStream
|
||||
f.logger.TCPStreamAction(f.info, ruleset.ActionAllow, true)
|
||||
}
|
||||
@@ -231,8 +289,10 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
|
||||
var ans []analyzer.TCPAnalyzer
|
||||
if rs != nil {
|
||||
baseAns := rs.Analyzers(info)
|
||||
baseAns = m.selector.SelectTCP(baseAns, payload)
|
||||
ans = analyzersToTCPAnalyzers(baseAns)
|
||||
if len(payload) > 0 {
|
||||
baseAns = m.selector.SelectTCP(baseAns, payload)
|
||||
ans = analyzersToTCPAnalyzers(baseAns)
|
||||
}
|
||||
}
|
||||
entries := make([]*tcpFlowEntry, 0, len(ans))
|
||||
for _, a := range ans {
|
||||
@@ -257,6 +317,8 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
|
||||
streamID: streamID,
|
||||
srcPort: tcp.SrcPort,
|
||||
dstPort: tcp.DstPort,
|
||||
srcIP: ipSrc,
|
||||
dstIP: ipDst,
|
||||
info: info,
|
||||
virgin: true,
|
||||
logger: m.logger,
|
||||
@@ -265,6 +327,10 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
|
||||
activeEntries: entries,
|
||||
lastVerdict: io.VerdictAccept,
|
||||
lastSeen: time.Now(),
|
||||
selector: m.selector,
|
||||
}
|
||||
if len(payload) == 0 && rs != nil {
|
||||
flow.pendingAnalyzers = rs.Analyzers(info)
|
||||
}
|
||||
flow.dirSeq[tcpDirC2S] = tcp.Seq + 1
|
||||
return flow
|
||||
@@ -325,3 +391,8 @@ func actionToTCPVerdict(a ruleset.Action) io.Verdict {
|
||||
return io.VerdictAcceptStream
|
||||
}
|
||||
}
|
||||
|
||||
func canFinalizeAfterLog(rs ruleset.Ruleset, info ruleset.StreamInfo, activeAnalyzers []string) bool {
|
||||
finalizer, ok := rs.(ruleset.LogFinalizer)
|
||||
return ok && finalizer.CanFinalizeAfterLog(info, activeAnalyzers)
|
||||
}
|
||||
|
||||
+75
-24
@@ -2,6 +2,7 @@ package engine
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"container/list"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
@@ -12,7 +13,6 @@ import (
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||
|
||||
"github.com/bwmarrin/snowflake"
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
)
|
||||
|
||||
// udpVerdict is a subset of io.Verdict for UDP streams.
|
||||
@@ -116,15 +116,18 @@ func (f *udpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) {
|
||||
|
||||
type udpStreamManager struct {
|
||||
factory *udpStreamFactory
|
||||
streams *lru.Cache[uint32, *udpStreamValue]
|
||||
streams map[uint32]*list.Element
|
||||
order *list.List
|
||||
maxStreams int
|
||||
tupleIndex map[udpTupleKey]uint32
|
||||
streamTuples map[uint32]udpTupleKey
|
||||
stats *statsCounters
|
||||
}
|
||||
|
||||
type udpStreamValue struct {
|
||||
Stream *udpStream
|
||||
Tuple udpTupleKey
|
||||
StreamID uint32
|
||||
Stream *udpStream
|
||||
Tuple udpTupleKey
|
||||
}
|
||||
|
||||
func (v *udpStreamValue) Match(k udpTupleKey) (ok, rev bool) {
|
||||
@@ -143,27 +146,23 @@ type udpTupleKey struct {
|
||||
}
|
||||
|
||||
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) {
|
||||
if maxStreams <= 0 {
|
||||
maxStreams = 1
|
||||
}
|
||||
m := &udpStreamManager{
|
||||
factory: factory,
|
||||
streams: make(map[uint32]*list.Element, maxStreams),
|
||||
order: list.New(),
|
||||
maxStreams: maxStreams,
|
||||
tupleIndex: make(map[udpTupleKey]uint32, maxStreams),
|
||||
streamTuples: make(map[uint32]udpTupleKey, maxStreams),
|
||||
stats: stats,
|
||||
}
|
||||
ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) {
|
||||
if v != nil && v.Stream != nil {
|
||||
v.Stream.Close()
|
||||
}
|
||||
m.removeTupleMappingLocked(k)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.streams = ss
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, rev bool, payload []byte, uc *udpContext) {
|
||||
value, ok := m.streams.Get(streamID)
|
||||
value, ok := m.get(streamID)
|
||||
if !ok {
|
||||
if m.stats != nil {
|
||||
m.stats.UDPTupleLookups.Add(1)
|
||||
@@ -176,7 +175,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey,
|
||||
m.stats.UDPTupleHits.Add(1)
|
||||
}
|
||||
var hasValue bool
|
||||
matchedValue, hasValue = m.streams.Get(matchedKey)
|
||||
matchedValue, hasValue = m.get(matchedKey)
|
||||
if !hasValue || matchedValue == nil {
|
||||
delete(m.tupleIndex, tuple)
|
||||
delete(m.streamTuples, matchedKey)
|
||||
@@ -188,16 +187,18 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey,
|
||||
value = matchedValue
|
||||
rev = matchedRev
|
||||
if matchedKey != streamID {
|
||||
m.streams.Remove(matchedKey)
|
||||
m.streams.Add(streamID, matchedValue)
|
||||
m.remove(matchedKey, false)
|
||||
matchedValue.StreamID = streamID
|
||||
m.add(streamID, matchedValue)
|
||||
m.bindTupleLocked(streamID, tuple)
|
||||
}
|
||||
} else {
|
||||
value = &udpStreamValue{
|
||||
Stream: m.factory.New(tuple, payload, uc),
|
||||
Tuple: tuple,
|
||||
StreamID: streamID,
|
||||
Stream: m.factory.New(tuple, payload, uc),
|
||||
Tuple: tuple,
|
||||
}
|
||||
m.streams.Add(streamID, value)
|
||||
m.add(streamID, value)
|
||||
m.bindTupleLocked(streamID, tuple)
|
||||
}
|
||||
} else {
|
||||
@@ -205,10 +206,11 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey,
|
||||
if !ok {
|
||||
value.Stream.Close()
|
||||
value = &udpStreamValue{
|
||||
Stream: m.factory.New(tuple, payload, uc),
|
||||
Tuple: tuple,
|
||||
StreamID: streamID,
|
||||
Stream: m.factory.New(tuple, payload, uc),
|
||||
Tuple: tuple,
|
||||
}
|
||||
m.streams.Add(streamID, value)
|
||||
m.add(streamID, value)
|
||||
m.bindTupleLocked(streamID, tuple)
|
||||
}
|
||||
}
|
||||
@@ -217,6 +219,55 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpStreamManager) get(streamID uint32) (*udpStreamValue, bool) {
|
||||
ele, ok := m.streams[streamID]
|
||||
if !ok || ele == nil {
|
||||
return nil, false
|
||||
}
|
||||
m.order.MoveToFront(ele)
|
||||
value, ok := ele.Value.(*udpStreamValue)
|
||||
return value, ok && value != nil
|
||||
}
|
||||
|
||||
func (m *udpStreamManager) add(streamID uint32, value *udpStreamValue) {
|
||||
if value == nil {
|
||||
return
|
||||
}
|
||||
if existing, ok := m.streams[streamID]; ok {
|
||||
existing.Value = value
|
||||
m.order.MoveToFront(existing)
|
||||
return
|
||||
}
|
||||
value.StreamID = streamID
|
||||
m.streams[streamID] = m.order.PushFront(value)
|
||||
for len(m.streams) > m.maxStreams {
|
||||
back := m.order.Back()
|
||||
if back == nil {
|
||||
return
|
||||
}
|
||||
evicted, _ := back.Value.(*udpStreamValue)
|
||||
if evicted == nil {
|
||||
m.order.Remove(back)
|
||||
continue
|
||||
}
|
||||
m.remove(evicted.StreamID, true)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpStreamManager) remove(streamID uint32, closeStream bool) {
|
||||
ele, ok := m.streams[streamID]
|
||||
if !ok || ele == nil {
|
||||
return
|
||||
}
|
||||
value, _ := ele.Value.(*udpStreamValue)
|
||||
delete(m.streams, streamID)
|
||||
m.order.Remove(ele)
|
||||
m.removeTupleMappingLocked(streamID)
|
||||
if closeStream && value != nil && value.Stream != nil {
|
||||
value.Stream.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpStreamManager) bindTupleLocked(streamID uint32, key udpTupleKey) {
|
||||
m.removeTupleMappingLocked(streamID)
|
||||
m.tupleIndex[key] = streamID
|
||||
|
||||
Reference in New Issue
Block a user