engine: more performance improvements

This commit is contained in:
2026-05-18 09:17:24 +05:30
parent 77dba0c4fa
commit 581041b1a7
12 changed files with 1005 additions and 103 deletions
+273
View File
@@ -0,0 +1,273 @@
package ruleset
import (
"net"
"strconv"
"strings"
"github.com/expr-lang/expr/ast"
"github.com/expr-lang/expr/parser"
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo"
)
type nativeExpr interface {
Match(StreamInfo) bool
}
type nativeBoolFunc func(StreamInfo) bool
func (f nativeBoolFunc) Match(info StreamInfo) bool {
return f(info)
}
type nativeValueFunc func(StreamInfo) (any, bool)
func compileNativeExpr(expression string, funcMap map[string]*Function, gm *geo.GeoMatcher) nativeExpr {
tree, err := parser.Parse(expression)
if err != nil || tree == nil || tree.Node == nil {
return nil
}
root := tree.Node
patcher := &idPatcher{FuncMap: funcMap, GeoMatcher: gm}
ast.Walk(&root, patcher)
if patcher.Err != nil {
return nil
}
return compileNativeBool(root)
}
func compileNativeBool(node ast.Node) nativeExpr {
switch n := node.(type) {
case *ast.BinaryNode:
switch n.Operator {
case "&&", "and":
left := compileNativeBool(n.Left)
right := compileNativeBool(n.Right)
if left == nil || right == nil {
return nil
}
return nativeBoolFunc(func(info StreamInfo) bool {
return left.Match(info) && right.Match(info)
})
case "||", "or":
left := compileNativeBool(n.Left)
right := compileNativeBool(n.Right)
if left == nil || right == nil {
return nil
}
return nativeBoolFunc(func(info StreamInfo) bool {
return left.Match(info) || right.Match(info)
})
case "==", "!=", ">", ">=", "<", "<=":
left := compileNativeValue(n.Left)
right := compileNativeValue(n.Right)
if left == nil || right == nil {
return nil
}
op := n.Operator
return nativeBoolFunc(func(info StreamInfo) bool {
lv, lok := left(info)
rv, rok := right(info)
if !lok || !rok {
return false
}
result, ok := compareNativeValues(lv, rv, op)
return ok && result
})
default:
return nil
}
case *ast.UnaryNode:
if n.Operator != "!" && n.Operator != "not" {
return nil
}
child := compileNativeBool(n.Node)
if child == nil {
return nil
}
return nativeBoolFunc(func(info StreamInfo) bool {
return !child.Match(info)
})
case *ast.CallNode:
return compileNativeCall(n)
case *ast.BoolNode:
value := n.Value
return nativeBoolFunc(func(StreamInfo) bool { return value })
default:
return nil
}
}
func compileNativeCall(n *ast.CallNode) nativeExpr {
id, ok := n.Callee.(*ast.IdentifierNode)
if !ok || strings.ToLower(id.Value) != "cidr" || len(n.Arguments) != 2 {
return nil
}
ipValue := compileNativeValue(n.Arguments[0])
if ipValue == nil {
return nil
}
var cidr *net.IPNet
switch arg := n.Arguments[1].(type) {
case *ast.ConstantNode:
cidr, _ = arg.Value.(*net.IPNet)
case *ast.StringNode:
_, parsed, err := net.ParseCIDR(arg.Value)
if err == nil {
cidr = parsed
}
}
if cidr == nil {
return nil
}
return nativeBoolFunc(func(info StreamInfo) bool {
value, ok := ipValue(info)
if !ok {
return false
}
switch v := value.(type) {
case net.IP:
return cidr.Contains(v)
case string:
ip := net.ParseIP(v)
return ip != nil && cidr.Contains(ip)
default:
return false
}
})
}
func compileNativeValue(node ast.Node) nativeValueFunc {
switch n := node.(type) {
case *ast.StringNode:
value := n.Value
return func(StreamInfo) (any, bool) { return value, true }
case *ast.IntegerNode:
value := int64(n.Value)
return func(StreamInfo) (any, bool) { return value, true }
case *ast.IdentifierNode:
switch strings.ToLower(n.Value) {
case "proto":
return func(info StreamInfo) (any, bool) { return info.Protocol.String(), true }
default:
return nil
}
case *ast.MemberNode:
return compileNativeMember(n)
default:
return nil
}
}
func compileNativeMember(n *ast.MemberNode) nativeValueFunc {
path := memberPath(n)
switch strings.Join(path, ".") {
case "mac.src":
return func(info StreamInfo) (any, bool) { return info.SrcMAC.String(), true }
case "mac.dst":
return func(info StreamInfo) (any, bool) { return info.DstMAC.String(), true }
case "ip.src":
return func(info StreamInfo) (any, bool) { return info.SrcIP, info.SrcIP != nil }
case "ip.dst":
return func(info StreamInfo) (any, bool) { return info.DstIP, info.DstIP != nil }
case "port.src":
return func(info StreamInfo) (any, bool) { return int64(info.SrcPort), true }
case "port.dst":
return func(info StreamInfo) (any, bool) { return int64(info.DstPort), true }
default:
return nil
}
}
func memberPath(node ast.Node) []string {
switch n := node.(type) {
case *ast.IdentifierNode:
return []string{strings.ToLower(n.Value)}
case *ast.MemberNode:
base := memberPath(n.Node)
prop, ok := n.Property.(*ast.StringNode)
if !ok {
return nil
}
return append(base, strings.ToLower(prop.Value))
default:
return nil
}
}
func compareNativeValues(left, right any, op string) (bool, bool) {
if li, lok := nativeInt(left); lok {
ri, rok := nativeInt(right)
if !rok {
return false, false
}
return compareNativeOrdered(li, ri, op), true
}
ls, lok := nativeString(left)
if !lok {
return false, false
}
rs, rok := nativeString(right)
if !rok {
return false, false
}
switch op {
case "==":
return ls == rs, true
case "!=":
return ls != rs, true
default:
return false, false
}
}
func compareNativeOrdered(left, right int64, op string) bool {
switch op {
case "==":
return left == right
case "!=":
return left != right
case ">":
return left > right
case ">=":
return left >= right
case "<":
return left < right
case "<=":
return left <= right
default:
return false
}
}
func nativeInt(v any) (int64, bool) {
switch n := v.(type) {
case int:
return int64(n), true
case int64:
return n, true
case uint16:
return int64(n), true
case *ast.IntegerNode:
return int64(n.Value), true
default:
return 0, false
}
}
func nativeString(v any) (string, bool) {
switch s := v.(type) {
case string:
return s, true
case net.IP:
if s == nil {
return "", false
}
return s.String(), true
case int64:
return strconv.FormatInt(s, 10), true
default:
return "", false
}
}