engine: more performance improvements
This commit is contained in:
+6
-1
@@ -27,6 +27,7 @@ const (
|
|||||||
type engine struct {
|
type engine struct {
|
||||||
logger Logger
|
logger Logger
|
||||||
io io.PacketIO
|
io io.PacketIO
|
||||||
|
macResolver *sourceMACResolver
|
||||||
workers []*worker
|
workers []*worker
|
||||||
stats *statsCounters
|
stats *statsCounters
|
||||||
verdicts sync.Map // streamID(uint32) -> verdictEntry
|
verdicts sync.Map // streamID(uint32) -> verdictEntry
|
||||||
@@ -46,7 +47,7 @@ func NewEngine(config Config) (Engine, error) {
|
|||||||
}
|
}
|
||||||
overflowPolicy := config.OverflowPolicy
|
overflowPolicy := config.OverflowPolicy
|
||||||
if overflowPolicy == "" {
|
if overflowPolicy == "" {
|
||||||
overflowPolicy = OverflowPolicyDrop
|
overflowPolicy = OverflowPolicyAccept
|
||||||
}
|
}
|
||||||
selectionMode := config.AnalyzerSelectionMode
|
selectionMode := config.AnalyzerSelectionMode
|
||||||
if selectionMode == "" {
|
if selectionMode == "" {
|
||||||
@@ -80,6 +81,7 @@ func NewEngine(config Config) (Engine, error) {
|
|||||||
e := &engine{
|
e := &engine{
|
||||||
logger: config.Logger,
|
logger: config.Logger,
|
||||||
io: config.IO,
|
io: config.IO,
|
||||||
|
macResolver: macResolver,
|
||||||
workers: workers,
|
workers: workers,
|
||||||
stats: stats,
|
stats: stats,
|
||||||
overflowPolicy: overflowPolicy,
|
overflowPolicy: overflowPolicy,
|
||||||
@@ -105,6 +107,9 @@ func (e *engine) Run(ctx context.Context) error {
|
|||||||
for _, w := range e.workers {
|
for _, w := range e.workers {
|
||||||
go w.Run(ioCtx)
|
go w.Run(ioCtx)
|
||||||
}
|
}
|
||||||
|
if e.macResolver != nil {
|
||||||
|
go e.macResolver.Run(ioCtx)
|
||||||
|
}
|
||||||
go e.drainResults(ioCtx)
|
go e.drainResults(ioCtx)
|
||||||
go e.sweepVerdicts(ioCtx)
|
go e.sweepVerdicts(ioCtx)
|
||||||
|
|
||||||
|
|||||||
+32
-43
@@ -5,6 +5,7 @@ package engine
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
@@ -52,38 +53,6 @@ func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
r.mu.RLock()
|
|
||||||
ifaceRefreshDue := now.Sub(r.lastIfaceRefresh) > ifaceCacheTTL
|
|
||||||
arpRefreshDue := now.Sub(r.lastARPRefresh) > arpCacheTTL
|
|
||||||
ndpRefreshDue := now.Sub(r.lastNDPRefresh) > ndpCacheTTL
|
|
||||||
if mac := r.ifaceByIP[ipKey]; len(mac) != 0 {
|
|
||||||
out := append(net.HardwareAddr(nil), mac...)
|
|
||||||
r.mu.RUnlock()
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
if mac := r.arpByIP[ipKey]; len(mac) != 0 && !arpRefreshDue {
|
|
||||||
out := append(net.HardwareAddr(nil), mac...)
|
|
||||||
r.mu.RUnlock()
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
if mac := r.ndpByIP[ipKey]; len(mac) != 0 && !ndpRefreshDue {
|
|
||||||
out := append(net.HardwareAddr(nil), mac...)
|
|
||||||
r.mu.RUnlock()
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
r.mu.RUnlock()
|
|
||||||
|
|
||||||
if ifaceRefreshDue {
|
|
||||||
r.refreshIfaceCache(now)
|
|
||||||
}
|
|
||||||
if arpRefreshDue {
|
|
||||||
r.refreshARPCache(now)
|
|
||||||
}
|
|
||||||
if ndpRefreshDue {
|
|
||||||
r.refreshNDPCache(now)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.mu.RLock()
|
r.mu.RLock()
|
||||||
defer r.mu.RUnlock()
|
defer r.mu.RUnlock()
|
||||||
if mac := r.ifaceByIP[ipKey]; len(mac) != 0 {
|
if mac := r.ifaceByIP[ipKey]; len(mac) != 0 {
|
||||||
@@ -95,20 +64,40 @@ func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr {
|
|||||||
if mac := r.ndpByIP[ipKey]; len(mac) != 0 {
|
if mac := r.ndpByIP[ipKey]; len(mac) != 0 {
|
||||||
return append(net.HardwareAddr(nil), mac...)
|
return append(net.HardwareAddr(nil), mac...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// On-demand IPv6 neighbor lookup via route-netlink as a last fast path.
|
|
||||||
if ip.To4() == nil {
|
|
||||||
if mac, ok := lookupNeighborMACNetlink(ip); ok {
|
|
||||||
out := append(net.HardwareAddr(nil), mac...)
|
|
||||||
r.mu.Lock()
|
|
||||||
r.ndpByIP[ipKey] = append(net.HardwareAddr(nil), mac...)
|
|
||||||
r.mu.Unlock()
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *sourceMACResolver) Run(ctx context.Context) {
|
||||||
|
r.refreshAll(time.Now())
|
||||||
|
ticker := time.NewTicker(arpCacheTTL)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case now := <-ticker.C:
|
||||||
|
r.refreshAll(now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *sourceMACResolver) refreshAll(now time.Time) {
|
||||||
|
r.mu.RLock()
|
||||||
|
ifaceRefreshDue := now.Sub(r.lastIfaceRefresh) > ifaceCacheTTL
|
||||||
|
arpRefreshDue := now.Sub(r.lastARPRefresh) > arpCacheTTL
|
||||||
|
ndpRefreshDue := now.Sub(r.lastNDPRefresh) > ndpCacheTTL
|
||||||
|
r.mu.RUnlock()
|
||||||
|
if ifaceRefreshDue {
|
||||||
|
r.refreshIfaceCache(now)
|
||||||
|
}
|
||||||
|
if arpRefreshDue {
|
||||||
|
r.refreshARPCache(now)
|
||||||
|
}
|
||||||
|
if ndpRefreshDue {
|
||||||
|
r.refreshNDPCache(now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *sourceMACResolver) refreshIfaceCache(now time.Time) {
|
func (r *sourceMACResolver) refreshIfaceCache(now time.Time) {
|
||||||
interfaces, err := net.Interfaces()
|
interfaces, err := net.Interfaces()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,7 +3,10 @@
|
|||||||
|
|
||||||
package engine
|
package engine
|
||||||
|
|
||||||
import "net"
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
type sourceMACResolver struct{}
|
type sourceMACResolver struct{}
|
||||||
|
|
||||||
@@ -15,3 +18,7 @@ func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr {
|
|||||||
_ = ip
|
_ = ip
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *sourceMACResolver) Run(ctx context.Context) {
|
||||||
|
<-ctx.Done()
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,70 @@
|
|||||||
|
package engine
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.difuse.io/Difuse/Mellaris/io"
|
||||||
|
)
|
||||||
|
|
||||||
|
type recordingPacket struct {
|
||||||
|
streamID uint32
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p recordingPacket) StreamID() uint32 { return p.streamID }
|
||||||
|
func (p recordingPacket) Data() []byte { return p.data }
|
||||||
|
|
||||||
|
type recordingPacketIO struct {
|
||||||
|
verdict io.Verdict
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *recordingPacketIO) Register(context.Context, io.PacketCallback) error { return nil }
|
||||||
|
func (r *recordingPacketIO) SetVerdict(_ io.Packet, v io.Verdict, _ []byte) error {
|
||||||
|
r.verdict = v
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (r *recordingPacketIO) ProtectedDialContext(context.Context, string, string) (net.Conn, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (r *recordingPacketIO) Close() error { return nil }
|
||||||
|
|
||||||
|
func TestEngineDefaultOverflowPolicyAccepts(t *testing.T) {
|
||||||
|
packetIO := &recordingPacketIO{}
|
||||||
|
eng, err := NewEngine(Config{
|
||||||
|
Logger: noopTestLogger{},
|
||||||
|
IO: packetIO,
|
||||||
|
Workers: 1,
|
||||||
|
WorkerQueueSize: 1,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewEngine error: %v", err)
|
||||||
|
}
|
||||||
|
e := eng.(*engine)
|
||||||
|
if e.overflowPolicy != OverflowPolicyAccept {
|
||||||
|
t.Fatalf("overflow policy=%v want=%v", e.overflowPolicy, OverflowPolicyAccept)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.workers[0].packetChan <- &workerPacket{}
|
||||||
|
packet := recordingPacket{
|
||||||
|
streamID: 1,
|
||||||
|
data: serializeIPv6TCP(
|
||||||
|
t,
|
||||||
|
net.ParseIP("2001:db8::11").To16(),
|
||||||
|
net.ParseIP("2001:db8::22").To16(),
|
||||||
|
42310,
|
||||||
|
443,
|
||||||
|
1000,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
e.dispatch(packet)
|
||||||
|
stats := e.Stats()
|
||||||
|
if packetIO.verdict != io.VerdictAccept {
|
||||||
|
t.Fatalf("overflow verdict=%v want=%v", packetIO.verdict, io.VerdictAccept)
|
||||||
|
}
|
||||||
|
if stats.OverflowEvents != 1 || stats.OverflowAccepts != 1 || stats.OverflowDrops != 0 {
|
||||||
|
t.Fatalf("overflow stats=%+v", stats)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -22,6 +22,89 @@ func (r fixedRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult {
|
|||||||
return ruleset.MatchResult{Action: r.action}
|
return ruleset.MatchResult{Action: r.action}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type analyzerRuleset struct {
|
||||||
|
action ruleset.Action
|
||||||
|
ans []analyzer.Analyzer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r analyzerRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer {
|
||||||
|
return r.ans
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r analyzerRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult {
|
||||||
|
return ruleset.MatchResult{Action: r.action}
|
||||||
|
}
|
||||||
|
|
||||||
|
type countingTCPAnalyzer struct {
|
||||||
|
newCalls *int
|
||||||
|
feedCalls *int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a countingTCPAnalyzer) Name() string { return "tls" }
|
||||||
|
func (a countingTCPAnalyzer) Limit() int { return 0 }
|
||||||
|
func (a countingTCPAnalyzer) NewTCP(analyzer.TCPInfo, analyzer.Logger) analyzer.TCPStream {
|
||||||
|
(*a.newCalls)++
|
||||||
|
return countingTCPStream{feedCalls: a.feedCalls}
|
||||||
|
}
|
||||||
|
|
||||||
|
type countingTCPStream struct {
|
||||||
|
feedCalls *int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s countingTCPStream) Feed(bool, bool, bool, int, []byte) (*analyzer.PropUpdate, bool) {
|
||||||
|
(*s.feedCalls)++
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s countingTCPStream) Close(bool) *analyzer.PropUpdate {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type logFinalizingRuleset struct {
|
||||||
|
ans []analyzer.Analyzer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r logFinalizingRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer {
|
||||||
|
return r.ans
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r logFinalizingRuleset) Match(info ruleset.StreamInfo) ruleset.MatchResult {
|
||||||
|
if _, ok := info.Props["tls"]; ok {
|
||||||
|
return ruleset.MatchResult{Action: ruleset.ActionMaybe, Logged: true}
|
||||||
|
}
|
||||||
|
return ruleset.MatchResult{Action: ruleset.ActionMaybe}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r logFinalizingRuleset) CanFinalizeAfterLog(ruleset.StreamInfo, []string) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestPropTCPAnalyzer struct {
|
||||||
|
closeCalls *int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a requestPropTCPAnalyzer) Name() string { return "tls" }
|
||||||
|
func (a requestPropTCPAnalyzer) Limit() int { return 0 }
|
||||||
|
func (a requestPropTCPAnalyzer) NewTCP(analyzer.TCPInfo, analyzer.Logger) analyzer.TCPStream {
|
||||||
|
return requestPropTCPStream{closeCalls: a.closeCalls}
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestPropTCPStream struct {
|
||||||
|
closeCalls *int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s requestPropTCPStream) Feed(bool, bool, bool, int, []byte) (*analyzer.PropUpdate, bool) {
|
||||||
|
return &analyzer.PropUpdate{
|
||||||
|
Type: analyzer.PropUpdateMerge,
|
||||||
|
M: analyzer.PropMap{"req": analyzer.PropMap{"sni": "good.example"}},
|
||||||
|
}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s requestPropTCPStream) Close(bool) *analyzer.PropUpdate {
|
||||||
|
(*s.closeCalls)++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type noopTestLogger struct{}
|
type noopTestLogger struct{}
|
||||||
|
|
||||||
func (noopTestLogger) WorkerStart(int) {}
|
func (noopTestLogger) WorkerStart(int) {}
|
||||||
@@ -207,3 +290,90 @@ func TestTCPFlowReevaluatesAfterRulesetVersionChange(t *testing.T) {
|
|||||||
t.Fatalf("cached verdict after update=%v want=%v", v, io.VerdictDropStream)
|
t.Fatalf("cached verdict after update=%v want=%v", v, io.VerdictDropStream)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTCPFlowDelaysAnalyzerCreationUntilPayload(t *testing.T) {
|
||||||
|
node, err := snowflake.NewNode(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create node: %v", err)
|
||||||
|
}
|
||||||
|
newCalls := 0
|
||||||
|
feedCalls := 0
|
||||||
|
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{}))
|
||||||
|
mgr.updateRuleset(analyzerRuleset{
|
||||||
|
action: ruleset.ActionMaybe,
|
||||||
|
ans: []analyzer.Analyzer{countingTCPAnalyzer{
|
||||||
|
newCalls: &newCalls,
|
||||||
|
feedCalls: &feedCalls,
|
||||||
|
}},
|
||||||
|
}, 0)
|
||||||
|
|
||||||
|
l3 := L3Info{
|
||||||
|
Version: 4,
|
||||||
|
Protocol: 6,
|
||||||
|
SrcIP: [4]byte{10, 0, 0, 1},
|
||||||
|
DstIP: [4]byte{10, 0, 0, 2},
|
||||||
|
}
|
||||||
|
tcp := TCPInfo{
|
||||||
|
SrcPort: 12345,
|
||||||
|
DstPort: 443,
|
||||||
|
Seq: 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
v := mgr.handle(1, l3, tcp, nil, nil, nil)
|
||||||
|
if v != io.VerdictAccept {
|
||||||
|
t.Fatalf("empty packet verdict=%v want=%v", v, io.VerdictAccept)
|
||||||
|
}
|
||||||
|
if newCalls != 0 || feedCalls != 0 {
|
||||||
|
t.Fatalf("empty packet created/feed analyzer: new=%d feed=%d", newCalls, feedCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
tcp.Seq = 101
|
||||||
|
v = mgr.handle(1, l3, tcp, []byte{0x16, 0x03, 0x01}, nil, nil)
|
||||||
|
if v != io.VerdictAccept {
|
||||||
|
t.Fatalf("payload verdict=%v want=%v", v, io.VerdictAccept)
|
||||||
|
}
|
||||||
|
if newCalls != 1 || feedCalls != 1 {
|
||||||
|
t.Fatalf("payload should create/feed analyzer once: new=%d feed=%d", newCalls, feedCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPFlowFinalizesAfterLogClassification(t *testing.T) {
|
||||||
|
node, err := snowflake.NewNode(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create node: %v", err)
|
||||||
|
}
|
||||||
|
closeCalls := 0
|
||||||
|
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{}))
|
||||||
|
mgr.updateRuleset(logFinalizingRuleset{
|
||||||
|
ans: []analyzer.Analyzer{requestPropTCPAnalyzer{closeCalls: &closeCalls}},
|
||||||
|
}, 0)
|
||||||
|
|
||||||
|
l3 := L3Info{
|
||||||
|
Version: 4,
|
||||||
|
Protocol: 6,
|
||||||
|
SrcIP: [4]byte{10, 0, 0, 1},
|
||||||
|
DstIP: [4]byte{10, 0, 0, 2},
|
||||||
|
}
|
||||||
|
tcp := TCPInfo{
|
||||||
|
SrcPort: 12345,
|
||||||
|
DstPort: 443,
|
||||||
|
Seq: 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
v := mgr.handle(1, l3, tcp, nil, nil, nil)
|
||||||
|
if v != io.VerdictAccept {
|
||||||
|
t.Fatalf("empty packet verdict=%v want=%v", v, io.VerdictAccept)
|
||||||
|
}
|
||||||
|
|
||||||
|
tcp.Seq = 101
|
||||||
|
v = mgr.handle(1, l3, tcp, []byte{0x16, 0x03, 0x01}, nil, nil)
|
||||||
|
if v != io.VerdictAcceptStream {
|
||||||
|
t.Fatalf("payload verdict=%v want=%v", v, io.VerdictAcceptStream)
|
||||||
|
}
|
||||||
|
if closeCalls != 1 {
|
||||||
|
t.Fatalf("expected analyzer to be closed once after finalization, got %d", closeCalls)
|
||||||
|
}
|
||||||
|
if _, ok := mgr.flows[1]; ok {
|
||||||
|
t.Fatal("expected finalized TCP flow to be removed from manager")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+73
-2
@@ -27,6 +27,8 @@ type tcpFlow struct {
|
|||||||
streamID uint32
|
streamID uint32
|
||||||
srcPort uint16
|
srcPort uint16
|
||||||
dstPort uint16
|
dstPort uint16
|
||||||
|
srcIP net.IP
|
||||||
|
dstIP net.IP
|
||||||
|
|
||||||
dirSeq [2]uint32
|
dirSeq [2]uint32
|
||||||
dirBuf [2][]byte
|
dirBuf [2][]byte
|
||||||
@@ -41,6 +43,9 @@ type tcpFlow struct {
|
|||||||
lastVerdict io.Verdict
|
lastVerdict io.Verdict
|
||||||
feedCalled [2]bool
|
feedCalled [2]bool
|
||||||
lastSeen time.Time
|
lastSeen time.Time
|
||||||
|
|
||||||
|
pendingAnalyzers []analyzer.Analyzer
|
||||||
|
selector *analyzerSelector
|
||||||
}
|
}
|
||||||
|
|
||||||
type tcpFlowEntry struct {
|
type tcpFlowEntry struct {
|
||||||
@@ -54,7 +59,7 @@ func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
|
|||||||
rs, version := f.currentRuleset()
|
rs, version := f.currentRuleset()
|
||||||
rulesetChanged := version != f.rulesetVersion
|
rulesetChanged := version != f.rulesetVersion
|
||||||
|
|
||||||
if !f.virgin && !rulesetChanged && len(f.activeEntries) == 0 {
|
if !f.virgin && !rulesetChanged && len(f.activeEntries) == 0 && !f.hasPendingAnalyzers() {
|
||||||
return f.lastVerdict
|
return f.lastVerdict
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,6 +73,9 @@ func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
|
|||||||
propUpdated := false
|
propUpdated := false
|
||||||
if len(payload) > 0 {
|
if len(payload) > 0 {
|
||||||
dir, rev := f.resolveDirection(tcp)
|
dir, rev := f.resolveDirection(tcp)
|
||||||
|
if len(f.pendingAnalyzers) > 0 {
|
||||||
|
f.initPendingAnalyzers(payload)
|
||||||
|
}
|
||||||
expected := f.dirSeq[dir]
|
expected := f.dirSeq[dir]
|
||||||
if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected {
|
if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected {
|
||||||
f.feedCalled[dir] = true
|
f.feedCalled[dir] = true
|
||||||
@@ -108,6 +116,52 @@ func (f *tcpFlow) feedAnalyzers(rev bool) bool {
|
|||||||
return updated
|
return updated
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *tcpFlow) initPendingAnalyzers(payload []byte) {
|
||||||
|
baseAns := f.pendingAnalyzers
|
||||||
|
f.pendingAnalyzers = nil
|
||||||
|
if f.selector != nil {
|
||||||
|
baseAns = f.selector.SelectTCP(baseAns, payload)
|
||||||
|
}
|
||||||
|
ans := analyzersToTCPAnalyzers(baseAns)
|
||||||
|
if len(ans) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entries := make([]*tcpFlowEntry, 0, len(ans))
|
||||||
|
for _, a := range ans {
|
||||||
|
entries = append(entries, &tcpFlowEntry{
|
||||||
|
Name: a.Name(),
|
||||||
|
Stream: a.NewTCP(analyzer.TCPInfo{
|
||||||
|
SrcIP: f.srcIP,
|
||||||
|
DstIP: f.dstIP,
|
||||||
|
SrcPort: f.srcPort,
|
||||||
|
DstPort: f.dstPort,
|
||||||
|
}, &analyzerLogger{
|
||||||
|
StreamID: f.info.ID,
|
||||||
|
Name: a.Name(),
|
||||||
|
Logger: f.logger,
|
||||||
|
}),
|
||||||
|
HasLimit: a.Limit() > 0,
|
||||||
|
Quota: a.Limit(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
f.activeEntries = append(f.activeEntries, entries...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *tcpFlow) hasPendingAnalyzers() bool {
|
||||||
|
return len(f.pendingAnalyzers) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *tcpFlow) analyzerNames() []string {
|
||||||
|
names := make([]string, 0, len(f.activeEntries)+len(f.pendingAnalyzers))
|
||||||
|
for _, entry := range f.activeEntries {
|
||||||
|
names = append(names, entry.Name)
|
||||||
|
}
|
||||||
|
for _, a := range f.pendingAnalyzers {
|
||||||
|
names = append(names, a.Name())
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
func (f *tcpFlow) runMatch(rs ruleset.Ruleset, version uint64, rulesetChanged bool, propUpdated bool) {
|
func (f *tcpFlow) runMatch(rs ruleset.Ruleset, version uint64, rulesetChanged bool, propUpdated bool) {
|
||||||
if !propUpdated && !f.virgin && !rulesetChanged {
|
if !propUpdated && !f.virgin && !rulesetChanged {
|
||||||
return
|
return
|
||||||
@@ -125,11 +179,15 @@ func (f *tcpFlow) runMatch(rs ruleset.Ruleset, version uint64, rulesetChanged bo
|
|||||||
f.lastVerdict = verdict
|
f.lastVerdict = verdict
|
||||||
f.closeActiveEntries()
|
f.closeActiveEntries()
|
||||||
f.logger.TCPStreamAction(f.info, action, false)
|
f.logger.TCPStreamAction(f.info, action, false)
|
||||||
|
} else if result.Logged && canFinalizeAfterLog(rs, f.info, f.analyzerNames()) {
|
||||||
|
f.lastVerdict = io.VerdictAcceptStream
|
||||||
|
f.closeActiveEntries()
|
||||||
|
f.logger.TCPStreamAction(f.info, ruleset.ActionAllow, true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *tcpFlow) maybeFinalizeVerdict() {
|
func (f *tcpFlow) maybeFinalizeVerdict() {
|
||||||
if len(f.activeEntries) == 0 && f.lastVerdict == io.VerdictAccept {
|
if len(f.activeEntries) == 0 && !f.hasPendingAnalyzers() && f.lastVerdict == io.VerdictAccept {
|
||||||
f.lastVerdict = io.VerdictAcceptStream
|
f.lastVerdict = io.VerdictAcceptStream
|
||||||
f.logger.TCPStreamAction(f.info, ruleset.ActionAllow, true)
|
f.logger.TCPStreamAction(f.info, ruleset.ActionAllow, true)
|
||||||
}
|
}
|
||||||
@@ -231,9 +289,11 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
|
|||||||
var ans []analyzer.TCPAnalyzer
|
var ans []analyzer.TCPAnalyzer
|
||||||
if rs != nil {
|
if rs != nil {
|
||||||
baseAns := rs.Analyzers(info)
|
baseAns := rs.Analyzers(info)
|
||||||
|
if len(payload) > 0 {
|
||||||
baseAns = m.selector.SelectTCP(baseAns, payload)
|
baseAns = m.selector.SelectTCP(baseAns, payload)
|
||||||
ans = analyzersToTCPAnalyzers(baseAns)
|
ans = analyzersToTCPAnalyzers(baseAns)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
entries := make([]*tcpFlowEntry, 0, len(ans))
|
entries := make([]*tcpFlowEntry, 0, len(ans))
|
||||||
for _, a := range ans {
|
for _, a := range ans {
|
||||||
entries = append(entries, &tcpFlowEntry{
|
entries = append(entries, &tcpFlowEntry{
|
||||||
@@ -257,6 +317,8 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
|
|||||||
streamID: streamID,
|
streamID: streamID,
|
||||||
srcPort: tcp.SrcPort,
|
srcPort: tcp.SrcPort,
|
||||||
dstPort: tcp.DstPort,
|
dstPort: tcp.DstPort,
|
||||||
|
srcIP: ipSrc,
|
||||||
|
dstIP: ipDst,
|
||||||
info: info,
|
info: info,
|
||||||
virgin: true,
|
virgin: true,
|
||||||
logger: m.logger,
|
logger: m.logger,
|
||||||
@@ -265,6 +327,10 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
|
|||||||
activeEntries: entries,
|
activeEntries: entries,
|
||||||
lastVerdict: io.VerdictAccept,
|
lastVerdict: io.VerdictAccept,
|
||||||
lastSeen: time.Now(),
|
lastSeen: time.Now(),
|
||||||
|
selector: m.selector,
|
||||||
|
}
|
||||||
|
if len(payload) == 0 && rs != nil {
|
||||||
|
flow.pendingAnalyzers = rs.Analyzers(info)
|
||||||
}
|
}
|
||||||
flow.dirSeq[tcpDirC2S] = tcp.Seq + 1
|
flow.dirSeq[tcpDirC2S] = tcp.Seq + 1
|
||||||
return flow
|
return flow
|
||||||
@@ -325,3 +391,8 @@ func actionToTCPVerdict(a ruleset.Action) io.Verdict {
|
|||||||
return io.VerdictAcceptStream
|
return io.VerdictAcceptStream
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func canFinalizeAfterLog(rs ruleset.Ruleset, info ruleset.StreamInfo, activeAnalyzers []string) bool {
|
||||||
|
finalizer, ok := rs.(ruleset.LogFinalizer)
|
||||||
|
return ok && finalizer.CanFinalizeAfterLog(info, activeAnalyzers)
|
||||||
|
}
|
||||||
|
|||||||
+69
-18
@@ -2,6 +2,7 @@ package engine
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"container/list"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -12,7 +13,6 @@ import (
|
|||||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||||
|
|
||||||
"github.com/bwmarrin/snowflake"
|
"github.com/bwmarrin/snowflake"
|
||||||
lru "github.com/hashicorp/golang-lru/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// udpVerdict is a subset of io.Verdict for UDP streams.
|
// udpVerdict is a subset of io.Verdict for UDP streams.
|
||||||
@@ -116,13 +116,16 @@ func (f *udpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) {
|
|||||||
|
|
||||||
type udpStreamManager struct {
|
type udpStreamManager struct {
|
||||||
factory *udpStreamFactory
|
factory *udpStreamFactory
|
||||||
streams *lru.Cache[uint32, *udpStreamValue]
|
streams map[uint32]*list.Element
|
||||||
|
order *list.List
|
||||||
|
maxStreams int
|
||||||
tupleIndex map[udpTupleKey]uint32
|
tupleIndex map[udpTupleKey]uint32
|
||||||
streamTuples map[uint32]udpTupleKey
|
streamTuples map[uint32]udpTupleKey
|
||||||
stats *statsCounters
|
stats *statsCounters
|
||||||
}
|
}
|
||||||
|
|
||||||
type udpStreamValue struct {
|
type udpStreamValue struct {
|
||||||
|
StreamID uint32
|
||||||
Stream *udpStream
|
Stream *udpStream
|
||||||
Tuple udpTupleKey
|
Tuple udpTupleKey
|
||||||
}
|
}
|
||||||
@@ -143,27 +146,23 @@ type udpTupleKey struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) {
|
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) {
|
||||||
|
if maxStreams <= 0 {
|
||||||
|
maxStreams = 1
|
||||||
|
}
|
||||||
m := &udpStreamManager{
|
m := &udpStreamManager{
|
||||||
factory: factory,
|
factory: factory,
|
||||||
|
streams: make(map[uint32]*list.Element, maxStreams),
|
||||||
|
order: list.New(),
|
||||||
|
maxStreams: maxStreams,
|
||||||
tupleIndex: make(map[udpTupleKey]uint32, maxStreams),
|
tupleIndex: make(map[udpTupleKey]uint32, maxStreams),
|
||||||
streamTuples: make(map[uint32]udpTupleKey, maxStreams),
|
streamTuples: make(map[uint32]udpTupleKey, maxStreams),
|
||||||
stats: stats,
|
stats: stats,
|
||||||
}
|
}
|
||||||
ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) {
|
|
||||||
if v != nil && v.Stream != nil {
|
|
||||||
v.Stream.Close()
|
|
||||||
}
|
|
||||||
m.removeTupleMappingLocked(k)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
m.streams = ss
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, rev bool, payload []byte, uc *udpContext) {
|
func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, rev bool, payload []byte, uc *udpContext) {
|
||||||
value, ok := m.streams.Get(streamID)
|
value, ok := m.get(streamID)
|
||||||
if !ok {
|
if !ok {
|
||||||
if m.stats != nil {
|
if m.stats != nil {
|
||||||
m.stats.UDPTupleLookups.Add(1)
|
m.stats.UDPTupleLookups.Add(1)
|
||||||
@@ -176,7 +175,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey,
|
|||||||
m.stats.UDPTupleHits.Add(1)
|
m.stats.UDPTupleHits.Add(1)
|
||||||
}
|
}
|
||||||
var hasValue bool
|
var hasValue bool
|
||||||
matchedValue, hasValue = m.streams.Get(matchedKey)
|
matchedValue, hasValue = m.get(matchedKey)
|
||||||
if !hasValue || matchedValue == nil {
|
if !hasValue || matchedValue == nil {
|
||||||
delete(m.tupleIndex, tuple)
|
delete(m.tupleIndex, tuple)
|
||||||
delete(m.streamTuples, matchedKey)
|
delete(m.streamTuples, matchedKey)
|
||||||
@@ -188,16 +187,18 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey,
|
|||||||
value = matchedValue
|
value = matchedValue
|
||||||
rev = matchedRev
|
rev = matchedRev
|
||||||
if matchedKey != streamID {
|
if matchedKey != streamID {
|
||||||
m.streams.Remove(matchedKey)
|
m.remove(matchedKey, false)
|
||||||
m.streams.Add(streamID, matchedValue)
|
matchedValue.StreamID = streamID
|
||||||
|
m.add(streamID, matchedValue)
|
||||||
m.bindTupleLocked(streamID, tuple)
|
m.bindTupleLocked(streamID, tuple)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
value = &udpStreamValue{
|
value = &udpStreamValue{
|
||||||
|
StreamID: streamID,
|
||||||
Stream: m.factory.New(tuple, payload, uc),
|
Stream: m.factory.New(tuple, payload, uc),
|
||||||
Tuple: tuple,
|
Tuple: tuple,
|
||||||
}
|
}
|
||||||
m.streams.Add(streamID, value)
|
m.add(streamID, value)
|
||||||
m.bindTupleLocked(streamID, tuple)
|
m.bindTupleLocked(streamID, tuple)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -205,10 +206,11 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey,
|
|||||||
if !ok {
|
if !ok {
|
||||||
value.Stream.Close()
|
value.Stream.Close()
|
||||||
value = &udpStreamValue{
|
value = &udpStreamValue{
|
||||||
|
StreamID: streamID,
|
||||||
Stream: m.factory.New(tuple, payload, uc),
|
Stream: m.factory.New(tuple, payload, uc),
|
||||||
Tuple: tuple,
|
Tuple: tuple,
|
||||||
}
|
}
|
||||||
m.streams.Add(streamID, value)
|
m.add(streamID, value)
|
||||||
m.bindTupleLocked(streamID, tuple)
|
m.bindTupleLocked(streamID, tuple)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -217,6 +219,55 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *udpStreamManager) get(streamID uint32) (*udpStreamValue, bool) {
|
||||||
|
ele, ok := m.streams[streamID]
|
||||||
|
if !ok || ele == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
m.order.MoveToFront(ele)
|
||||||
|
value, ok := ele.Value.(*udpStreamValue)
|
||||||
|
return value, ok && value != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *udpStreamManager) add(streamID uint32, value *udpStreamValue) {
|
||||||
|
if value == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if existing, ok := m.streams[streamID]; ok {
|
||||||
|
existing.Value = value
|
||||||
|
m.order.MoveToFront(existing)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value.StreamID = streamID
|
||||||
|
m.streams[streamID] = m.order.PushFront(value)
|
||||||
|
for len(m.streams) > m.maxStreams {
|
||||||
|
back := m.order.Back()
|
||||||
|
if back == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
evicted, _ := back.Value.(*udpStreamValue)
|
||||||
|
if evicted == nil {
|
||||||
|
m.order.Remove(back)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m.remove(evicted.StreamID, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *udpStreamManager) remove(streamID uint32, closeStream bool) {
|
||||||
|
ele, ok := m.streams[streamID]
|
||||||
|
if !ok || ele == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value, _ := ele.Value.(*udpStreamValue)
|
||||||
|
delete(m.streams, streamID)
|
||||||
|
m.order.Remove(ele)
|
||||||
|
m.removeTupleMappingLocked(streamID)
|
||||||
|
if closeStream && value != nil && value.Stream != nil {
|
||||||
|
value.Stream.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *udpStreamManager) bindTupleLocked(streamID uint32, key udpTupleKey) {
|
func (m *udpStreamManager) bindTupleLocked(streamID uint32, key udpTupleKey) {
|
||||||
m.removeTupleMappingLocked(streamID)
|
m.removeTupleMappingLocked(streamID)
|
||||||
m.tupleIndex[key] = streamID
|
m.tupleIndex[key] = streamID
|
||||||
|
|||||||
@@ -112,15 +112,20 @@ type geositeDomain struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type geositeMatcher struct {
|
type geositeMatcher struct {
|
||||||
Domains []geositeDomain
|
Domains []geositeDomain // legacy slow path for tests and manual construction
|
||||||
|
Plain []geositeDomain
|
||||||
|
Regex []geositeDomain
|
||||||
|
Root map[string]geositeDomain
|
||||||
|
Full map[string]geositeDomain
|
||||||
// Attributes are matched using "and" logic - if you have multiple attributes here,
|
// Attributes are matched using "and" logic - if you have multiple attributes here,
|
||||||
// a domain must have all of those attributes to be considered a match.
|
// a domain must have all of those attributes to be considered a match.
|
||||||
Attrs []string
|
Attrs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool {
|
func (m *geositeMatcher) attrsMatch(domain geositeDomain) bool {
|
||||||
// Match attributes first
|
if len(m.Attrs) == 0 {
|
||||||
if len(m.Attrs) > 0 {
|
return true
|
||||||
|
}
|
||||||
if len(domain.Attrs) == 0 {
|
if len(domain.Attrs) == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -129,6 +134,13 @@ func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool {
|
||||||
|
// Match attributes first
|
||||||
|
if !m.attrsMatch(domain) {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
switch domain.Type {
|
switch domain.Type {
|
||||||
@@ -152,54 +164,90 @@ func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *geositeMatcher) Match(host HostInfo) bool {
|
func (m *geositeMatcher) Match(host HostInfo) bool {
|
||||||
|
if host.Name == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if domain, ok := m.Full[host.Name]; ok && m.attrsMatch(domain) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for name := host.Name; name != ""; {
|
||||||
|
if domain, ok := m.Root[name]; ok && m.attrsMatch(domain) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
idx := strings.IndexByte(name, '.')
|
||||||
|
if idx < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
name = name[idx+1:]
|
||||||
|
}
|
||||||
|
for _, domain := range m.Plain {
|
||||||
|
if m.matchDomain(domain, host) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(m.Plain) == 0 && len(m.Regex) == 0 && len(m.Root) == 0 && len(m.Full) == 0 {
|
||||||
for _, domain := range m.Domains {
|
for _, domain := range m.Domains {
|
||||||
if m.matchDomain(domain, host) {
|
if m.matchDomain(domain, host) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
for _, domain := range m.Regex {
|
||||||
|
if m.matchDomain(domain, host) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func newGeositeMatcher(list *v2geo.GeoSite, attrs []string) (*geositeMatcher, error) {
|
func newGeositeMatcher(list *v2geo.GeoSite, attrs []string) (*geositeMatcher, error) {
|
||||||
domains := make([]geositeDomain, len(list.Domain))
|
matcher := &geositeMatcher{
|
||||||
for i, domain := range list.Domain {
|
Root: make(map[string]geositeDomain),
|
||||||
|
Full: make(map[string]geositeDomain),
|
||||||
|
Attrs: attrs,
|
||||||
|
}
|
||||||
|
for _, domain := range list.Domain {
|
||||||
|
var compiled geositeDomain
|
||||||
switch domain.Type {
|
switch domain.Type {
|
||||||
case v2geo.Domain_Plain:
|
case v2geo.Domain_Plain:
|
||||||
domains[i] = geositeDomain{
|
compiled = geositeDomain{
|
||||||
Type: geositeDomainPlain,
|
Type: geositeDomainPlain,
|
||||||
Value: domain.Value,
|
Value: domain.Value,
|
||||||
Attrs: domainAttributeToMap(domain.Attribute),
|
Attrs: domainAttributeToMap(domain.Attribute),
|
||||||
}
|
}
|
||||||
|
matcher.Plain = append(matcher.Plain, compiled)
|
||||||
case v2geo.Domain_Regex:
|
case v2geo.Domain_Regex:
|
||||||
regex, err := regexp.Compile(domain.Value)
|
regex, err := regexp.Compile(domain.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
domains[i] = geositeDomain{
|
compiled = geositeDomain{
|
||||||
Type: geositeDomainRegex,
|
Type: geositeDomainRegex,
|
||||||
|
Value: domain.Value,
|
||||||
Regex: regex,
|
Regex: regex,
|
||||||
Attrs: domainAttributeToMap(domain.Attribute),
|
Attrs: domainAttributeToMap(domain.Attribute),
|
||||||
}
|
}
|
||||||
|
matcher.Regex = append(matcher.Regex, compiled)
|
||||||
case v2geo.Domain_Full:
|
case v2geo.Domain_Full:
|
||||||
domains[i] = geositeDomain{
|
compiled = geositeDomain{
|
||||||
Type: geositeDomainFull,
|
Type: geositeDomainFull,
|
||||||
Value: domain.Value,
|
Value: domain.Value,
|
||||||
Attrs: domainAttributeToMap(domain.Attribute),
|
Attrs: domainAttributeToMap(domain.Attribute),
|
||||||
}
|
}
|
||||||
|
matcher.Full[domain.Value] = compiled
|
||||||
case v2geo.Domain_RootDomain:
|
case v2geo.Domain_RootDomain:
|
||||||
domains[i] = geositeDomain{
|
compiled = geositeDomain{
|
||||||
Type: geositeDomainRoot,
|
Type: geositeDomainRoot,
|
||||||
Value: domain.Value,
|
Value: domain.Value,
|
||||||
Attrs: domainAttributeToMap(domain.Attribute),
|
Attrs: domainAttributeToMap(domain.Attribute),
|
||||||
}
|
}
|
||||||
|
matcher.Root[domain.Value] = compiled
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("unsupported domain type")
|
return nil, errors.New("unsupported domain type")
|
||||||
}
|
}
|
||||||
|
matcher.Domains = append(matcher.Domains, compiled)
|
||||||
}
|
}
|
||||||
return &geositeMatcher{
|
return matcher, nil
|
||||||
Domains: domains,
|
|
||||||
Attrs: attrs,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func domainAttributeToMap(attrs []*v2geo.Domain_Attribute) map[string]bool {
|
func domainAttributeToMap(attrs []*v2geo.Domain_Attribute) map[string]bool {
|
||||||
|
|||||||
+120
-5
@@ -59,6 +59,8 @@ type compiledExprRule struct {
|
|||||||
Log bool
|
Log bool
|
||||||
ModInstance modifier.Instance
|
ModInstance modifier.Instance
|
||||||
Program *vm.Program
|
Program *vm.Program
|
||||||
|
Native nativeExpr
|
||||||
|
AnalyzerRefs map[string]analyzerRuleRef
|
||||||
GeoSiteConditions []string
|
GeoSiteConditions []string
|
||||||
StartTimeSecs int // seconds since midnight, -1 if unset
|
StartTimeSecs int // seconds since midnight, -1 if unset
|
||||||
StopTimeSecs int // seconds since midnight, -1 if unset
|
StopTimeSecs int // seconds since midnight, -1 if unset
|
||||||
@@ -67,6 +69,7 @@ type compiledExprRule struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var _ Ruleset = (*exprRuleset)(nil)
|
var _ Ruleset = (*exprRuleset)(nil)
|
||||||
|
var _ LogFinalizer = (*exprRuleset)(nil)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
envPool = sync.Pool{
|
envPool = sync.Pool{
|
||||||
@@ -102,10 +105,12 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
env := envPool.Get().(map[string]any)
|
var env map[string]any
|
||||||
clear(env)
|
var macMap, ipMap, portMap map[string]any
|
||||||
macMap, ipMap, portMap := populateExprEnv(env, info)
|
|
||||||
releaseEnv := func() {
|
releaseEnv := func() {
|
||||||
|
if env == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
clear(env)
|
clear(env)
|
||||||
envPool.Put(env)
|
envPool.Put(env)
|
||||||
putSubMap(macMap)
|
putSubMap(macMap)
|
||||||
@@ -113,10 +118,20 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
|||||||
putSubMap(portMap)
|
putSubMap(portMap)
|
||||||
}
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
logged := false
|
||||||
for _, rule := range r.Rules {
|
for _, rule := range r.Rules {
|
||||||
if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) {
|
if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
matched := false
|
||||||
|
if rule.Native != nil {
|
||||||
|
matched = rule.Native.Match(info)
|
||||||
|
} else {
|
||||||
|
if env == nil {
|
||||||
|
env = envPool.Get().(map[string]any)
|
||||||
|
clear(env)
|
||||||
|
macMap, ipMap, portMap = populateExprEnv(env, info)
|
||||||
|
}
|
||||||
v, err := vm.Run(rule.Program, env)
|
v, err := vm.Run(rule.Program, env)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if r.stats != nil {
|
if r.stats != nil {
|
||||||
@@ -125,19 +140,23 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
|||||||
r.Logger.MatchError(info, rule.Name, err)
|
r.Logger.MatchError(info, rule.Name, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if vBool, ok := v.(bool); ok && vBool {
|
matched, _ = v.(bool)
|
||||||
|
}
|
||||||
|
if matched {
|
||||||
if rule.Log {
|
if rule.Log {
|
||||||
logInfo := info
|
logInfo := info
|
||||||
if len(rule.GeoSiteConditions) > 0 && r.GeoMatcher != nil {
|
if len(rule.GeoSiteConditions) > 0 && r.GeoMatcher != nil {
|
||||||
logInfo = addGeoSiteLogMetadata(logInfo, r.GeoMatcher, rule.GeoSiteConditions)
|
logInfo = addGeoSiteLogMetadata(logInfo, r.GeoMatcher, rule.GeoSiteConditions)
|
||||||
}
|
}
|
||||||
r.Logger.Log(logInfo, rule.Name)
|
r.Logger.Log(logInfo, rule.Name)
|
||||||
|
logged = true
|
||||||
}
|
}
|
||||||
if rule.Action != nil {
|
if rule.Action != nil {
|
||||||
releaseEnv()
|
releaseEnv()
|
||||||
return MatchResult{
|
return MatchResult{
|
||||||
Action: *rule.Action,
|
Action: *rule.Action,
|
||||||
ModInstance: rule.ModInstance,
|
ModInstance: rule.ModInstance,
|
||||||
|
Logged: logged,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -145,9 +164,40 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
|||||||
releaseEnv()
|
releaseEnv()
|
||||||
return MatchResult{
|
return MatchResult{
|
||||||
Action: ActionMaybe,
|
Action: ActionMaybe,
|
||||||
|
Logged: logged,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *exprRuleset) CanFinalizeAfterLog(info StreamInfo, activeAnalyzers []string) bool {
|
||||||
|
active := make(map[string]bool, len(activeAnalyzers))
|
||||||
|
for _, name := range activeAnalyzers {
|
||||||
|
active[name] = true
|
||||||
|
}
|
||||||
|
for _, rule := range r.Rules {
|
||||||
|
if rule.Action == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if *rule.Action == ActionModify {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if rule.StartTimeSecs != -1 || rule.StopTimeSecs != -1 || len(rule.Weekdays) != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for name, ref := range rule.AnalyzerRefs {
|
||||||
|
if !active[name] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ref.ResponseSide {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, ok := info.Props[name]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (r *exprRuleset) Stats() Stats {
|
func (r *exprRuleset) Stats() Stats {
|
||||||
if r == nil || r.stats == nil {
|
if r == nil || r.stats == nil {
|
||||||
return Stats{}
|
return Stats{}
|
||||||
@@ -242,17 +292,23 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("rule %q has invalid weekdays: %w", rule.Name, err)
|
return nil, fmt.Errorf("rule %q has invalid weekdays: %w", rule.Name, err)
|
||||||
}
|
}
|
||||||
|
var analyzerRefs map[string]analyzerRuleRef
|
||||||
|
if refTree, err := parser.Parse(rule.Expr); err == nil && refTree != nil {
|
||||||
|
analyzerRefs = collectAnalyzerRefs(refTree.Node, fullAnMap)
|
||||||
|
}
|
||||||
cr := compiledExprRule{
|
cr := compiledExprRule{
|
||||||
Name: rule.Name,
|
Name: rule.Name,
|
||||||
Action: action,
|
Action: action,
|
||||||
Log: rule.Log,
|
Log: rule.Log,
|
||||||
Program: program,
|
Program: program,
|
||||||
|
AnalyzerRefs: analyzerRefs,
|
||||||
GeoSiteConditions: extractGeoSiteConditions(rule.Expr),
|
GeoSiteConditions: extractGeoSiteConditions(rule.Expr),
|
||||||
StartTimeSecs: startSecs,
|
StartTimeSecs: startSecs,
|
||||||
StopTimeSecs: stopSecs,
|
StopTimeSecs: stopSecs,
|
||||||
Weekdays: weekdays,
|
Weekdays: weekdays,
|
||||||
WeekdaysNegated: weekdaysNegated,
|
WeekdaysNegated: weekdaysNegated,
|
||||||
}
|
}
|
||||||
|
cr.Native = compileNativeExpr(rule.Expr, funcMap, geoMatcher)
|
||||||
if action != nil && *action == ActionModify {
|
if action != nil && *action == ActionModify {
|
||||||
mod, ok := fullModMap[rule.Modifier.Name]
|
mod, ok := fullModMap[rule.Modifier.Name]
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -266,9 +322,16 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
|||||||
}
|
}
|
||||||
compiledRules = append(compiledRules, cr)
|
compiledRules = append(compiledRules, cr)
|
||||||
}
|
}
|
||||||
|
depAns := make([]analyzer.Analyzer, 0, len(depAnMap))
|
||||||
|
for _, a := range ans {
|
||||||
|
if depAnMap[a.Name()] != nil {
|
||||||
|
depAns = append(depAns, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &exprRuleset{
|
return &exprRuleset{
|
||||||
Rules: compiledRules,
|
Rules: compiledRules,
|
||||||
Ans: ans,
|
Ans: depAns,
|
||||||
Logger: config.Logger,
|
Logger: config.Logger,
|
||||||
GeoMatcher: geoMatcher,
|
GeoMatcher: geoMatcher,
|
||||||
stats: stats,
|
stats: stats,
|
||||||
@@ -373,6 +436,58 @@ func (v *idVisitor) Visit(node *ast.Node) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type analyzerRuleRef struct {
|
||||||
|
ResponseSide bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type analyzerRefVisitor struct {
|
||||||
|
Analyzers map[string]analyzer.Analyzer
|
||||||
|
Refs map[string]analyzerRuleRef
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectAnalyzerRefs(root ast.Node, analyzers map[string]analyzer.Analyzer) map[string]analyzerRuleRef {
|
||||||
|
visitor := &analyzerRefVisitor{
|
||||||
|
Analyzers: analyzers,
|
||||||
|
Refs: make(map[string]analyzerRuleRef),
|
||||||
|
}
|
||||||
|
ast.Walk(&root, visitor)
|
||||||
|
return visitor.Refs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *analyzerRefVisitor) Visit(node *ast.Node) {
|
||||||
|
switch n := (*node).(type) {
|
||||||
|
case *ast.IdentifierNode:
|
||||||
|
if _, ok := v.Analyzers[n.Value]; ok {
|
||||||
|
v.add(n.Value, false)
|
||||||
|
}
|
||||||
|
case *ast.MemberNode:
|
||||||
|
path := memberPath(n)
|
||||||
|
if len(path) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
name := path[0]
|
||||||
|
if _, ok := v.Analyzers[name]; !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
v.add(name, len(path) > 1 && isResponseSideAnalyzerPath(path[1]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *analyzerRefVisitor) add(name string, responseSide bool) {
|
||||||
|
ref := v.Refs[name]
|
||||||
|
ref.ResponseSide = ref.ResponseSide || responseSide
|
||||||
|
v.Refs[name] = ref
|
||||||
|
}
|
||||||
|
|
||||||
|
func isResponseSideAnalyzerPath(name string) bool {
|
||||||
|
switch name {
|
||||||
|
case "resp", "server", "answers", "response":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// idPatcher patches the AST during expr compilation, replacing certain values with
|
// idPatcher patches the AST during expr compilation, replacing certain values with
|
||||||
// their internal representations for better runtime performance.
|
// their internal representations for better runtime performance.
|
||||||
type idPatcher struct {
|
type idPatcher struct {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package ruleset
|
package ruleset
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -12,6 +13,13 @@ import (
|
|||||||
"github.com/expr-lang/expr/parser"
|
"github.com/expr-lang/expr/parser"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type testAnalyzer struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a testAnalyzer) Name() string { return a.name }
|
||||||
|
func (a testAnalyzer) Limit() int { return 0 }
|
||||||
|
|
||||||
func TestExtractGeoSiteConditions(t *testing.T) {
|
func TestExtractGeoSiteConditions(t *testing.T) {
|
||||||
expression := `
|
expression := `
|
||||||
(geosite(tls.req.sni, "openai") || geosite(quic.req.sni, "OpenAI")) &&
|
(geosite(tls.req.sni, "openai") || geosite(quic.req.sni, "OpenAI")) &&
|
||||||
@@ -88,3 +96,93 @@ func TestIDPatcher_PatchesGeoSiteORChainToGeoSiteSet(t *testing.T) {
|
|||||||
t.Fatalf("expected OR chain to be collapsed, got %q", got)
|
t.Fatalf("expected OR chain to be collapsed, got %q", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompileExprRulesPrunesUnusedAnalyzers(t *testing.T) {
|
||||||
|
rs, err := CompileExprRules([]ExprRule{
|
||||||
|
{Name: "network-only", Action: "allow", Expr: `proto == "tcp" && port.dst == 443`},
|
||||||
|
}, []analyzer.Analyzer{testAnalyzer{name: "tls"}, testAnalyzer{name: "quic"}}, nil, &BuiltinConfig{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CompileExprRules error: %v", err)
|
||||||
|
}
|
||||||
|
exprRS := rs.(*exprRuleset)
|
||||||
|
if len(exprRS.Ans) != 0 {
|
||||||
|
t.Fatalf("expected no analyzers for network-only rule, got %d", len(exprRS.Ans))
|
||||||
|
}
|
||||||
|
if exprRS.Rules[0].Native == nil {
|
||||||
|
t.Fatalf("expected network-only rule to compile to native matcher")
|
||||||
|
}
|
||||||
|
got := rs.Match(StreamInfo{Protocol: ProtocolTCP, DstPort: 443})
|
||||||
|
if got.Action != ActionAllow {
|
||||||
|
t.Fatalf("native match action=%v want=%v", got.Action, ActionAllow)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompileExprRulesKeepsReferencedAnalyzersOnly(t *testing.T) {
|
||||||
|
rs, err := CompileExprRules([]ExprRule{
|
||||||
|
{Name: "tls-only", Action: "allow", Expr: `tls != nil && tls.req != nil && tls.req.sni == "example.com"`},
|
||||||
|
}, []analyzer.Analyzer{testAnalyzer{name: "tls"}, testAnalyzer{name: "quic"}}, nil, &BuiltinConfig{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CompileExprRules error: %v", err)
|
||||||
|
}
|
||||||
|
exprRS := rs.(*exprRuleset)
|
||||||
|
if len(exprRS.Ans) != 1 || exprRS.Ans[0].Name() != "tls" {
|
||||||
|
t.Fatalf("expected only tls analyzer, got %#v", exprRS.Ans)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNativeCIDRMatcher(t *testing.T) {
|
||||||
|
funcMap, geoMatcher := buildFunctionMapForTest()
|
||||||
|
n := compileNativeExpr(`cidr(ip.src, "192.168.1.0/24") && port.dst >= 80 && port.dst <= 443`, funcMap, geoMatcher)
|
||||||
|
if n == nil {
|
||||||
|
t.Fatal("expected native matcher")
|
||||||
|
}
|
||||||
|
if !n.Match(StreamInfo{SrcIP: net.ParseIP("192.168.1.10"), DstPort: 443}) {
|
||||||
|
t.Fatal("expected native CIDR matcher to match")
|
||||||
|
}
|
||||||
|
if n.Match(StreamInfo{SrcIP: net.ParseIP("10.0.0.1"), DstPort: 443}) {
|
||||||
|
t.Fatal("expected native CIDR matcher not to match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCanFinalizeAfterLogForRequestOnlyActionRules(t *testing.T) {
|
||||||
|
rs, err := CompileExprRules([]ExprRule{
|
||||||
|
{Name: "log-host", Log: true, Expr: `tls != nil && tls.req != nil && tls.req.sni != nil`},
|
||||||
|
{Name: "block-bad-host", Action: "block", Expr: `tls != nil && tls.req != nil && tls.req.sni == "bad.example"`},
|
||||||
|
}, []analyzer.Analyzer{testAnalyzer{name: "tls"}}, nil, &BuiltinConfig{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CompileExprRules error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
info := StreamInfo{
|
||||||
|
Props: analyzer.CombinedPropMap{
|
||||||
|
"tls": analyzer.PropMap{"req": analyzer.PropMap{"sni": "good.example"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if !rs.(LogFinalizer).CanFinalizeAfterLog(info, []string{"tls"}) {
|
||||||
|
t.Fatal("expected request-only rules to allow log finalization once request props exist")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCanFinalizeAfterLogWaitsForResponseActionRules(t *testing.T) {
|
||||||
|
rs, err := CompileExprRules([]ExprRule{
|
||||||
|
{Name: "log-host", Log: true, Expr: `tls != nil && tls.req != nil && tls.req.sni != nil`},
|
||||||
|
{Name: "block-response", Action: "block", Expr: `tls != nil && tls.resp != nil && tls.resp.cipher_suite == "bad"`},
|
||||||
|
}, []analyzer.Analyzer{testAnalyzer{name: "tls"}}, nil, &BuiltinConfig{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CompileExprRules error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
info := StreamInfo{
|
||||||
|
Props: analyzer.CombinedPropMap{
|
||||||
|
"tls": analyzer.PropMap{"req": analyzer.PropMap{"sni": "good.example"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if rs.(LogFinalizer).CanFinalizeAfterLog(info, []string{"tls"}) {
|
||||||
|
t.Fatal("expected response-side rule to keep inspection open")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildFunctionMapForTest() (map[string]*Function, *geo.GeoMatcher) {
|
||||||
|
m, g := buildFunctionMap(&BuiltinConfig{}, nil)
|
||||||
|
return m, g
|
||||||
|
}
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ func (i StreamInfo) DstString() string {
|
|||||||
type MatchResult struct {
|
type MatchResult struct {
|
||||||
Action Action
|
Action Action
|
||||||
ModInstance modifier.Instance
|
ModInstance modifier.Instance
|
||||||
|
Logged bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Ruleset interface {
|
type Ruleset interface {
|
||||||
@@ -96,6 +97,10 @@ type Ruleset interface {
|
|||||||
Match(StreamInfo) MatchResult
|
Match(StreamInfo) MatchResult
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type LogFinalizer interface {
|
||||||
|
CanFinalizeAfterLog(StreamInfo, []string) bool
|
||||||
|
}
|
||||||
|
|
||||||
type Stats struct {
|
type Stats struct {
|
||||||
MatchCalls uint64
|
MatchCalls uint64
|
||||||
MatchErrors uint64
|
MatchErrors uint64
|
||||||
|
|||||||
@@ -0,0 +1,273 @@
|
|||||||
|
package ruleset
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/expr-lang/expr/ast"
|
||||||
|
"github.com/expr-lang/expr/parser"
|
||||||
|
|
||||||
|
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo"
|
||||||
|
)
|
||||||
|
|
||||||
|
type nativeExpr interface {
|
||||||
|
Match(StreamInfo) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type nativeBoolFunc func(StreamInfo) bool
|
||||||
|
|
||||||
|
func (f nativeBoolFunc) Match(info StreamInfo) bool {
|
||||||
|
return f(info)
|
||||||
|
}
|
||||||
|
|
||||||
|
type nativeValueFunc func(StreamInfo) (any, bool)
|
||||||
|
|
||||||
|
func compileNativeExpr(expression string, funcMap map[string]*Function, gm *geo.GeoMatcher) nativeExpr {
|
||||||
|
tree, err := parser.Parse(expression)
|
||||||
|
if err != nil || tree == nil || tree.Node == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
root := tree.Node
|
||||||
|
patcher := &idPatcher{FuncMap: funcMap, GeoMatcher: gm}
|
||||||
|
ast.Walk(&root, patcher)
|
||||||
|
if patcher.Err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return compileNativeBool(root)
|
||||||
|
}
|
||||||
|
|
||||||
|
func compileNativeBool(node ast.Node) nativeExpr {
|
||||||
|
switch n := node.(type) {
|
||||||
|
case *ast.BinaryNode:
|
||||||
|
switch n.Operator {
|
||||||
|
case "&&", "and":
|
||||||
|
left := compileNativeBool(n.Left)
|
||||||
|
right := compileNativeBool(n.Right)
|
||||||
|
if left == nil || right == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return nativeBoolFunc(func(info StreamInfo) bool {
|
||||||
|
return left.Match(info) && right.Match(info)
|
||||||
|
})
|
||||||
|
case "||", "or":
|
||||||
|
left := compileNativeBool(n.Left)
|
||||||
|
right := compileNativeBool(n.Right)
|
||||||
|
if left == nil || right == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return nativeBoolFunc(func(info StreamInfo) bool {
|
||||||
|
return left.Match(info) || right.Match(info)
|
||||||
|
})
|
||||||
|
case "==", "!=", ">", ">=", "<", "<=":
|
||||||
|
left := compileNativeValue(n.Left)
|
||||||
|
right := compileNativeValue(n.Right)
|
||||||
|
if left == nil || right == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
op := n.Operator
|
||||||
|
return nativeBoolFunc(func(info StreamInfo) bool {
|
||||||
|
lv, lok := left(info)
|
||||||
|
rv, rok := right(info)
|
||||||
|
if !lok || !rok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
result, ok := compareNativeValues(lv, rv, op)
|
||||||
|
return ok && result
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case *ast.UnaryNode:
|
||||||
|
if n.Operator != "!" && n.Operator != "not" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
child := compileNativeBool(n.Node)
|
||||||
|
if child == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return nativeBoolFunc(func(info StreamInfo) bool {
|
||||||
|
return !child.Match(info)
|
||||||
|
})
|
||||||
|
case *ast.CallNode:
|
||||||
|
return compileNativeCall(n)
|
||||||
|
case *ast.BoolNode:
|
||||||
|
value := n.Value
|
||||||
|
return nativeBoolFunc(func(StreamInfo) bool { return value })
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compileNativeCall(n *ast.CallNode) nativeExpr {
|
||||||
|
id, ok := n.Callee.(*ast.IdentifierNode)
|
||||||
|
if !ok || strings.ToLower(id.Value) != "cidr" || len(n.Arguments) != 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ipValue := compileNativeValue(n.Arguments[0])
|
||||||
|
if ipValue == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var cidr *net.IPNet
|
||||||
|
switch arg := n.Arguments[1].(type) {
|
||||||
|
case *ast.ConstantNode:
|
||||||
|
cidr, _ = arg.Value.(*net.IPNet)
|
||||||
|
case *ast.StringNode:
|
||||||
|
_, parsed, err := net.ParseCIDR(arg.Value)
|
||||||
|
if err == nil {
|
||||||
|
cidr = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cidr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return nativeBoolFunc(func(info StreamInfo) bool {
|
||||||
|
value, ok := ipValue(info)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch v := value.(type) {
|
||||||
|
case net.IP:
|
||||||
|
return cidr.Contains(v)
|
||||||
|
case string:
|
||||||
|
ip := net.ParseIP(v)
|
||||||
|
return ip != nil && cidr.Contains(ip)
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func compileNativeValue(node ast.Node) nativeValueFunc {
|
||||||
|
switch n := node.(type) {
|
||||||
|
case *ast.StringNode:
|
||||||
|
value := n.Value
|
||||||
|
return func(StreamInfo) (any, bool) { return value, true }
|
||||||
|
case *ast.IntegerNode:
|
||||||
|
value := int64(n.Value)
|
||||||
|
return func(StreamInfo) (any, bool) { return value, true }
|
||||||
|
case *ast.IdentifierNode:
|
||||||
|
switch strings.ToLower(n.Value) {
|
||||||
|
case "proto":
|
||||||
|
return func(info StreamInfo) (any, bool) { return info.Protocol.String(), true }
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case *ast.MemberNode:
|
||||||
|
return compileNativeMember(n)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compileNativeMember(n *ast.MemberNode) nativeValueFunc {
|
||||||
|
path := memberPath(n)
|
||||||
|
switch strings.Join(path, ".") {
|
||||||
|
case "mac.src":
|
||||||
|
return func(info StreamInfo) (any, bool) { return info.SrcMAC.String(), true }
|
||||||
|
case "mac.dst":
|
||||||
|
return func(info StreamInfo) (any, bool) { return info.DstMAC.String(), true }
|
||||||
|
case "ip.src":
|
||||||
|
return func(info StreamInfo) (any, bool) { return info.SrcIP, info.SrcIP != nil }
|
||||||
|
case "ip.dst":
|
||||||
|
return func(info StreamInfo) (any, bool) { return info.DstIP, info.DstIP != nil }
|
||||||
|
case "port.src":
|
||||||
|
return func(info StreamInfo) (any, bool) { return int64(info.SrcPort), true }
|
||||||
|
case "port.dst":
|
||||||
|
return func(info StreamInfo) (any, bool) { return int64(info.DstPort), true }
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func memberPath(node ast.Node) []string {
|
||||||
|
switch n := node.(type) {
|
||||||
|
case *ast.IdentifierNode:
|
||||||
|
return []string{strings.ToLower(n.Value)}
|
||||||
|
case *ast.MemberNode:
|
||||||
|
base := memberPath(n.Node)
|
||||||
|
prop, ok := n.Property.(*ast.StringNode)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return append(base, strings.ToLower(prop.Value))
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareNativeValues(left, right any, op string) (bool, bool) {
|
||||||
|
if li, lok := nativeInt(left); lok {
|
||||||
|
ri, rok := nativeInt(right)
|
||||||
|
if !rok {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return compareNativeOrdered(li, ri, op), true
|
||||||
|
}
|
||||||
|
ls, lok := nativeString(left)
|
||||||
|
if !lok {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
rs, rok := nativeString(right)
|
||||||
|
if !rok {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
switch op {
|
||||||
|
case "==":
|
||||||
|
return ls == rs, true
|
||||||
|
case "!=":
|
||||||
|
return ls != rs, true
|
||||||
|
default:
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareNativeOrdered(left, right int64, op string) bool {
|
||||||
|
switch op {
|
||||||
|
case "==":
|
||||||
|
return left == right
|
||||||
|
case "!=":
|
||||||
|
return left != right
|
||||||
|
case ">":
|
||||||
|
return left > right
|
||||||
|
case ">=":
|
||||||
|
return left >= right
|
||||||
|
case "<":
|
||||||
|
return left < right
|
||||||
|
case "<=":
|
||||||
|
return left <= right
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func nativeInt(v any) (int64, bool) {
|
||||||
|
switch n := v.(type) {
|
||||||
|
case int:
|
||||||
|
return int64(n), true
|
||||||
|
case int64:
|
||||||
|
return n, true
|
||||||
|
case uint16:
|
||||||
|
return int64(n), true
|
||||||
|
case *ast.IntegerNode:
|
||||||
|
return int64(n.Value), true
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func nativeString(v any) (string, bool) {
|
||||||
|
switch s := v.(type) {
|
||||||
|
case string:
|
||||||
|
return s, true
|
||||||
|
case net.IP:
|
||||||
|
if s == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return s.String(), true
|
||||||
|
case int64:
|
||||||
|
return strconv.FormatInt(s, 10), true
|
||||||
|
default:
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user