package ruleset import ( "net" "strconv" "strings" "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/parser" "git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo" ) type nativeExpr interface { Match(StreamInfo) bool } type nativeBoolFunc func(StreamInfo) bool func (f nativeBoolFunc) Match(info StreamInfo) bool { return f(info) } type nativeValueFunc func(StreamInfo) (any, bool) func compileNativeExpr(expression string, funcMap map[string]*Function, gm *geo.GeoMatcher) nativeExpr { tree, err := parser.Parse(expression) if err != nil || tree == nil || tree.Node == nil { return nil } root := tree.Node patcher := &idPatcher{FuncMap: funcMap, GeoMatcher: gm} ast.Walk(&root, patcher) if patcher.Err != nil { return nil } return compileNativeBool(root) } func compileNativeBool(node ast.Node) nativeExpr { switch n := node.(type) { case *ast.BinaryNode: switch n.Operator { case "&&", "and": left := compileNativeBool(n.Left) right := compileNativeBool(n.Right) if left == nil || right == nil { return nil } return nativeBoolFunc(func(info StreamInfo) bool { return left.Match(info) && right.Match(info) }) case "||", "or": left := compileNativeBool(n.Left) right := compileNativeBool(n.Right) if left == nil || right == nil { return nil } return nativeBoolFunc(func(info StreamInfo) bool { return left.Match(info) || right.Match(info) }) case "==", "!=", ">", ">=", "<", "<=": left := compileNativeValue(n.Left) right := compileNativeValue(n.Right) if left == nil || right == nil { return nil } op := n.Operator return nativeBoolFunc(func(info StreamInfo) bool { lv, lok := left(info) rv, rok := right(info) if !lok || !rok { return false } result, ok := compareNativeValues(lv, rv, op) return ok && result }) default: return nil } case *ast.UnaryNode: if n.Operator != "!" && n.Operator != "not" { return nil } child := compileNativeBool(n.Node) if child == nil { return nil } return nativeBoolFunc(func(info StreamInfo) bool { return !child.Match(info) }) case *ast.CallNode: return compileNativeCall(n) case *ast.BoolNode: value := n.Value return nativeBoolFunc(func(StreamInfo) bool { return value }) default: return nil } } func compileNativeCall(n *ast.CallNode) nativeExpr { id, ok := n.Callee.(*ast.IdentifierNode) if !ok || strings.ToLower(id.Value) != "cidr" || len(n.Arguments) != 2 { return nil } ipValue := compileNativeValue(n.Arguments[0]) if ipValue == nil { return nil } var cidr *net.IPNet switch arg := n.Arguments[1].(type) { case *ast.ConstantNode: cidr, _ = arg.Value.(*net.IPNet) case *ast.StringNode: _, parsed, err := net.ParseCIDR(arg.Value) if err == nil { cidr = parsed } } if cidr == nil { return nil } return nativeBoolFunc(func(info StreamInfo) bool { value, ok := ipValue(info) if !ok { return false } switch v := value.(type) { case net.IP: return cidr.Contains(v) case string: ip := net.ParseIP(v) return ip != nil && cidr.Contains(ip) default: return false } }) } func compileNativeValue(node ast.Node) nativeValueFunc { switch n := node.(type) { case *ast.StringNode: value := n.Value return func(StreamInfo) (any, bool) { return value, true } case *ast.IntegerNode: value := int64(n.Value) return func(StreamInfo) (any, bool) { return value, true } case *ast.IdentifierNode: switch strings.ToLower(n.Value) { case "proto": return func(info StreamInfo) (any, bool) { return info.Protocol.String(), true } default: return nil } case *ast.MemberNode: return compileNativeMember(n) default: return nil } } func compileNativeMember(n *ast.MemberNode) nativeValueFunc { path := memberPath(n) switch strings.Join(path, ".") { case "mac.src": return func(info StreamInfo) (any, bool) { return info.SrcMAC.String(), true } case "mac.dst": return func(info StreamInfo) (any, bool) { return info.DstMAC.String(), true } case "ip.src": return func(info StreamInfo) (any, bool) { return info.SrcIP, info.SrcIP != nil } case "ip.dst": return func(info StreamInfo) (any, bool) { return info.DstIP, info.DstIP != nil } case "port.src": return func(info StreamInfo) (any, bool) { return int64(info.SrcPort), true } case "port.dst": return func(info StreamInfo) (any, bool) { return int64(info.DstPort), true } default: return nil } } func memberPath(node ast.Node) []string { switch n := node.(type) { case *ast.IdentifierNode: return []string{strings.ToLower(n.Value)} case *ast.MemberNode: base := memberPath(n.Node) prop, ok := n.Property.(*ast.StringNode) if !ok { return nil } return append(base, strings.ToLower(prop.Value)) default: return nil } } func compareNativeValues(left, right any, op string) (bool, bool) { if li, lok := nativeInt(left); lok { ri, rok := nativeInt(right) if !rok { return false, false } return compareNativeOrdered(li, ri, op), true } ls, lok := nativeString(left) if !lok { return false, false } rs, rok := nativeString(right) if !rok { return false, false } switch op { case "==": return ls == rs, true case "!=": return ls != rs, true default: return false, false } } func compareNativeOrdered(left, right int64, op string) bool { switch op { case "==": return left == right case "!=": return left != right case ">": return left > right case ">=": return left >= right case "<": return left < right case "<=": return left <= right default: return false } } func nativeInt(v any) (int64, bool) { switch n := v.(type) { case int: return int64(n), true case int64: return n, true case uint16: return int64(n), true case *ast.IntegerNode: return int64(n.Value), true default: return 0, false } } func nativeString(v any) (string, bool) { switch s := v.(type) { case string: return s, true case net.IP: if s == nil { return "", false } return s.String(), true case int64: return strconv.FormatInt(s, 10), true default: return "", false } }