engine: more performance improvements

This commit is contained in:
2026-05-18 09:17:24 +05:30
parent 77dba0c4fa
commit 581041b1a7
12 changed files with 1005 additions and 103 deletions
+69 -21
View File
@@ -112,23 +112,35 @@ type geositeDomain struct {
}
type geositeMatcher struct {
Domains []geositeDomain
Domains []geositeDomain // legacy slow path for tests and manual construction
Plain []geositeDomain
Regex []geositeDomain
Root map[string]geositeDomain
Full map[string]geositeDomain
// Attributes are matched using "and" logic - if you have multiple attributes here,
// a domain must have all of those attributes to be considered a match.
Attrs []string
}
func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool {
// Match attributes first
if len(m.Attrs) > 0 {
if len(domain.Attrs) == 0 {
func (m *geositeMatcher) attrsMatch(domain geositeDomain) bool {
if len(m.Attrs) == 0 {
return true
}
if len(domain.Attrs) == 0 {
return false
}
for _, attr := range m.Attrs {
if !domain.Attrs[attr] {
return false
}
for _, attr := range m.Attrs {
if !domain.Attrs[attr] {
return false
}
}
}
return true
}
func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool {
// Match attributes first
if !m.attrsMatch(domain) {
return false
}
switch domain.Type {
@@ -152,7 +164,35 @@ func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool {
}
func (m *geositeMatcher) Match(host HostInfo) bool {
for _, domain := range m.Domains {
if host.Name == "" {
return false
}
if domain, ok := m.Full[host.Name]; ok && m.attrsMatch(domain) {
return true
}
for name := host.Name; name != ""; {
if domain, ok := m.Root[name]; ok && m.attrsMatch(domain) {
return true
}
idx := strings.IndexByte(name, '.')
if idx < 0 {
break
}
name = name[idx+1:]
}
for _, domain := range m.Plain {
if m.matchDomain(domain, host) {
return true
}
}
if len(m.Plain) == 0 && len(m.Regex) == 0 && len(m.Root) == 0 && len(m.Full) == 0 {
for _, domain := range m.Domains {
if m.matchDomain(domain, host) {
return true
}
}
}
for _, domain := range m.Regex {
if m.matchDomain(domain, host) {
return true
}
@@ -161,45 +201,53 @@ func (m *geositeMatcher) Match(host HostInfo) bool {
}
func newGeositeMatcher(list *v2geo.GeoSite, attrs []string) (*geositeMatcher, error) {
domains := make([]geositeDomain, len(list.Domain))
for i, domain := range list.Domain {
matcher := &geositeMatcher{
Root: make(map[string]geositeDomain),
Full: make(map[string]geositeDomain),
Attrs: attrs,
}
for _, domain := range list.Domain {
var compiled geositeDomain
switch domain.Type {
case v2geo.Domain_Plain:
domains[i] = geositeDomain{
compiled = geositeDomain{
Type: geositeDomainPlain,
Value: domain.Value,
Attrs: domainAttributeToMap(domain.Attribute),
}
matcher.Plain = append(matcher.Plain, compiled)
case v2geo.Domain_Regex:
regex, err := regexp.Compile(domain.Value)
if err != nil {
return nil, err
}
domains[i] = geositeDomain{
compiled = geositeDomain{
Type: geositeDomainRegex,
Value: domain.Value,
Regex: regex,
Attrs: domainAttributeToMap(domain.Attribute),
}
matcher.Regex = append(matcher.Regex, compiled)
case v2geo.Domain_Full:
domains[i] = geositeDomain{
compiled = geositeDomain{
Type: geositeDomainFull,
Value: domain.Value,
Attrs: domainAttributeToMap(domain.Attribute),
}
matcher.Full[domain.Value] = compiled
case v2geo.Domain_RootDomain:
domains[i] = geositeDomain{
compiled = geositeDomain{
Type: geositeDomainRoot,
Value: domain.Value,
Attrs: domainAttributeToMap(domain.Attribute),
}
matcher.Root[domain.Value] = compiled
default:
return nil, errors.New("unsupported domain type")
}
matcher.Domains = append(matcher.Domains, compiled)
}
return &geositeMatcher{
Domains: domains,
Attrs: attrs,
}, nil
return matcher, nil
}
func domainAttributeToMap(attrs []*v2geo.Domain_Attribute) map[string]bool {
+126 -11
View File
@@ -59,6 +59,8 @@ type compiledExprRule struct {
Log bool
ModInstance modifier.Instance
Program *vm.Program
Native nativeExpr
AnalyzerRefs map[string]analyzerRuleRef
GeoSiteConditions []string
StartTimeSecs int // seconds since midnight, -1 if unset
StopTimeSecs int // seconds since midnight, -1 if unset
@@ -67,6 +69,7 @@ type compiledExprRule struct {
}
var _ Ruleset = (*exprRuleset)(nil)
var _ LogFinalizer = (*exprRuleset)(nil)
var (
envPool = sync.Pool{
@@ -102,10 +105,12 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
}()
}
env := envPool.Get().(map[string]any)
clear(env)
macMap, ipMap, portMap := populateExprEnv(env, info)
var env map[string]any
var macMap, ipMap, portMap map[string]any
releaseEnv := func() {
if env == nil {
return
}
clear(env)
envPool.Put(env)
putSubMap(macMap)
@@ -113,31 +118,45 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
putSubMap(portMap)
}
now := time.Now()
logged := false
for _, rule := range r.Rules {
if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) {
continue
}
v, err := vm.Run(rule.Program, env)
if err != nil {
if r.stats != nil {
r.stats.MatchErrors.Add(1)
matched := false
if rule.Native != nil {
matched = rule.Native.Match(info)
} else {
if env == nil {
env = envPool.Get().(map[string]any)
clear(env)
macMap, ipMap, portMap = populateExprEnv(env, info)
}
r.Logger.MatchError(info, rule.Name, err)
continue
v, err := vm.Run(rule.Program, env)
if err != nil {
if r.stats != nil {
r.stats.MatchErrors.Add(1)
}
r.Logger.MatchError(info, rule.Name, err)
continue
}
matched, _ = v.(bool)
}
if vBool, ok := v.(bool); ok && vBool {
if matched {
if rule.Log {
logInfo := info
if len(rule.GeoSiteConditions) > 0 && r.GeoMatcher != nil {
logInfo = addGeoSiteLogMetadata(logInfo, r.GeoMatcher, rule.GeoSiteConditions)
}
r.Logger.Log(logInfo, rule.Name)
logged = true
}
if rule.Action != nil {
releaseEnv()
return MatchResult{
Action: *rule.Action,
ModInstance: rule.ModInstance,
Logged: logged,
}
}
}
@@ -145,9 +164,40 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
releaseEnv()
return MatchResult{
Action: ActionMaybe,
Logged: logged,
}
}
func (r *exprRuleset) CanFinalizeAfterLog(info StreamInfo, activeAnalyzers []string) bool {
active := make(map[string]bool, len(activeAnalyzers))
for _, name := range activeAnalyzers {
active[name] = true
}
for _, rule := range r.Rules {
if rule.Action == nil {
continue
}
if *rule.Action == ActionModify {
return false
}
if rule.StartTimeSecs != -1 || rule.StopTimeSecs != -1 || len(rule.Weekdays) != 0 {
return false
}
for name, ref := range rule.AnalyzerRefs {
if !active[name] {
continue
}
if ref.ResponseSide {
return false
}
if _, ok := info.Props[name]; !ok {
return false
}
}
}
return true
}
func (r *exprRuleset) Stats() Stats {
if r == nil || r.stats == nil {
return Stats{}
@@ -242,17 +292,23 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
if err != nil {
return nil, fmt.Errorf("rule %q has invalid weekdays: %w", rule.Name, err)
}
var analyzerRefs map[string]analyzerRuleRef
if refTree, err := parser.Parse(rule.Expr); err == nil && refTree != nil {
analyzerRefs = collectAnalyzerRefs(refTree.Node, fullAnMap)
}
cr := compiledExprRule{
Name: rule.Name,
Action: action,
Log: rule.Log,
Program: program,
AnalyzerRefs: analyzerRefs,
GeoSiteConditions: extractGeoSiteConditions(rule.Expr),
StartTimeSecs: startSecs,
StopTimeSecs: stopSecs,
Weekdays: weekdays,
WeekdaysNegated: weekdaysNegated,
}
cr.Native = compileNativeExpr(rule.Expr, funcMap, geoMatcher)
if action != nil && *action == ActionModify {
mod, ok := fullModMap[rule.Modifier.Name]
if !ok {
@@ -266,9 +322,16 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
}
compiledRules = append(compiledRules, cr)
}
depAns := make([]analyzer.Analyzer, 0, len(depAnMap))
for _, a := range ans {
if depAnMap[a.Name()] != nil {
depAns = append(depAns, a)
}
}
return &exprRuleset{
Rules: compiledRules,
Ans: ans,
Ans: depAns,
Logger: config.Logger,
GeoMatcher: geoMatcher,
stats: stats,
@@ -373,6 +436,58 @@ func (v *idVisitor) Visit(node *ast.Node) {
}
}
type analyzerRuleRef struct {
ResponseSide bool
}
type analyzerRefVisitor struct {
Analyzers map[string]analyzer.Analyzer
Refs map[string]analyzerRuleRef
}
func collectAnalyzerRefs(root ast.Node, analyzers map[string]analyzer.Analyzer) map[string]analyzerRuleRef {
visitor := &analyzerRefVisitor{
Analyzers: analyzers,
Refs: make(map[string]analyzerRuleRef),
}
ast.Walk(&root, visitor)
return visitor.Refs
}
func (v *analyzerRefVisitor) Visit(node *ast.Node) {
switch n := (*node).(type) {
case *ast.IdentifierNode:
if _, ok := v.Analyzers[n.Value]; ok {
v.add(n.Value, false)
}
case *ast.MemberNode:
path := memberPath(n)
if len(path) == 0 {
return
}
name := path[0]
if _, ok := v.Analyzers[name]; !ok {
return
}
v.add(name, len(path) > 1 && isResponseSideAnalyzerPath(path[1]))
}
}
func (v *analyzerRefVisitor) add(name string, responseSide bool) {
ref := v.Refs[name]
ref.ResponseSide = ref.ResponseSide || responseSide
v.Refs[name] = ref
}
func isResponseSideAnalyzerPath(name string) bool {
switch name {
case "resp", "server", "answers", "response":
return true
default:
return false
}
}
// idPatcher patches the AST during expr compilation, replacing certain values with
// their internal representations for better runtime performance.
type idPatcher struct {
+98
View File
@@ -1,6 +1,7 @@
package ruleset
import (
"net"
"reflect"
"strings"
"testing"
@@ -12,6 +13,13 @@ import (
"github.com/expr-lang/expr/parser"
)
type testAnalyzer struct {
name string
}
func (a testAnalyzer) Name() string { return a.name }
func (a testAnalyzer) Limit() int { return 0 }
func TestExtractGeoSiteConditions(t *testing.T) {
expression := `
(geosite(tls.req.sni, "openai") || geosite(quic.req.sni, "OpenAI")) &&
@@ -88,3 +96,93 @@ func TestIDPatcher_PatchesGeoSiteORChainToGeoSiteSet(t *testing.T) {
t.Fatalf("expected OR chain to be collapsed, got %q", got)
}
}
func TestCompileExprRulesPrunesUnusedAnalyzers(t *testing.T) {
rs, err := CompileExprRules([]ExprRule{
{Name: "network-only", Action: "allow", Expr: `proto == "tcp" && port.dst == 443`},
}, []analyzer.Analyzer{testAnalyzer{name: "tls"}, testAnalyzer{name: "quic"}}, nil, &BuiltinConfig{})
if err != nil {
t.Fatalf("CompileExprRules error: %v", err)
}
exprRS := rs.(*exprRuleset)
if len(exprRS.Ans) != 0 {
t.Fatalf("expected no analyzers for network-only rule, got %d", len(exprRS.Ans))
}
if exprRS.Rules[0].Native == nil {
t.Fatalf("expected network-only rule to compile to native matcher")
}
got := rs.Match(StreamInfo{Protocol: ProtocolTCP, DstPort: 443})
if got.Action != ActionAllow {
t.Fatalf("native match action=%v want=%v", got.Action, ActionAllow)
}
}
func TestCompileExprRulesKeepsReferencedAnalyzersOnly(t *testing.T) {
rs, err := CompileExprRules([]ExprRule{
{Name: "tls-only", Action: "allow", Expr: `tls != nil && tls.req != nil && tls.req.sni == "example.com"`},
}, []analyzer.Analyzer{testAnalyzer{name: "tls"}, testAnalyzer{name: "quic"}}, nil, &BuiltinConfig{})
if err != nil {
t.Fatalf("CompileExprRules error: %v", err)
}
exprRS := rs.(*exprRuleset)
if len(exprRS.Ans) != 1 || exprRS.Ans[0].Name() != "tls" {
t.Fatalf("expected only tls analyzer, got %#v", exprRS.Ans)
}
}
func TestNativeCIDRMatcher(t *testing.T) {
funcMap, geoMatcher := buildFunctionMapForTest()
n := compileNativeExpr(`cidr(ip.src, "192.168.1.0/24") && port.dst >= 80 && port.dst <= 443`, funcMap, geoMatcher)
if n == nil {
t.Fatal("expected native matcher")
}
if !n.Match(StreamInfo{SrcIP: net.ParseIP("192.168.1.10"), DstPort: 443}) {
t.Fatal("expected native CIDR matcher to match")
}
if n.Match(StreamInfo{SrcIP: net.ParseIP("10.0.0.1"), DstPort: 443}) {
t.Fatal("expected native CIDR matcher not to match")
}
}
func TestCanFinalizeAfterLogForRequestOnlyActionRules(t *testing.T) {
rs, err := CompileExprRules([]ExprRule{
{Name: "log-host", Log: true, Expr: `tls != nil && tls.req != nil && tls.req.sni != nil`},
{Name: "block-bad-host", Action: "block", Expr: `tls != nil && tls.req != nil && tls.req.sni == "bad.example"`},
}, []analyzer.Analyzer{testAnalyzer{name: "tls"}}, nil, &BuiltinConfig{})
if err != nil {
t.Fatalf("CompileExprRules error: %v", err)
}
info := StreamInfo{
Props: analyzer.CombinedPropMap{
"tls": analyzer.PropMap{"req": analyzer.PropMap{"sni": "good.example"}},
},
}
if !rs.(LogFinalizer).CanFinalizeAfterLog(info, []string{"tls"}) {
t.Fatal("expected request-only rules to allow log finalization once request props exist")
}
}
func TestCanFinalizeAfterLogWaitsForResponseActionRules(t *testing.T) {
rs, err := CompileExprRules([]ExprRule{
{Name: "log-host", Log: true, Expr: `tls != nil && tls.req != nil && tls.req.sni != nil`},
{Name: "block-response", Action: "block", Expr: `tls != nil && tls.resp != nil && tls.resp.cipher_suite == "bad"`},
}, []analyzer.Analyzer{testAnalyzer{name: "tls"}}, nil, &BuiltinConfig{})
if err != nil {
t.Fatalf("CompileExprRules error: %v", err)
}
info := StreamInfo{
Props: analyzer.CombinedPropMap{
"tls": analyzer.PropMap{"req": analyzer.PropMap{"sni": "good.example"}},
},
}
if rs.(LogFinalizer).CanFinalizeAfterLog(info, []string{"tls"}) {
t.Fatal("expected response-side rule to keep inspection open")
}
}
func buildFunctionMapForTest() (map[string]*Function, *geo.GeoMatcher) {
m, g := buildFunctionMap(&BuiltinConfig{}, nil)
return m, g
}
+5
View File
@@ -85,6 +85,7 @@ func (i StreamInfo) DstString() string {
type MatchResult struct {
Action Action
ModInstance modifier.Instance
Logged bool
}
type Ruleset interface {
@@ -96,6 +97,10 @@ type Ruleset interface {
Match(StreamInfo) MatchResult
}
type LogFinalizer interface {
CanFinalizeAfterLog(StreamInfo, []string) bool
}
type Stats struct {
MatchCalls uint64
MatchErrors uint64
+273
View File
@@ -0,0 +1,273 @@
package ruleset
import (
"net"
"strconv"
"strings"
"github.com/expr-lang/expr/ast"
"github.com/expr-lang/expr/parser"
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo"
)
type nativeExpr interface {
Match(StreamInfo) bool
}
type nativeBoolFunc func(StreamInfo) bool
func (f nativeBoolFunc) Match(info StreamInfo) bool {
return f(info)
}
type nativeValueFunc func(StreamInfo) (any, bool)
func compileNativeExpr(expression string, funcMap map[string]*Function, gm *geo.GeoMatcher) nativeExpr {
tree, err := parser.Parse(expression)
if err != nil || tree == nil || tree.Node == nil {
return nil
}
root := tree.Node
patcher := &idPatcher{FuncMap: funcMap, GeoMatcher: gm}
ast.Walk(&root, patcher)
if patcher.Err != nil {
return nil
}
return compileNativeBool(root)
}
func compileNativeBool(node ast.Node) nativeExpr {
switch n := node.(type) {
case *ast.BinaryNode:
switch n.Operator {
case "&&", "and":
left := compileNativeBool(n.Left)
right := compileNativeBool(n.Right)
if left == nil || right == nil {
return nil
}
return nativeBoolFunc(func(info StreamInfo) bool {
return left.Match(info) && right.Match(info)
})
case "||", "or":
left := compileNativeBool(n.Left)
right := compileNativeBool(n.Right)
if left == nil || right == nil {
return nil
}
return nativeBoolFunc(func(info StreamInfo) bool {
return left.Match(info) || right.Match(info)
})
case "==", "!=", ">", ">=", "<", "<=":
left := compileNativeValue(n.Left)
right := compileNativeValue(n.Right)
if left == nil || right == nil {
return nil
}
op := n.Operator
return nativeBoolFunc(func(info StreamInfo) bool {
lv, lok := left(info)
rv, rok := right(info)
if !lok || !rok {
return false
}
result, ok := compareNativeValues(lv, rv, op)
return ok && result
})
default:
return nil
}
case *ast.UnaryNode:
if n.Operator != "!" && n.Operator != "not" {
return nil
}
child := compileNativeBool(n.Node)
if child == nil {
return nil
}
return nativeBoolFunc(func(info StreamInfo) bool {
return !child.Match(info)
})
case *ast.CallNode:
return compileNativeCall(n)
case *ast.BoolNode:
value := n.Value
return nativeBoolFunc(func(StreamInfo) bool { return value })
default:
return nil
}
}
func compileNativeCall(n *ast.CallNode) nativeExpr {
id, ok := n.Callee.(*ast.IdentifierNode)
if !ok || strings.ToLower(id.Value) != "cidr" || len(n.Arguments) != 2 {
return nil
}
ipValue := compileNativeValue(n.Arguments[0])
if ipValue == nil {
return nil
}
var cidr *net.IPNet
switch arg := n.Arguments[1].(type) {
case *ast.ConstantNode:
cidr, _ = arg.Value.(*net.IPNet)
case *ast.StringNode:
_, parsed, err := net.ParseCIDR(arg.Value)
if err == nil {
cidr = parsed
}
}
if cidr == nil {
return nil
}
return nativeBoolFunc(func(info StreamInfo) bool {
value, ok := ipValue(info)
if !ok {
return false
}
switch v := value.(type) {
case net.IP:
return cidr.Contains(v)
case string:
ip := net.ParseIP(v)
return ip != nil && cidr.Contains(ip)
default:
return false
}
})
}
func compileNativeValue(node ast.Node) nativeValueFunc {
switch n := node.(type) {
case *ast.StringNode:
value := n.Value
return func(StreamInfo) (any, bool) { return value, true }
case *ast.IntegerNode:
value := int64(n.Value)
return func(StreamInfo) (any, bool) { return value, true }
case *ast.IdentifierNode:
switch strings.ToLower(n.Value) {
case "proto":
return func(info StreamInfo) (any, bool) { return info.Protocol.String(), true }
default:
return nil
}
case *ast.MemberNode:
return compileNativeMember(n)
default:
return nil
}
}
func compileNativeMember(n *ast.MemberNode) nativeValueFunc {
path := memberPath(n)
switch strings.Join(path, ".") {
case "mac.src":
return func(info StreamInfo) (any, bool) { return info.SrcMAC.String(), true }
case "mac.dst":
return func(info StreamInfo) (any, bool) { return info.DstMAC.String(), true }
case "ip.src":
return func(info StreamInfo) (any, bool) { return info.SrcIP, info.SrcIP != nil }
case "ip.dst":
return func(info StreamInfo) (any, bool) { return info.DstIP, info.DstIP != nil }
case "port.src":
return func(info StreamInfo) (any, bool) { return int64(info.SrcPort), true }
case "port.dst":
return func(info StreamInfo) (any, bool) { return int64(info.DstPort), true }
default:
return nil
}
}
func memberPath(node ast.Node) []string {
switch n := node.(type) {
case *ast.IdentifierNode:
return []string{strings.ToLower(n.Value)}
case *ast.MemberNode:
base := memberPath(n.Node)
prop, ok := n.Property.(*ast.StringNode)
if !ok {
return nil
}
return append(base, strings.ToLower(prop.Value))
default:
return nil
}
}
func compareNativeValues(left, right any, op string) (bool, bool) {
if li, lok := nativeInt(left); lok {
ri, rok := nativeInt(right)
if !rok {
return false, false
}
return compareNativeOrdered(li, ri, op), true
}
ls, lok := nativeString(left)
if !lok {
return false, false
}
rs, rok := nativeString(right)
if !rok {
return false, false
}
switch op {
case "==":
return ls == rs, true
case "!=":
return ls != rs, true
default:
return false, false
}
}
func compareNativeOrdered(left, right int64, op string) bool {
switch op {
case "==":
return left == right
case "!=":
return left != right
case ">":
return left > right
case ">=":
return left >= right
case "<":
return left < right
case "<=":
return left <= right
default:
return false
}
}
func nativeInt(v any) (int64, bool) {
switch n := v.(type) {
case int:
return int64(n), true
case int64:
return n, true
case uint16:
return int64(n), true
case *ast.IntegerNode:
return int64(n.Value), true
default:
return 0, false
}
}
func nativeString(v any) (string, bool) {
switch s := v.(type) {
case string:
return s, true
case net.IP:
if s == nil {
return "", false
}
return s.String(), true
case int64:
return strconv.FormatInt(s, 10), true
default:
return "", false
}
}