274 lines
6.0 KiB
Go
274 lines
6.0 KiB
Go
package ruleset
|
|
|
|
import (
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/expr-lang/expr/ast"
|
|
"github.com/expr-lang/expr/parser"
|
|
|
|
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo"
|
|
)
|
|
|
|
type nativeExpr interface {
|
|
Match(StreamInfo) bool
|
|
}
|
|
|
|
type nativeBoolFunc func(StreamInfo) bool
|
|
|
|
func (f nativeBoolFunc) Match(info StreamInfo) bool {
|
|
return f(info)
|
|
}
|
|
|
|
type nativeValueFunc func(StreamInfo) (any, bool)
|
|
|
|
func compileNativeExpr(expression string, funcMap map[string]*Function, gm *geo.GeoMatcher) nativeExpr {
|
|
tree, err := parser.Parse(expression)
|
|
if err != nil || tree == nil || tree.Node == nil {
|
|
return nil
|
|
}
|
|
root := tree.Node
|
|
patcher := &idPatcher{FuncMap: funcMap, GeoMatcher: gm}
|
|
ast.Walk(&root, patcher)
|
|
if patcher.Err != nil {
|
|
return nil
|
|
}
|
|
return compileNativeBool(root)
|
|
}
|
|
|
|
func compileNativeBool(node ast.Node) nativeExpr {
|
|
switch n := node.(type) {
|
|
case *ast.BinaryNode:
|
|
switch n.Operator {
|
|
case "&&", "and":
|
|
left := compileNativeBool(n.Left)
|
|
right := compileNativeBool(n.Right)
|
|
if left == nil || right == nil {
|
|
return nil
|
|
}
|
|
return nativeBoolFunc(func(info StreamInfo) bool {
|
|
return left.Match(info) && right.Match(info)
|
|
})
|
|
case "||", "or":
|
|
left := compileNativeBool(n.Left)
|
|
right := compileNativeBool(n.Right)
|
|
if left == nil || right == nil {
|
|
return nil
|
|
}
|
|
return nativeBoolFunc(func(info StreamInfo) bool {
|
|
return left.Match(info) || right.Match(info)
|
|
})
|
|
case "==", "!=", ">", ">=", "<", "<=":
|
|
left := compileNativeValue(n.Left)
|
|
right := compileNativeValue(n.Right)
|
|
if left == nil || right == nil {
|
|
return nil
|
|
}
|
|
op := n.Operator
|
|
return nativeBoolFunc(func(info StreamInfo) bool {
|
|
lv, lok := left(info)
|
|
rv, rok := right(info)
|
|
if !lok || !rok {
|
|
return false
|
|
}
|
|
result, ok := compareNativeValues(lv, rv, op)
|
|
return ok && result
|
|
})
|
|
default:
|
|
return nil
|
|
}
|
|
case *ast.UnaryNode:
|
|
if n.Operator != "!" && n.Operator != "not" {
|
|
return nil
|
|
}
|
|
child := compileNativeBool(n.Node)
|
|
if child == nil {
|
|
return nil
|
|
}
|
|
return nativeBoolFunc(func(info StreamInfo) bool {
|
|
return !child.Match(info)
|
|
})
|
|
case *ast.CallNode:
|
|
return compileNativeCall(n)
|
|
case *ast.BoolNode:
|
|
value := n.Value
|
|
return nativeBoolFunc(func(StreamInfo) bool { return value })
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func compileNativeCall(n *ast.CallNode) nativeExpr {
|
|
id, ok := n.Callee.(*ast.IdentifierNode)
|
|
if !ok || strings.ToLower(id.Value) != "cidr" || len(n.Arguments) != 2 {
|
|
return nil
|
|
}
|
|
ipValue := compileNativeValue(n.Arguments[0])
|
|
if ipValue == nil {
|
|
return nil
|
|
}
|
|
var cidr *net.IPNet
|
|
switch arg := n.Arguments[1].(type) {
|
|
case *ast.ConstantNode:
|
|
cidr, _ = arg.Value.(*net.IPNet)
|
|
case *ast.StringNode:
|
|
_, parsed, err := net.ParseCIDR(arg.Value)
|
|
if err == nil {
|
|
cidr = parsed
|
|
}
|
|
}
|
|
if cidr == nil {
|
|
return nil
|
|
}
|
|
return nativeBoolFunc(func(info StreamInfo) bool {
|
|
value, ok := ipValue(info)
|
|
if !ok {
|
|
return false
|
|
}
|
|
switch v := value.(type) {
|
|
case net.IP:
|
|
return cidr.Contains(v)
|
|
case string:
|
|
ip := net.ParseIP(v)
|
|
return ip != nil && cidr.Contains(ip)
|
|
default:
|
|
return false
|
|
}
|
|
})
|
|
}
|
|
|
|
func compileNativeValue(node ast.Node) nativeValueFunc {
|
|
switch n := node.(type) {
|
|
case *ast.StringNode:
|
|
value := n.Value
|
|
return func(StreamInfo) (any, bool) { return value, true }
|
|
case *ast.IntegerNode:
|
|
value := int64(n.Value)
|
|
return func(StreamInfo) (any, bool) { return value, true }
|
|
case *ast.IdentifierNode:
|
|
switch strings.ToLower(n.Value) {
|
|
case "proto":
|
|
return func(info StreamInfo) (any, bool) { return info.Protocol.String(), true }
|
|
default:
|
|
return nil
|
|
}
|
|
case *ast.MemberNode:
|
|
return compileNativeMember(n)
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func compileNativeMember(n *ast.MemberNode) nativeValueFunc {
|
|
path := memberPath(n)
|
|
switch strings.Join(path, ".") {
|
|
case "mac.src":
|
|
return func(info StreamInfo) (any, bool) { return info.SrcMAC.String(), true }
|
|
case "mac.dst":
|
|
return func(info StreamInfo) (any, bool) { return info.DstMAC.String(), true }
|
|
case "ip.src":
|
|
return func(info StreamInfo) (any, bool) { return info.SrcIP, info.SrcIP != nil }
|
|
case "ip.dst":
|
|
return func(info StreamInfo) (any, bool) { return info.DstIP, info.DstIP != nil }
|
|
case "port.src":
|
|
return func(info StreamInfo) (any, bool) { return int64(info.SrcPort), true }
|
|
case "port.dst":
|
|
return func(info StreamInfo) (any, bool) { return int64(info.DstPort), true }
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func memberPath(node ast.Node) []string {
|
|
switch n := node.(type) {
|
|
case *ast.IdentifierNode:
|
|
return []string{strings.ToLower(n.Value)}
|
|
case *ast.MemberNode:
|
|
base := memberPath(n.Node)
|
|
prop, ok := n.Property.(*ast.StringNode)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
return append(base, strings.ToLower(prop.Value))
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func compareNativeValues(left, right any, op string) (bool, bool) {
|
|
if li, lok := nativeInt(left); lok {
|
|
ri, rok := nativeInt(right)
|
|
if !rok {
|
|
return false, false
|
|
}
|
|
return compareNativeOrdered(li, ri, op), true
|
|
}
|
|
ls, lok := nativeString(left)
|
|
if !lok {
|
|
return false, false
|
|
}
|
|
rs, rok := nativeString(right)
|
|
if !rok {
|
|
return false, false
|
|
}
|
|
switch op {
|
|
case "==":
|
|
return ls == rs, true
|
|
case "!=":
|
|
return ls != rs, true
|
|
default:
|
|
return false, false
|
|
}
|
|
}
|
|
|
|
func compareNativeOrdered(left, right int64, op string) bool {
|
|
switch op {
|
|
case "==":
|
|
return left == right
|
|
case "!=":
|
|
return left != right
|
|
case ">":
|
|
return left > right
|
|
case ">=":
|
|
return left >= right
|
|
case "<":
|
|
return left < right
|
|
case "<=":
|
|
return left <= right
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func nativeInt(v any) (int64, bool) {
|
|
switch n := v.(type) {
|
|
case int:
|
|
return int64(n), true
|
|
case int64:
|
|
return n, true
|
|
case uint16:
|
|
return int64(n), true
|
|
case *ast.IntegerNode:
|
|
return int64(n.Value), true
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|
|
|
|
func nativeString(v any) (string, bool) {
|
|
switch s := v.(type) {
|
|
case string:
|
|
return s, true
|
|
case net.IP:
|
|
if s == nil {
|
|
return "", false
|
|
}
|
|
return s.String(), true
|
|
case int64:
|
|
return strconv.FormatInt(s, 10), true
|
|
default:
|
|
return "", false
|
|
}
|
|
}
|