package ruleset import ( "context" "fmt" "net" "os" "reflect" "strings" "sync" "time" "github.com/expr-lang/expr/builtin" "github.com/expr-lang/expr" "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/parser" "github.com/expr-lang/expr/vm" "gopkg.in/yaml.v3" "git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/modifier" "git.difuse.io/Difuse/Mellaris/ruleset/builtins" "git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo" ) // ExprRule is the external representation of an expression rule. type ExprRule struct { Name string `yaml:"name"` Action string `yaml:"action"` Log bool `yaml:"log"` Modifier ModifierEntry `yaml:"modifier"` Expr string `yaml:"expr"` StartTime string `yaml:"start_time"` StopTime string `yaml:"stop_time"` Weekdays []string `yaml:"weekdays"` } type ModifierEntry struct { Name string `yaml:"name"` Args map[string]interface{} `yaml:"args"` } func ExprRulesFromYAML(file string) ([]ExprRule, error) { bs, err := os.ReadFile(file) if err != nil { return nil, err } var rules []ExprRule err = yaml.Unmarshal(bs, &rules) return rules, err } // compiledExprRule is the internal, compiled representation of an expression rule. type compiledExprRule struct { Name string Action *Action // fallthrough if nil Log bool ModInstance modifier.Instance Program *vm.Program GeoSiteConditions []string StartTimeSecs int // seconds since midnight, -1 if unset StopTimeSecs int // seconds since midnight, -1 if unset Weekdays []time.Weekday WeekdaysNegated bool } var _ Ruleset = (*exprRuleset)(nil) var ( envPool = sync.Pool{ New: func() any { return make(map[string]any, 16) }, } subMapPool = sync.Pool{ New: func() any { return make(map[string]any, 8) }, } ) type exprRuleset struct { Rules []compiledExprRule Ans []analyzer.Analyzer Logger Logger GeoMatcher *geo.GeoMatcher stats *statsCounters } func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer { return r.Ans } func (r *exprRuleset) Match(info StreamInfo) MatchResult { start := time.Now() if r.stats != nil { r.stats.MatchCalls.Add(1) defer func() { r.stats.MatchLatencyNanos.Add(uint64(time.Since(start).Nanoseconds())) }() } env := envPool.Get().(map[string]any) clear(env) macMap, ipMap, portMap := populateExprEnv(env, info) releaseEnv := func() { clear(env) envPool.Put(env) putSubMap(macMap) putSubMap(ipMap) putSubMap(portMap) } now := time.Now() for _, rule := range r.Rules { if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) { continue } v, err := vm.Run(rule.Program, env) if err != nil { if r.stats != nil { r.stats.MatchErrors.Add(1) } r.Logger.MatchError(info, rule.Name, err) continue } if vBool, ok := v.(bool); ok && vBool { if rule.Log { logInfo := info if len(rule.GeoSiteConditions) > 0 && r.GeoMatcher != nil { logInfo = addGeoSiteLogMetadata(logInfo, r.GeoMatcher, rule.GeoSiteConditions) } r.Logger.Log(logInfo, rule.Name) } if rule.Action != nil { releaseEnv() return MatchResult{ Action: *rule.Action, ModInstance: rule.ModInstance, } } } } releaseEnv() return MatchResult{ Action: ActionMaybe, } } func (r *exprRuleset) Stats() Stats { if r == nil || r.stats == nil { return Stats{} } return Stats{ MatchCalls: r.stats.MatchCalls.Load(), MatchErrors: r.stats.MatchErrors.Load(), MatchLatencyNanos: r.stats.MatchLatencyNanos.Load(), LookupCalls: r.stats.LookupCalls.Load(), LookupErrors: r.stats.LookupErrors.Load(), LookupLatencyNanos: r.stats.LookupLatencyNanos.Load(), } } // CompileExprRules compiles a list of expression rules into a ruleset. // It returns an error if any of the rules are invalid, or if any of the analyzers // used by the rules are unknown (not provided in the analyzer list). func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier.Modifier, config *BuiltinConfig) (Ruleset, error) { var compiledRules []compiledExprRule fullAnMap := analyzersToMap(ans) fullModMap := modifiersToMap(mods) depAnMap := make(map[string]analyzer.Analyzer) stats := &statsCounters{} funcMap, geoMatcher := buildFunctionMap(config, stats) // Compile all rules and build a map of analyzers that are used by the rules. for _, rule := range rules { if rule.Action == "" && !rule.Log { return nil, fmt.Errorf("rule %q must have at least one of action or log", rule.Name) } var action *Action if rule.Action != "" { a, ok := actionStringToAction(rule.Action) if !ok { return nil, fmt.Errorf("rule %q has invalid action %q", rule.Name, rule.Action) } action = &a } visitor := &idVisitor{Variables: make(map[string]bool), Identifiers: make(map[string]bool)} patcher := &idPatcher{FuncMap: funcMap, GeoMatcher: geoMatcher} program, err := expr.Compile(rule.Expr, func(c *conf.Config) { c.Strict = false c.Expect = reflect.Bool c.Visitors = append(c.Visitors, visitor, patcher) for name, f := range funcMap { c.Functions[name] = &builtin.Function{ Name: name, Func: f.Func, Types: f.Types, } } }, ) if err != nil { return nil, fmt.Errorf("rule %q has invalid expression: %w", rule.Name, err) } if patcher.Err != nil { return nil, fmt.Errorf("rule %q failed to patch expression: %w", rule.Name, patcher.Err) } for name := range visitor.Identifiers { // Skip built-in analyzers & user-defined variables if isBuiltInAnalyzer(name) || visitor.Variables[name] { continue } if f, ok := funcMap[name]; ok { // Built-in function, initialize if necessary if f.InitFunc != nil { if err := f.InitFunc(); err != nil { return nil, fmt.Errorf("rule %q failed to initialize function %q: %w", rule.Name, name, err) } } } else if a, ok := fullAnMap[name]; ok { // Analyzer, add to dependency map depAnMap[name] = a } } startSecs := -1 if rule.StartTime != "" { startSecs, err = parseTimeOfDay(rule.StartTime) if err != nil { return nil, fmt.Errorf("rule %q has invalid start_time: %w", rule.Name, err) } } stopSecs := -1 if rule.StopTime != "" { stopSecs, err = parseTimeOfDay(rule.StopTime) if err != nil { return nil, fmt.Errorf("rule %q has invalid stop_time: %w", rule.Name, err) } } weekdays, weekdaysNegated, err := parseWeekdays(rule.Weekdays) if err != nil { return nil, fmt.Errorf("rule %q has invalid weekdays: %w", rule.Name, err) } cr := compiledExprRule{ Name: rule.Name, Action: action, Log: rule.Log, Program: program, GeoSiteConditions: extractGeoSiteConditions(rule.Expr), StartTimeSecs: startSecs, StopTimeSecs: stopSecs, Weekdays: weekdays, WeekdaysNegated: weekdaysNegated, } if action != nil && *action == ActionModify { mod, ok := fullModMap[rule.Modifier.Name] if !ok { return nil, fmt.Errorf("rule %q uses unknown modifier %q", rule.Name, rule.Modifier.Name) } modInst, err := mod.New(rule.Modifier.Args) if err != nil { return nil, fmt.Errorf("rule %q failed to create modifier instance: %w", rule.Name, err) } cr.ModInstance = modInst } compiledRules = append(compiledRules, cr) } return &exprRuleset{ Rules: compiledRules, Ans: ans, Logger: config.Logger, GeoMatcher: geoMatcher, stats: stats, }, nil } func populateExprEnv(m map[string]any, info StreamInfo) (macMap, ipMap, portMap map[string]any) { macMap = getSubMap() ipMap = getSubMap() portMap = getSubMap() macMap["src"] = info.SrcMAC.String() macMap["dst"] = info.DstMAC.String() ipMap["src"] = info.SrcIP.String() ipMap["dst"] = info.DstIP.String() portMap["src"] = info.SrcPort portMap["dst"] = info.DstPort m["id"] = info.ID m["proto"] = info.Protocol.String() m["mac"] = macMap m["ip"] = ipMap m["port"] = portMap for anName, anProps := range info.Props { if len(anProps) != 0 { m[anName] = anProps } } return macMap, ipMap, portMap } func getSubMap() map[string]any { m := subMapPool.Get().(map[string]any) clear(m) return m } func putSubMap(m map[string]any) { if m == nil { return } clear(m) subMapPool.Put(m) } func isBuiltInAnalyzer(name string) bool { switch name { case "id", "proto", "mac", "ip", "port": return true default: return false } } func actionStringToAction(action string) (Action, bool) { switch strings.ToLower(action) { case "allow": return ActionAllow, true case "block": return ActionBlock, true case "drop": return ActionDrop, true case "modify": return ActionModify, true default: return ActionMaybe, false } } // analyzersToMap converts a list of analyzers to a map of name -> analyzer. // This is for easier lookup when compiling rules. func analyzersToMap(ans []analyzer.Analyzer) map[string]analyzer.Analyzer { anMap := make(map[string]analyzer.Analyzer) for _, a := range ans { anMap[a.Name()] = a } return anMap } // modifiersToMap converts a list of modifiers to a map of name -> modifier. // This is for easier lookup when compiling rules. func modifiersToMap(mods []modifier.Modifier) map[string]modifier.Modifier { modMap := make(map[string]modifier.Modifier) for _, m := range mods { modMap[m.Name()] = m } return modMap } // idVisitor is a visitor that collects all identifiers in an expression. // This is for determining which analyzers are used by the expression. type idVisitor struct { Variables map[string]bool Identifiers map[string]bool } func (v *idVisitor) Visit(node *ast.Node) { if varNode, ok := (*node).(*ast.VariableDeclaratorNode); ok { v.Variables[varNode.Name] = true } else if idNode, ok := (*node).(*ast.IdentifierNode); ok { v.Identifiers[idNode.Value] = true } } // idPatcher patches the AST during expr compilation, replacing certain values with // their internal representations for better runtime performance. type idPatcher struct { FuncMap map[string]*Function GeoMatcher *geo.GeoMatcher Err error } func (p *idPatcher) Visit(node *ast.Node) { if p.tryPatchGeoSiteORChain(node) { return } switch (*node).(type) { case *ast.CallNode: callNode := (*node).(*ast.CallNode) if callNode.Callee == nil { // Ignore invalid call nodes return } if f, ok := p.FuncMap[callNode.Callee.String()]; ok { if f.PatchFunc != nil { if err := f.PatchFunc(&callNode.Arguments); err != nil { p.Err = err return } } } } } func (p *idPatcher) tryPatchGeoSiteORChain(node *ast.Node) bool { if p == nil || p.GeoMatcher == nil { return false } terms, ok := collectGeoSiteORChain(*node) if !ok || len(terms) < 2 { return false } hostExpr := strings.TrimSpace(terms[0].hostExpr) if hostExpr == "" { return false } conditions := make([]string, 0, len(terms)) for _, term := range terms { if strings.TrimSpace(term.hostExpr) != hostExpr { return false } conditions = append(conditions, term.condition) } normalized := normalizeUniqueLowerStrings(conditions) if len(normalized) < 2 { return false } hostNode, err := parser.Parse(hostExpr) if err != nil || hostNode == nil || hostNode.Node == nil { return false } call := &ast.CallNode{ Callee: &ast.IdentifierNode{Value: "geosite_set"}, Arguments: []ast.Node{ hostNode.Node, &ast.ConstantNode{Value: &geo.SiteConditionSet{Conditions: normalized}}, }, } ast.Patch(node, call) return true } type geositeTerm struct { hostExpr string condition string } func collectGeoSiteORChain(node ast.Node) ([]geositeTerm, bool) { switch n := node.(type) { case *ast.BinaryNode: if n.Operator != "or" && n.Operator != "||" { return nil, false } left, ok := collectGeoSiteORChain(n.Left) if !ok { return nil, false } right, ok := collectGeoSiteORChain(n.Right) if !ok { return nil, false } out := make([]geositeTerm, 0, len(left)+len(right)) out = append(out, left...) out = append(out, right...) return out, true case *ast.CallNode: idNode, ok := n.Callee.(*ast.IdentifierNode) if !ok || len(n.Arguments) < 2 { return nil, false } name := strings.ToLower(idNode.Value) if name == "geosite" { condNode, ok := n.Arguments[1].(*ast.StringNode) if !ok { return nil, false } return []geositeTerm{{ hostExpr: n.Arguments[0].String(), condition: condNode.Value, }}, true } if name != "geosite_set" { return nil, false } setNode, ok := n.Arguments[1].(*ast.ConstantNode) if !ok || setNode.Value == nil { return nil, false } set, ok := setNode.Value.(*geo.SiteConditionSet) if !ok || set == nil { return nil, false } if len(set.Conditions) == 0 { return nil, false } out := make([]geositeTerm, 0, len(set.Conditions)) hostExpr := n.Arguments[0].String() for _, condition := range set.Conditions { out = append(out, geositeTerm{hostExpr: hostExpr, condition: condition}) } return out, true default: return nil, false } } type Function struct { InitFunc func() error PatchFunc func(args *[]ast.Node) error Func func(params ...any) (any, error) Types []reflect.Type } func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*Function, *geo.GeoMatcher) { geoMatcher := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename) return map[string]*Function{ "geoip": { InitFunc: geoMatcher.LoadGeoIP, PatchFunc: nil, Func: func(params ...any) (any, error) { a, ok1 := params[0].(string) b, ok2 := params[1].(string) if !ok1 || !ok2 { return false, nil } return geoMatcher.MatchGeoIp(a, b), nil }, Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)}, }, "geosite": { InitFunc: geoMatcher.LoadGeoSite, PatchFunc: nil, Func: func(params ...any) (any, error) { a, ok1 := params[0].(string) b, ok2 := params[1].(string) if !ok1 || !ok2 { return false, nil } return geoMatcher.MatchGeoSite(a, b), nil }, Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)}, }, "geosite_set": { InitFunc: geoMatcher.LoadGeoSite, PatchFunc: nil, Func: func(params ...any) (any, error) { a, ok1 := params[0].(string) b, ok2 := params[1].(*geo.SiteConditionSet) if !ok1 || !ok2 { return false, nil } return geoMatcher.MatchGeoSiteSet(a, b), nil }, Types: []reflect.Type{ reflect.TypeOf((func(string, *geo.SiteConditionSet) bool)(nil)), }, }, "cidr": { InitFunc: nil, PatchFunc: func(args *[]ast.Node) error { cidrStringNode, ok := (*args)[1].(*ast.StringNode) if !ok { return fmt.Errorf("cidr: invalid argument type") } cidr, err := builtins.CompileCIDR(cidrStringNode.Value) if err != nil { return err } (*args)[1] = &ast.ConstantNode{Value: cidr} return nil }, Func: func(params ...any) (any, error) { a, ok1 := params[0].(string) b, ok2 := params[1].(*net.IPNet) if !ok1 || !ok2 { return false, nil } return builtins.MatchCIDR(a, b), nil }, Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)}, }, "lookup": { InitFunc: nil, PatchFunc: func(args *[]ast.Node) error { var serverStr *ast.StringNode if len(*args) > 1 { var ok bool serverStr, ok = (*args)[1].(*ast.StringNode) if !ok { return fmt.Errorf("lookup: invalid argument type") } } r := &net.Resolver{ Dial: func(ctx context.Context, network, address string) (net.Conn, error) { if serverStr != nil { address = serverStr.Value } return config.ProtectedDialContext(ctx, network, address) }, } if len(*args) > 1 { (*args)[1] = &ast.ConstantNode{Value: r} } else { *args = append(*args, &ast.ConstantNode{Value: r}) } return nil }, Func: func(params ...any) (any, error) { start := time.Now() if stats != nil { stats.LookupCalls.Add(1) defer func() { stats.LookupLatencyNanos.Add(uint64(time.Since(start).Nanoseconds())) }() } a, ok1 := params[0].(string) b, ok2 := params[1].(*net.Resolver) if !ok1 || !ok2 { return nil, nil } ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) defer cancel() out, err := b.LookupHost(ctx, a) if err != nil && stats != nil { stats.LookupErrors.Add(1) } return out, err }, Types: []reflect.Type{ reflect.TypeOf((func(string, *net.Resolver) []string)(nil)), }, }, }, geoMatcher } func matchTime(now time.Time, startSecs, stopSecs int, weekdays []time.Weekday, negated bool) bool { if startSecs >= 0 || stopSecs >= 0 { currentSecs := now.Hour()*3600 + now.Minute()*60 + now.Second() if startSecs >= 0 && stopSecs >= 0 { if startSecs <= stopSecs { if currentSecs < startSecs || currentSecs > stopSecs { return false } } else { if currentSecs < startSecs && currentSecs > stopSecs { return false } } } else if startSecs >= 0 { if currentSecs < startSecs { return false } } else if currentSecs > stopSecs { return false } } if len(weekdays) > 0 { current := now.Weekday() found := false for _, d := range weekdays { if current == d { found = true break } } if negated == found { return false } } return true } func parseTimeOfDay(s string) (int, error) { t, err := time.Parse("15:04:05", s) if err != nil { return -1, fmt.Errorf("invalid time %q (expected hh:mm:ss)", s) } return t.Hour()*3600 + t.Minute()*60 + t.Second(), nil } func parseWeekdays(days []string) ([]time.Weekday, bool, error) { if len(days) == 0 { return nil, false, nil } negated := false parsed := make([]time.Weekday, 0, len(days)) for i, d := range days { d = strings.TrimSpace(d) if i == 0 && strings.HasPrefix(d, "!") { negated = true d = strings.TrimSpace(strings.TrimPrefix(d, "!")) } var wd time.Weekday switch strings.ToLower(d) { case "sun", "sunday": wd = time.Sunday case "mon", "monday": wd = time.Monday case "tue", "tues", "tuesday": wd = time.Tuesday case "wed", "wednesday": wd = time.Wednesday case "thu", "thur", "thurs", "thursday": wd = time.Thursday case "fri", "friday": wd = time.Friday case "sat", "saturday": wd = time.Saturday default: return nil, false, fmt.Errorf("invalid weekday %q", d) } parsed = append(parsed, wd) } return parsed, negated, nil } const rulesetLogMetaKey = "_ruleset" func addGeoSiteLogMetadata(info StreamInfo, gm *geo.GeoMatcher, conditions []string) StreamInfo { hosts := extractGeoSiteHostCandidates(info) if len(hosts) == 0 { return info } matchedGeoSites := matchGeoSiteConditions(hosts, conditions, gm.MatchGeoSite) if len(matchedGeoSites) == 0 { return info } clonedProps := cloneCombinedPropMap(info.Props) clonedProps[rulesetLogMetaKey] = analyzer.PropMap{ "geosite": matchedGeoSites, "hosts": hosts, } info.Props = clonedProps return info } func extractGeoSiteHostCandidates(info StreamInfo) []string { out := make([]string, 0, 4) seen := make(map[string]struct{}, 4) add := func(raw string) { host := normalizeHost(raw) if host == "" { return } if _, ok := seen[host]; ok { return } seen[host] = struct{}{} out = append(out, host) } if sni, ok := info.Props.Get("tls", "req.sni").(string); ok { add(sni) } if sni, ok := info.Props.Get("quic", "req.sni").(string); ok { add(sni) } if host, ok := info.Props.Get("http", "req.headers.host").(string); ok { add(host) } if addr, ok := info.Props.Get("socks", "req.addr").(string); ok { add(addr) } qs := info.Props.Get("dns", "questions") switch v := qs.(type) { case []analyzer.PropMap: for _, q := range v { if name, ok := q["name"].(string); ok { add(name) } } case []interface{}: for _, item := range v { switch q := item.(type) { case analyzer.PropMap: if name, ok := q["name"].(string); ok { add(name) } case map[string]interface{}: if name, ok := q["name"].(string); ok { add(name) } } } } return out } func normalizeHost(raw string) string { s := strings.TrimSpace(strings.ToLower(raw)) if s == "" { return "" } // Handle bracketed host:port first, then unbracketed host:port. if strings.HasPrefix(s, "[") { if host, _, err := net.SplitHostPort(s); err == nil { s = host } } else if strings.Count(s, ":") == 1 { if host, _, err := net.SplitHostPort(s); err == nil { s = host } } s = strings.TrimPrefix(s, "[") s = strings.TrimSuffix(s, "]") s = strings.TrimSuffix(s, ".") if s == "" || net.ParseIP(s) != nil { return "" } return s } func matchGeoSiteConditions(hosts, conditions []string, matchFn func(site, condition string) bool) []string { out := make([]string, 0, len(conditions)) seen := make(map[string]struct{}, len(conditions)) for _, cond := range conditions { c := strings.TrimSpace(strings.ToLower(cond)) if c == "" { continue } if _, ok := seen[c]; ok { continue } for _, host := range hosts { if matchFn(host, c) { seen[c] = struct{}{} out = append(out, c) break } } } return out } func cloneCombinedPropMap(in analyzer.CombinedPropMap) analyzer.CombinedPropMap { if in == nil { return analyzer.CombinedPropMap{} } out := make(analyzer.CombinedPropMap, len(in)+1) for k, v := range in { out[k] = v } return out } func extractGeoSiteConditions(expression string) []string { tree, err := parser.Parse(expression) if err != nil || tree == nil || tree.Node == nil { return nil } root := tree.Node v := &geositeCallVisitor{ conditions: make([]string, 0, 4), } ast.Walk(&root, v) return normalizeUniqueLowerStrings(v.conditions) } type geositeCallVisitor struct { conditions []string } func (v *geositeCallVisitor) Visit(node *ast.Node) { callNode, ok := (*node).(*ast.CallNode) if !ok || callNode.Callee == nil || len(callNode.Arguments) < 2 { return } idNode, ok := callNode.Callee.(*ast.IdentifierNode) if !ok || strings.ToLower(idNode.Value) != "geosite" { return } stringNode, ok := callNode.Arguments[1].(*ast.StringNode) if !ok { return } v.conditions = append(v.conditions, stringNode.Value) } func normalizeUniqueLowerStrings(in []string) []string { out := make([]string, 0, len(in)) seen := make(map[string]struct{}, len(in)) for _, v := range in { s := strings.TrimSpace(strings.ToLower(v)) if s == "" { continue } if _, ok := seen[s]; ok { continue } seen[s] = struct{}{} out = append(out, s) } return out }