engine: more performance improvements
This commit is contained in:
@@ -112,23 +112,35 @@ type geositeDomain struct {
|
||||
}
|
||||
|
||||
type geositeMatcher struct {
|
||||
Domains []geositeDomain
|
||||
Domains []geositeDomain // legacy slow path for tests and manual construction
|
||||
Plain []geositeDomain
|
||||
Regex []geositeDomain
|
||||
Root map[string]geositeDomain
|
||||
Full map[string]geositeDomain
|
||||
// Attributes are matched using "and" logic - if you have multiple attributes here,
|
||||
// a domain must have all of those attributes to be considered a match.
|
||||
Attrs []string
|
||||
}
|
||||
|
||||
func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool {
|
||||
// Match attributes first
|
||||
if len(m.Attrs) > 0 {
|
||||
if len(domain.Attrs) == 0 {
|
||||
func (m *geositeMatcher) attrsMatch(domain geositeDomain) bool {
|
||||
if len(m.Attrs) == 0 {
|
||||
return true
|
||||
}
|
||||
if len(domain.Attrs) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, attr := range m.Attrs {
|
||||
if !domain.Attrs[attr] {
|
||||
return false
|
||||
}
|
||||
for _, attr := range m.Attrs {
|
||||
if !domain.Attrs[attr] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool {
|
||||
// Match attributes first
|
||||
if !m.attrsMatch(domain) {
|
||||
return false
|
||||
}
|
||||
|
||||
switch domain.Type {
|
||||
@@ -152,7 +164,35 @@ func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool {
|
||||
}
|
||||
|
||||
func (m *geositeMatcher) Match(host HostInfo) bool {
|
||||
for _, domain := range m.Domains {
|
||||
if host.Name == "" {
|
||||
return false
|
||||
}
|
||||
if domain, ok := m.Full[host.Name]; ok && m.attrsMatch(domain) {
|
||||
return true
|
||||
}
|
||||
for name := host.Name; name != ""; {
|
||||
if domain, ok := m.Root[name]; ok && m.attrsMatch(domain) {
|
||||
return true
|
||||
}
|
||||
idx := strings.IndexByte(name, '.')
|
||||
if idx < 0 {
|
||||
break
|
||||
}
|
||||
name = name[idx+1:]
|
||||
}
|
||||
for _, domain := range m.Plain {
|
||||
if m.matchDomain(domain, host) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if len(m.Plain) == 0 && len(m.Regex) == 0 && len(m.Root) == 0 && len(m.Full) == 0 {
|
||||
for _, domain := range m.Domains {
|
||||
if m.matchDomain(domain, host) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, domain := range m.Regex {
|
||||
if m.matchDomain(domain, host) {
|
||||
return true
|
||||
}
|
||||
@@ -161,45 +201,53 @@ func (m *geositeMatcher) Match(host HostInfo) bool {
|
||||
}
|
||||
|
||||
func newGeositeMatcher(list *v2geo.GeoSite, attrs []string) (*geositeMatcher, error) {
|
||||
domains := make([]geositeDomain, len(list.Domain))
|
||||
for i, domain := range list.Domain {
|
||||
matcher := &geositeMatcher{
|
||||
Root: make(map[string]geositeDomain),
|
||||
Full: make(map[string]geositeDomain),
|
||||
Attrs: attrs,
|
||||
}
|
||||
for _, domain := range list.Domain {
|
||||
var compiled geositeDomain
|
||||
switch domain.Type {
|
||||
case v2geo.Domain_Plain:
|
||||
domains[i] = geositeDomain{
|
||||
compiled = geositeDomain{
|
||||
Type: geositeDomainPlain,
|
||||
Value: domain.Value,
|
||||
Attrs: domainAttributeToMap(domain.Attribute),
|
||||
}
|
||||
matcher.Plain = append(matcher.Plain, compiled)
|
||||
case v2geo.Domain_Regex:
|
||||
regex, err := regexp.Compile(domain.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
domains[i] = geositeDomain{
|
||||
compiled = geositeDomain{
|
||||
Type: geositeDomainRegex,
|
||||
Value: domain.Value,
|
||||
Regex: regex,
|
||||
Attrs: domainAttributeToMap(domain.Attribute),
|
||||
}
|
||||
matcher.Regex = append(matcher.Regex, compiled)
|
||||
case v2geo.Domain_Full:
|
||||
domains[i] = geositeDomain{
|
||||
compiled = geositeDomain{
|
||||
Type: geositeDomainFull,
|
||||
Value: domain.Value,
|
||||
Attrs: domainAttributeToMap(domain.Attribute),
|
||||
}
|
||||
matcher.Full[domain.Value] = compiled
|
||||
case v2geo.Domain_RootDomain:
|
||||
domains[i] = geositeDomain{
|
||||
compiled = geositeDomain{
|
||||
Type: geositeDomainRoot,
|
||||
Value: domain.Value,
|
||||
Attrs: domainAttributeToMap(domain.Attribute),
|
||||
}
|
||||
matcher.Root[domain.Value] = compiled
|
||||
default:
|
||||
return nil, errors.New("unsupported domain type")
|
||||
}
|
||||
matcher.Domains = append(matcher.Domains, compiled)
|
||||
}
|
||||
return &geositeMatcher{
|
||||
Domains: domains,
|
||||
Attrs: attrs,
|
||||
}, nil
|
||||
return matcher, nil
|
||||
}
|
||||
|
||||
func domainAttributeToMap(attrs []*v2geo.Domain_Attribute) map[string]bool {
|
||||
|
||||
+126
-11
@@ -59,6 +59,8 @@ type compiledExprRule struct {
|
||||
Log bool
|
||||
ModInstance modifier.Instance
|
||||
Program *vm.Program
|
||||
Native nativeExpr
|
||||
AnalyzerRefs map[string]analyzerRuleRef
|
||||
GeoSiteConditions []string
|
||||
StartTimeSecs int // seconds since midnight, -1 if unset
|
||||
StopTimeSecs int // seconds since midnight, -1 if unset
|
||||
@@ -67,6 +69,7 @@ type compiledExprRule struct {
|
||||
}
|
||||
|
||||
var _ Ruleset = (*exprRuleset)(nil)
|
||||
var _ LogFinalizer = (*exprRuleset)(nil)
|
||||
|
||||
var (
|
||||
envPool = sync.Pool{
|
||||
@@ -102,10 +105,12 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
||||
}()
|
||||
}
|
||||
|
||||
env := envPool.Get().(map[string]any)
|
||||
clear(env)
|
||||
macMap, ipMap, portMap := populateExprEnv(env, info)
|
||||
var env map[string]any
|
||||
var macMap, ipMap, portMap map[string]any
|
||||
releaseEnv := func() {
|
||||
if env == nil {
|
||||
return
|
||||
}
|
||||
clear(env)
|
||||
envPool.Put(env)
|
||||
putSubMap(macMap)
|
||||
@@ -113,31 +118,45 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
||||
putSubMap(portMap)
|
||||
}
|
||||
now := time.Now()
|
||||
logged := false
|
||||
for _, rule := range r.Rules {
|
||||
if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) {
|
||||
continue
|
||||
}
|
||||
v, err := vm.Run(rule.Program, env)
|
||||
if err != nil {
|
||||
if r.stats != nil {
|
||||
r.stats.MatchErrors.Add(1)
|
||||
matched := false
|
||||
if rule.Native != nil {
|
||||
matched = rule.Native.Match(info)
|
||||
} else {
|
||||
if env == nil {
|
||||
env = envPool.Get().(map[string]any)
|
||||
clear(env)
|
||||
macMap, ipMap, portMap = populateExprEnv(env, info)
|
||||
}
|
||||
r.Logger.MatchError(info, rule.Name, err)
|
||||
continue
|
||||
v, err := vm.Run(rule.Program, env)
|
||||
if err != nil {
|
||||
if r.stats != nil {
|
||||
r.stats.MatchErrors.Add(1)
|
||||
}
|
||||
r.Logger.MatchError(info, rule.Name, err)
|
||||
continue
|
||||
}
|
||||
matched, _ = v.(bool)
|
||||
}
|
||||
if vBool, ok := v.(bool); ok && vBool {
|
||||
if matched {
|
||||
if rule.Log {
|
||||
logInfo := info
|
||||
if len(rule.GeoSiteConditions) > 0 && r.GeoMatcher != nil {
|
||||
logInfo = addGeoSiteLogMetadata(logInfo, r.GeoMatcher, rule.GeoSiteConditions)
|
||||
}
|
||||
r.Logger.Log(logInfo, rule.Name)
|
||||
logged = true
|
||||
}
|
||||
if rule.Action != nil {
|
||||
releaseEnv()
|
||||
return MatchResult{
|
||||
Action: *rule.Action,
|
||||
ModInstance: rule.ModInstance,
|
||||
Logged: logged,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -145,9 +164,40 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
||||
releaseEnv()
|
||||
return MatchResult{
|
||||
Action: ActionMaybe,
|
||||
Logged: logged,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *exprRuleset) CanFinalizeAfterLog(info StreamInfo, activeAnalyzers []string) bool {
|
||||
active := make(map[string]bool, len(activeAnalyzers))
|
||||
for _, name := range activeAnalyzers {
|
||||
active[name] = true
|
||||
}
|
||||
for _, rule := range r.Rules {
|
||||
if rule.Action == nil {
|
||||
continue
|
||||
}
|
||||
if *rule.Action == ActionModify {
|
||||
return false
|
||||
}
|
||||
if rule.StartTimeSecs != -1 || rule.StopTimeSecs != -1 || len(rule.Weekdays) != 0 {
|
||||
return false
|
||||
}
|
||||
for name, ref := range rule.AnalyzerRefs {
|
||||
if !active[name] {
|
||||
continue
|
||||
}
|
||||
if ref.ResponseSide {
|
||||
return false
|
||||
}
|
||||
if _, ok := info.Props[name]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *exprRuleset) Stats() Stats {
|
||||
if r == nil || r.stats == nil {
|
||||
return Stats{}
|
||||
@@ -242,17 +292,23 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rule %q has invalid weekdays: %w", rule.Name, err)
|
||||
}
|
||||
var analyzerRefs map[string]analyzerRuleRef
|
||||
if refTree, err := parser.Parse(rule.Expr); err == nil && refTree != nil {
|
||||
analyzerRefs = collectAnalyzerRefs(refTree.Node, fullAnMap)
|
||||
}
|
||||
cr := compiledExprRule{
|
||||
Name: rule.Name,
|
||||
Action: action,
|
||||
Log: rule.Log,
|
||||
Program: program,
|
||||
AnalyzerRefs: analyzerRefs,
|
||||
GeoSiteConditions: extractGeoSiteConditions(rule.Expr),
|
||||
StartTimeSecs: startSecs,
|
||||
StopTimeSecs: stopSecs,
|
||||
Weekdays: weekdays,
|
||||
WeekdaysNegated: weekdaysNegated,
|
||||
}
|
||||
cr.Native = compileNativeExpr(rule.Expr, funcMap, geoMatcher)
|
||||
if action != nil && *action == ActionModify {
|
||||
mod, ok := fullModMap[rule.Modifier.Name]
|
||||
if !ok {
|
||||
@@ -266,9 +322,16 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
||||
}
|
||||
compiledRules = append(compiledRules, cr)
|
||||
}
|
||||
depAns := make([]analyzer.Analyzer, 0, len(depAnMap))
|
||||
for _, a := range ans {
|
||||
if depAnMap[a.Name()] != nil {
|
||||
depAns = append(depAns, a)
|
||||
}
|
||||
}
|
||||
|
||||
return &exprRuleset{
|
||||
Rules: compiledRules,
|
||||
Ans: ans,
|
||||
Ans: depAns,
|
||||
Logger: config.Logger,
|
||||
GeoMatcher: geoMatcher,
|
||||
stats: stats,
|
||||
@@ -373,6 +436,58 @@ func (v *idVisitor) Visit(node *ast.Node) {
|
||||
}
|
||||
}
|
||||
|
||||
type analyzerRuleRef struct {
|
||||
ResponseSide bool
|
||||
}
|
||||
|
||||
type analyzerRefVisitor struct {
|
||||
Analyzers map[string]analyzer.Analyzer
|
||||
Refs map[string]analyzerRuleRef
|
||||
}
|
||||
|
||||
func collectAnalyzerRefs(root ast.Node, analyzers map[string]analyzer.Analyzer) map[string]analyzerRuleRef {
|
||||
visitor := &analyzerRefVisitor{
|
||||
Analyzers: analyzers,
|
||||
Refs: make(map[string]analyzerRuleRef),
|
||||
}
|
||||
ast.Walk(&root, visitor)
|
||||
return visitor.Refs
|
||||
}
|
||||
|
||||
func (v *analyzerRefVisitor) Visit(node *ast.Node) {
|
||||
switch n := (*node).(type) {
|
||||
case *ast.IdentifierNode:
|
||||
if _, ok := v.Analyzers[n.Value]; ok {
|
||||
v.add(n.Value, false)
|
||||
}
|
||||
case *ast.MemberNode:
|
||||
path := memberPath(n)
|
||||
if len(path) == 0 {
|
||||
return
|
||||
}
|
||||
name := path[0]
|
||||
if _, ok := v.Analyzers[name]; !ok {
|
||||
return
|
||||
}
|
||||
v.add(name, len(path) > 1 && isResponseSideAnalyzerPath(path[1]))
|
||||
}
|
||||
}
|
||||
|
||||
func (v *analyzerRefVisitor) add(name string, responseSide bool) {
|
||||
ref := v.Refs[name]
|
||||
ref.ResponseSide = ref.ResponseSide || responseSide
|
||||
v.Refs[name] = ref
|
||||
}
|
||||
|
||||
func isResponseSideAnalyzerPath(name string) bool {
|
||||
switch name {
|
||||
case "resp", "server", "answers", "response":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// idPatcher patches the AST during expr compilation, replacing certain values with
|
||||
// their internal representations for better runtime performance.
|
||||
type idPatcher struct {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package ruleset
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -12,6 +13,13 @@ import (
|
||||
"github.com/expr-lang/expr/parser"
|
||||
)
|
||||
|
||||
type testAnalyzer struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (a testAnalyzer) Name() string { return a.name }
|
||||
func (a testAnalyzer) Limit() int { return 0 }
|
||||
|
||||
func TestExtractGeoSiteConditions(t *testing.T) {
|
||||
expression := `
|
||||
(geosite(tls.req.sni, "openai") || geosite(quic.req.sni, "OpenAI")) &&
|
||||
@@ -88,3 +96,93 @@ func TestIDPatcher_PatchesGeoSiteORChainToGeoSiteSet(t *testing.T) {
|
||||
t.Fatalf("expected OR chain to be collapsed, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileExprRulesPrunesUnusedAnalyzers(t *testing.T) {
|
||||
rs, err := CompileExprRules([]ExprRule{
|
||||
{Name: "network-only", Action: "allow", Expr: `proto == "tcp" && port.dst == 443`},
|
||||
}, []analyzer.Analyzer{testAnalyzer{name: "tls"}, testAnalyzer{name: "quic"}}, nil, &BuiltinConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("CompileExprRules error: %v", err)
|
||||
}
|
||||
exprRS := rs.(*exprRuleset)
|
||||
if len(exprRS.Ans) != 0 {
|
||||
t.Fatalf("expected no analyzers for network-only rule, got %d", len(exprRS.Ans))
|
||||
}
|
||||
if exprRS.Rules[0].Native == nil {
|
||||
t.Fatalf("expected network-only rule to compile to native matcher")
|
||||
}
|
||||
got := rs.Match(StreamInfo{Protocol: ProtocolTCP, DstPort: 443})
|
||||
if got.Action != ActionAllow {
|
||||
t.Fatalf("native match action=%v want=%v", got.Action, ActionAllow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileExprRulesKeepsReferencedAnalyzersOnly(t *testing.T) {
|
||||
rs, err := CompileExprRules([]ExprRule{
|
||||
{Name: "tls-only", Action: "allow", Expr: `tls != nil && tls.req != nil && tls.req.sni == "example.com"`},
|
||||
}, []analyzer.Analyzer{testAnalyzer{name: "tls"}, testAnalyzer{name: "quic"}}, nil, &BuiltinConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("CompileExprRules error: %v", err)
|
||||
}
|
||||
exprRS := rs.(*exprRuleset)
|
||||
if len(exprRS.Ans) != 1 || exprRS.Ans[0].Name() != "tls" {
|
||||
t.Fatalf("expected only tls analyzer, got %#v", exprRS.Ans)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNativeCIDRMatcher(t *testing.T) {
|
||||
funcMap, geoMatcher := buildFunctionMapForTest()
|
||||
n := compileNativeExpr(`cidr(ip.src, "192.168.1.0/24") && port.dst >= 80 && port.dst <= 443`, funcMap, geoMatcher)
|
||||
if n == nil {
|
||||
t.Fatal("expected native matcher")
|
||||
}
|
||||
if !n.Match(StreamInfo{SrcIP: net.ParseIP("192.168.1.10"), DstPort: 443}) {
|
||||
t.Fatal("expected native CIDR matcher to match")
|
||||
}
|
||||
if n.Match(StreamInfo{SrcIP: net.ParseIP("10.0.0.1"), DstPort: 443}) {
|
||||
t.Fatal("expected native CIDR matcher not to match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanFinalizeAfterLogForRequestOnlyActionRules(t *testing.T) {
|
||||
rs, err := CompileExprRules([]ExprRule{
|
||||
{Name: "log-host", Log: true, Expr: `tls != nil && tls.req != nil && tls.req.sni != nil`},
|
||||
{Name: "block-bad-host", Action: "block", Expr: `tls != nil && tls.req != nil && tls.req.sni == "bad.example"`},
|
||||
}, []analyzer.Analyzer{testAnalyzer{name: "tls"}}, nil, &BuiltinConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("CompileExprRules error: %v", err)
|
||||
}
|
||||
|
||||
info := StreamInfo{
|
||||
Props: analyzer.CombinedPropMap{
|
||||
"tls": analyzer.PropMap{"req": analyzer.PropMap{"sni": "good.example"}},
|
||||
},
|
||||
}
|
||||
if !rs.(LogFinalizer).CanFinalizeAfterLog(info, []string{"tls"}) {
|
||||
t.Fatal("expected request-only rules to allow log finalization once request props exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanFinalizeAfterLogWaitsForResponseActionRules(t *testing.T) {
|
||||
rs, err := CompileExprRules([]ExprRule{
|
||||
{Name: "log-host", Log: true, Expr: `tls != nil && tls.req != nil && tls.req.sni != nil`},
|
||||
{Name: "block-response", Action: "block", Expr: `tls != nil && tls.resp != nil && tls.resp.cipher_suite == "bad"`},
|
||||
}, []analyzer.Analyzer{testAnalyzer{name: "tls"}}, nil, &BuiltinConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("CompileExprRules error: %v", err)
|
||||
}
|
||||
|
||||
info := StreamInfo{
|
||||
Props: analyzer.CombinedPropMap{
|
||||
"tls": analyzer.PropMap{"req": analyzer.PropMap{"sni": "good.example"}},
|
||||
},
|
||||
}
|
||||
if rs.(LogFinalizer).CanFinalizeAfterLog(info, []string{"tls"}) {
|
||||
t.Fatal("expected response-side rule to keep inspection open")
|
||||
}
|
||||
}
|
||||
|
||||
func buildFunctionMapForTest() (map[string]*Function, *geo.GeoMatcher) {
|
||||
m, g := buildFunctionMap(&BuiltinConfig{}, nil)
|
||||
return m, g
|
||||
}
|
||||
|
||||
@@ -85,6 +85,7 @@ func (i StreamInfo) DstString() string {
|
||||
type MatchResult struct {
|
||||
Action Action
|
||||
ModInstance modifier.Instance
|
||||
Logged bool
|
||||
}
|
||||
|
||||
type Ruleset interface {
|
||||
@@ -96,6 +97,10 @@ type Ruleset interface {
|
||||
Match(StreamInfo) MatchResult
|
||||
}
|
||||
|
||||
type LogFinalizer interface {
|
||||
CanFinalizeAfterLog(StreamInfo, []string) bool
|
||||
}
|
||||
|
||||
type Stats struct {
|
||||
MatchCalls uint64
|
||||
MatchErrors uint64
|
||||
|
||||
@@ -0,0 +1,273 @@
|
||||
package ruleset
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/expr-lang/expr/ast"
|
||||
"github.com/expr-lang/expr/parser"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo"
|
||||
)
|
||||
|
||||
type nativeExpr interface {
|
||||
Match(StreamInfo) bool
|
||||
}
|
||||
|
||||
type nativeBoolFunc func(StreamInfo) bool
|
||||
|
||||
func (f nativeBoolFunc) Match(info StreamInfo) bool {
|
||||
return f(info)
|
||||
}
|
||||
|
||||
type nativeValueFunc func(StreamInfo) (any, bool)
|
||||
|
||||
func compileNativeExpr(expression string, funcMap map[string]*Function, gm *geo.GeoMatcher) nativeExpr {
|
||||
tree, err := parser.Parse(expression)
|
||||
if err != nil || tree == nil || tree.Node == nil {
|
||||
return nil
|
||||
}
|
||||
root := tree.Node
|
||||
patcher := &idPatcher{FuncMap: funcMap, GeoMatcher: gm}
|
||||
ast.Walk(&root, patcher)
|
||||
if patcher.Err != nil {
|
||||
return nil
|
||||
}
|
||||
return compileNativeBool(root)
|
||||
}
|
||||
|
||||
func compileNativeBool(node ast.Node) nativeExpr {
|
||||
switch n := node.(type) {
|
||||
case *ast.BinaryNode:
|
||||
switch n.Operator {
|
||||
case "&&", "and":
|
||||
left := compileNativeBool(n.Left)
|
||||
right := compileNativeBool(n.Right)
|
||||
if left == nil || right == nil {
|
||||
return nil
|
||||
}
|
||||
return nativeBoolFunc(func(info StreamInfo) bool {
|
||||
return left.Match(info) && right.Match(info)
|
||||
})
|
||||
case "||", "or":
|
||||
left := compileNativeBool(n.Left)
|
||||
right := compileNativeBool(n.Right)
|
||||
if left == nil || right == nil {
|
||||
return nil
|
||||
}
|
||||
return nativeBoolFunc(func(info StreamInfo) bool {
|
||||
return left.Match(info) || right.Match(info)
|
||||
})
|
||||
case "==", "!=", ">", ">=", "<", "<=":
|
||||
left := compileNativeValue(n.Left)
|
||||
right := compileNativeValue(n.Right)
|
||||
if left == nil || right == nil {
|
||||
return nil
|
||||
}
|
||||
op := n.Operator
|
||||
return nativeBoolFunc(func(info StreamInfo) bool {
|
||||
lv, lok := left(info)
|
||||
rv, rok := right(info)
|
||||
if !lok || !rok {
|
||||
return false
|
||||
}
|
||||
result, ok := compareNativeValues(lv, rv, op)
|
||||
return ok && result
|
||||
})
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
case *ast.UnaryNode:
|
||||
if n.Operator != "!" && n.Operator != "not" {
|
||||
return nil
|
||||
}
|
||||
child := compileNativeBool(n.Node)
|
||||
if child == nil {
|
||||
return nil
|
||||
}
|
||||
return nativeBoolFunc(func(info StreamInfo) bool {
|
||||
return !child.Match(info)
|
||||
})
|
||||
case *ast.CallNode:
|
||||
return compileNativeCall(n)
|
||||
case *ast.BoolNode:
|
||||
value := n.Value
|
||||
return nativeBoolFunc(func(StreamInfo) bool { return value })
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func compileNativeCall(n *ast.CallNode) nativeExpr {
|
||||
id, ok := n.Callee.(*ast.IdentifierNode)
|
||||
if !ok || strings.ToLower(id.Value) != "cidr" || len(n.Arguments) != 2 {
|
||||
return nil
|
||||
}
|
||||
ipValue := compileNativeValue(n.Arguments[0])
|
||||
if ipValue == nil {
|
||||
return nil
|
||||
}
|
||||
var cidr *net.IPNet
|
||||
switch arg := n.Arguments[1].(type) {
|
||||
case *ast.ConstantNode:
|
||||
cidr, _ = arg.Value.(*net.IPNet)
|
||||
case *ast.StringNode:
|
||||
_, parsed, err := net.ParseCIDR(arg.Value)
|
||||
if err == nil {
|
||||
cidr = parsed
|
||||
}
|
||||
}
|
||||
if cidr == nil {
|
||||
return nil
|
||||
}
|
||||
return nativeBoolFunc(func(info StreamInfo) bool {
|
||||
value, ok := ipValue(info)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case net.IP:
|
||||
return cidr.Contains(v)
|
||||
case string:
|
||||
ip := net.ParseIP(v)
|
||||
return ip != nil && cidr.Contains(ip)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func compileNativeValue(node ast.Node) nativeValueFunc {
|
||||
switch n := node.(type) {
|
||||
case *ast.StringNode:
|
||||
value := n.Value
|
||||
return func(StreamInfo) (any, bool) { return value, true }
|
||||
case *ast.IntegerNode:
|
||||
value := int64(n.Value)
|
||||
return func(StreamInfo) (any, bool) { return value, true }
|
||||
case *ast.IdentifierNode:
|
||||
switch strings.ToLower(n.Value) {
|
||||
case "proto":
|
||||
return func(info StreamInfo) (any, bool) { return info.Protocol.String(), true }
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
case *ast.MemberNode:
|
||||
return compileNativeMember(n)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func compileNativeMember(n *ast.MemberNode) nativeValueFunc {
|
||||
path := memberPath(n)
|
||||
switch strings.Join(path, ".") {
|
||||
case "mac.src":
|
||||
return func(info StreamInfo) (any, bool) { return info.SrcMAC.String(), true }
|
||||
case "mac.dst":
|
||||
return func(info StreamInfo) (any, bool) { return info.DstMAC.String(), true }
|
||||
case "ip.src":
|
||||
return func(info StreamInfo) (any, bool) { return info.SrcIP, info.SrcIP != nil }
|
||||
case "ip.dst":
|
||||
return func(info StreamInfo) (any, bool) { return info.DstIP, info.DstIP != nil }
|
||||
case "port.src":
|
||||
return func(info StreamInfo) (any, bool) { return int64(info.SrcPort), true }
|
||||
case "port.dst":
|
||||
return func(info StreamInfo) (any, bool) { return int64(info.DstPort), true }
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func memberPath(node ast.Node) []string {
|
||||
switch n := node.(type) {
|
||||
case *ast.IdentifierNode:
|
||||
return []string{strings.ToLower(n.Value)}
|
||||
case *ast.MemberNode:
|
||||
base := memberPath(n.Node)
|
||||
prop, ok := n.Property.(*ast.StringNode)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return append(base, strings.ToLower(prop.Value))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func compareNativeValues(left, right any, op string) (bool, bool) {
|
||||
if li, lok := nativeInt(left); lok {
|
||||
ri, rok := nativeInt(right)
|
||||
if !rok {
|
||||
return false, false
|
||||
}
|
||||
return compareNativeOrdered(li, ri, op), true
|
||||
}
|
||||
ls, lok := nativeString(left)
|
||||
if !lok {
|
||||
return false, false
|
||||
}
|
||||
rs, rok := nativeString(right)
|
||||
if !rok {
|
||||
return false, false
|
||||
}
|
||||
switch op {
|
||||
case "==":
|
||||
return ls == rs, true
|
||||
case "!=":
|
||||
return ls != rs, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
|
||||
func compareNativeOrdered(left, right int64, op string) bool {
|
||||
switch op {
|
||||
case "==":
|
||||
return left == right
|
||||
case "!=":
|
||||
return left != right
|
||||
case ">":
|
||||
return left > right
|
||||
case ">=":
|
||||
return left >= right
|
||||
case "<":
|
||||
return left < right
|
||||
case "<=":
|
||||
return left <= right
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func nativeInt(v any) (int64, bool) {
|
||||
switch n := v.(type) {
|
||||
case int:
|
||||
return int64(n), true
|
||||
case int64:
|
||||
return n, true
|
||||
case uint16:
|
||||
return int64(n), true
|
||||
case *ast.IntegerNode:
|
||||
return int64(n.Value), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func nativeString(v any) (string, bool) {
|
||||
switch s := v.(type) {
|
||||
case string:
|
||||
return s, true
|
||||
case net.IP:
|
||||
if s == nil {
|
||||
return "", false
|
||||
}
|
||||
return s.String(), true
|
||||
case int64:
|
||||
return strconv.FormatInt(s, 10), true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user