engine: more performance improvements
This commit is contained in:
@@ -0,0 +1,273 @@
|
||||
package ruleset
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/expr-lang/expr/ast"
|
||||
"github.com/expr-lang/expr/parser"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo"
|
||||
)
|
||||
|
||||
type nativeExpr interface {
|
||||
Match(StreamInfo) bool
|
||||
}
|
||||
|
||||
type nativeBoolFunc func(StreamInfo) bool
|
||||
|
||||
func (f nativeBoolFunc) Match(info StreamInfo) bool {
|
||||
return f(info)
|
||||
}
|
||||
|
||||
type nativeValueFunc func(StreamInfo) (any, bool)
|
||||
|
||||
func compileNativeExpr(expression string, funcMap map[string]*Function, gm *geo.GeoMatcher) nativeExpr {
|
||||
tree, err := parser.Parse(expression)
|
||||
if err != nil || tree == nil || tree.Node == nil {
|
||||
return nil
|
||||
}
|
||||
root := tree.Node
|
||||
patcher := &idPatcher{FuncMap: funcMap, GeoMatcher: gm}
|
||||
ast.Walk(&root, patcher)
|
||||
if patcher.Err != nil {
|
||||
return nil
|
||||
}
|
||||
return compileNativeBool(root)
|
||||
}
|
||||
|
||||
func compileNativeBool(node ast.Node) nativeExpr {
|
||||
switch n := node.(type) {
|
||||
case *ast.BinaryNode:
|
||||
switch n.Operator {
|
||||
case "&&", "and":
|
||||
left := compileNativeBool(n.Left)
|
||||
right := compileNativeBool(n.Right)
|
||||
if left == nil || right == nil {
|
||||
return nil
|
||||
}
|
||||
return nativeBoolFunc(func(info StreamInfo) bool {
|
||||
return left.Match(info) && right.Match(info)
|
||||
})
|
||||
case "||", "or":
|
||||
left := compileNativeBool(n.Left)
|
||||
right := compileNativeBool(n.Right)
|
||||
if left == nil || right == nil {
|
||||
return nil
|
||||
}
|
||||
return nativeBoolFunc(func(info StreamInfo) bool {
|
||||
return left.Match(info) || right.Match(info)
|
||||
})
|
||||
case "==", "!=", ">", ">=", "<", "<=":
|
||||
left := compileNativeValue(n.Left)
|
||||
right := compileNativeValue(n.Right)
|
||||
if left == nil || right == nil {
|
||||
return nil
|
||||
}
|
||||
op := n.Operator
|
||||
return nativeBoolFunc(func(info StreamInfo) bool {
|
||||
lv, lok := left(info)
|
||||
rv, rok := right(info)
|
||||
if !lok || !rok {
|
||||
return false
|
||||
}
|
||||
result, ok := compareNativeValues(lv, rv, op)
|
||||
return ok && result
|
||||
})
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
case *ast.UnaryNode:
|
||||
if n.Operator != "!" && n.Operator != "not" {
|
||||
return nil
|
||||
}
|
||||
child := compileNativeBool(n.Node)
|
||||
if child == nil {
|
||||
return nil
|
||||
}
|
||||
return nativeBoolFunc(func(info StreamInfo) bool {
|
||||
return !child.Match(info)
|
||||
})
|
||||
case *ast.CallNode:
|
||||
return compileNativeCall(n)
|
||||
case *ast.BoolNode:
|
||||
value := n.Value
|
||||
return nativeBoolFunc(func(StreamInfo) bool { return value })
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func compileNativeCall(n *ast.CallNode) nativeExpr {
|
||||
id, ok := n.Callee.(*ast.IdentifierNode)
|
||||
if !ok || strings.ToLower(id.Value) != "cidr" || len(n.Arguments) != 2 {
|
||||
return nil
|
||||
}
|
||||
ipValue := compileNativeValue(n.Arguments[0])
|
||||
if ipValue == nil {
|
||||
return nil
|
||||
}
|
||||
var cidr *net.IPNet
|
||||
switch arg := n.Arguments[1].(type) {
|
||||
case *ast.ConstantNode:
|
||||
cidr, _ = arg.Value.(*net.IPNet)
|
||||
case *ast.StringNode:
|
||||
_, parsed, err := net.ParseCIDR(arg.Value)
|
||||
if err == nil {
|
||||
cidr = parsed
|
||||
}
|
||||
}
|
||||
if cidr == nil {
|
||||
return nil
|
||||
}
|
||||
return nativeBoolFunc(func(info StreamInfo) bool {
|
||||
value, ok := ipValue(info)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case net.IP:
|
||||
return cidr.Contains(v)
|
||||
case string:
|
||||
ip := net.ParseIP(v)
|
||||
return ip != nil && cidr.Contains(ip)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func compileNativeValue(node ast.Node) nativeValueFunc {
|
||||
switch n := node.(type) {
|
||||
case *ast.StringNode:
|
||||
value := n.Value
|
||||
return func(StreamInfo) (any, bool) { return value, true }
|
||||
case *ast.IntegerNode:
|
||||
value := int64(n.Value)
|
||||
return func(StreamInfo) (any, bool) { return value, true }
|
||||
case *ast.IdentifierNode:
|
||||
switch strings.ToLower(n.Value) {
|
||||
case "proto":
|
||||
return func(info StreamInfo) (any, bool) { return info.Protocol.String(), true }
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
case *ast.MemberNode:
|
||||
return compileNativeMember(n)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func compileNativeMember(n *ast.MemberNode) nativeValueFunc {
|
||||
path := memberPath(n)
|
||||
switch strings.Join(path, ".") {
|
||||
case "mac.src":
|
||||
return func(info StreamInfo) (any, bool) { return info.SrcMAC.String(), true }
|
||||
case "mac.dst":
|
||||
return func(info StreamInfo) (any, bool) { return info.DstMAC.String(), true }
|
||||
case "ip.src":
|
||||
return func(info StreamInfo) (any, bool) { return info.SrcIP, info.SrcIP != nil }
|
||||
case "ip.dst":
|
||||
return func(info StreamInfo) (any, bool) { return info.DstIP, info.DstIP != nil }
|
||||
case "port.src":
|
||||
return func(info StreamInfo) (any, bool) { return int64(info.SrcPort), true }
|
||||
case "port.dst":
|
||||
return func(info StreamInfo) (any, bool) { return int64(info.DstPort), true }
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func memberPath(node ast.Node) []string {
|
||||
switch n := node.(type) {
|
||||
case *ast.IdentifierNode:
|
||||
return []string{strings.ToLower(n.Value)}
|
||||
case *ast.MemberNode:
|
||||
base := memberPath(n.Node)
|
||||
prop, ok := n.Property.(*ast.StringNode)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return append(base, strings.ToLower(prop.Value))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func compareNativeValues(left, right any, op string) (bool, bool) {
|
||||
if li, lok := nativeInt(left); lok {
|
||||
ri, rok := nativeInt(right)
|
||||
if !rok {
|
||||
return false, false
|
||||
}
|
||||
return compareNativeOrdered(li, ri, op), true
|
||||
}
|
||||
ls, lok := nativeString(left)
|
||||
if !lok {
|
||||
return false, false
|
||||
}
|
||||
rs, rok := nativeString(right)
|
||||
if !rok {
|
||||
return false, false
|
||||
}
|
||||
switch op {
|
||||
case "==":
|
||||
return ls == rs, true
|
||||
case "!=":
|
||||
return ls != rs, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
|
||||
func compareNativeOrdered(left, right int64, op string) bool {
|
||||
switch op {
|
||||
case "==":
|
||||
return left == right
|
||||
case "!=":
|
||||
return left != right
|
||||
case ">":
|
||||
return left > right
|
||||
case ">=":
|
||||
return left >= right
|
||||
case "<":
|
||||
return left < right
|
||||
case "<=":
|
||||
return left <= right
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func nativeInt(v any) (int64, bool) {
|
||||
switch n := v.(type) {
|
||||
case int:
|
||||
return int64(n), true
|
||||
case int64:
|
||||
return n, true
|
||||
case uint16:
|
||||
return int64(n), true
|
||||
case *ast.IntegerNode:
|
||||
return int64(n.Value), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func nativeString(v any) (string, bool) {
|
||||
switch s := v.(type) {
|
||||
case string:
|
||||
return s, true
|
||||
case net.IP:
|
||||
if s == nil {
|
||||
return "", false
|
||||
}
|
||||
return s.String(), true
|
||||
case int64:
|
||||
return strconv.FormatInt(s, 10), true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user