diff --git a/ruleset/expr.go b/ruleset/expr.go index 11682f0..276d7aa 100644 --- a/ruleset/expr.go +++ b/ruleset/expr.go @@ -14,6 +14,7 @@ import ( "github.com/expr-lang/expr" "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/conf" + "github.com/expr-lang/expr/parser" "github.com/expr-lang/expr/vm" "gopkg.in/yaml.v3" @@ -49,19 +50,21 @@ func ExprRulesFromYAML(file string) ([]ExprRule, error) { // compiledExprRule is the internal, compiled representation of an expression rule. type compiledExprRule struct { - Name string - Action *Action // fallthrough if nil - Log bool - ModInstance modifier.Instance - Program *vm.Program + Name string + Action *Action // fallthrough if nil + Log bool + ModInstance modifier.Instance + Program *vm.Program + GeoSiteConditions []string } var _ Ruleset = (*exprRuleset)(nil) type exprRuleset struct { - Rules []compiledExprRule - Ans []analyzer.Analyzer - Logger Logger + Rules []compiledExprRule + Ans []analyzer.Analyzer + Logger Logger + GeoMatcher *geo.GeoMatcher } func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer { @@ -79,7 +82,11 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult { } if vBool, ok := v.(bool); ok && vBool { if rule.Log { - r.Logger.Log(info, rule.Name) + logInfo := info + if len(rule.GeoSiteConditions) > 0 && r.GeoMatcher != nil { + logInfo = addGeoSiteLogMetadata(logInfo, r.GeoMatcher, rule.GeoSiteConditions) + } + r.Logger.Log(logInfo, rule.Name) } if rule.Action != nil { return MatchResult{ @@ -103,7 +110,7 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier fullAnMap := analyzersToMap(ans) fullModMap := modifiersToMap(mods) depAnMap := make(map[string]analyzer.Analyzer) - funcMap := buildFunctionMap(config) + funcMap, geoMatcher := buildFunctionMap(config) // Compile all rules and build a map of analyzers that are used by the rules. for _, rule := range rules { if rule.Action == "" && !rule.Log { @@ -157,10 +164,11 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier } } cr := compiledExprRule{ - Name: rule.Name, - Action: action, - Log: rule.Log, - Program: program, + Name: rule.Name, + Action: action, + Log: rule.Log, + Program: program, + GeoSiteConditions: extractGeoSiteConditions(rule.Expr), } if action != nil && *action == ActionModify { mod, ok := fullModMap[rule.Modifier.Name] @@ -181,9 +189,10 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier depAns = append(depAns, a) } return &exprRuleset{ - Rules: compiledRules, - Ans: depAns, - Logger: config.Logger, + Rules: compiledRules, + Ans: depAns, + Logger: config.Logger, + GeoMatcher: geoMatcher, }, nil } @@ -305,7 +314,7 @@ type Function struct { Types []reflect.Type } -func buildFunctionMap(config *BuiltinConfig) map[string]*Function { +func buildFunctionMap(config *BuiltinConfig) (map[string]*Function, *geo.GeoMatcher) { geoMatcher := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename) return map[string]*Function{ "geoip": { @@ -379,5 +388,183 @@ func buildFunctionMap(config *BuiltinConfig) map[string]*Function { reflect.TypeOf((func(string, *net.Resolver) []string)(nil)), }, }, - } + }, geoMatcher +} + +const rulesetLogMetaKey = "_ruleset" + +func addGeoSiteLogMetadata(info StreamInfo, gm *geo.GeoMatcher, conditions []string) StreamInfo { + hosts := extractGeoSiteHostCandidates(info) + if len(hosts) == 0 { + return info + } + matchedGeoSites := matchGeoSiteConditions(hosts, conditions, gm.MatchGeoSite) + if len(matchedGeoSites) == 0 { + return info + } + clonedProps := cloneCombinedPropMap(info.Props) + clonedProps[rulesetLogMetaKey] = analyzer.PropMap{ + "geosite": matchedGeoSites, + "hosts": hosts, + } + info.Props = clonedProps + return info +} + +func extractGeoSiteHostCandidates(info StreamInfo) []string { + out := make([]string, 0, 4) + seen := make(map[string]struct{}, 4) + add := func(raw string) { + host := normalizeHost(raw) + if host == "" { + return + } + if _, ok := seen[host]; ok { + return + } + seen[host] = struct{}{} + out = append(out, host) + } + if sni, ok := info.Props.Get("tls", "req.sni").(string); ok { + add(sni) + } + if sni, ok := info.Props.Get("quic", "req.sni").(string); ok { + add(sni) + } + if host, ok := info.Props.Get("http", "req.headers.host").(string); ok { + add(host) + } + if addr, ok := info.Props.Get("socks", "req.addr").(string); ok { + add(addr) + } + qs := info.Props.Get("dns", "questions") + switch v := qs.(type) { + case []analyzer.PropMap: + for _, q := range v { + if name, ok := q["name"].(string); ok { + add(name) + } + } + case []interface{}: + for _, item := range v { + switch q := item.(type) { + case analyzer.PropMap: + if name, ok := q["name"].(string); ok { + add(name) + } + case map[string]interface{}: + if name, ok := q["name"].(string); ok { + add(name) + } + } + } + } + return out +} + +func normalizeHost(raw string) string { + s := strings.TrimSpace(strings.ToLower(raw)) + if s == "" { + return "" + } + // Handle bracketed host:port first, then unbracketed host:port. + if strings.HasPrefix(s, "[") { + if host, _, err := net.SplitHostPort(s); err == nil { + s = host + } + } else if strings.Count(s, ":") == 1 { + if host, _, err := net.SplitHostPort(s); err == nil { + s = host + } + } + s = strings.TrimPrefix(s, "[") + s = strings.TrimSuffix(s, "]") + s = strings.TrimSuffix(s, ".") + if s == "" || net.ParseIP(s) != nil { + return "" + } + return s +} + +func matchGeoSiteConditions(hosts, conditions []string, matchFn func(site, condition string) bool) []string { + out := make([]string, 0, len(conditions)) + seen := make(map[string]struct{}, len(conditions)) + for _, cond := range conditions { + c := strings.TrimSpace(strings.ToLower(cond)) + if c == "" { + continue + } + if _, ok := seen[c]; ok { + continue + } + for _, host := range hosts { + if matchFn(host, c) { + seen[c] = struct{}{} + out = append(out, c) + break + } + } + } + return out +} + +func cloneCombinedPropMap(in analyzer.CombinedPropMap) analyzer.CombinedPropMap { + if in == nil { + return analyzer.CombinedPropMap{} + } + out := make(analyzer.CombinedPropMap, len(in)+1) + for k, v := range in { + out[k] = v + } + return out +} + +func extractGeoSiteConditions(expression string) []string { + tree, err := parser.Parse(expression) + if err != nil || tree == nil || tree.Node == nil { + return nil + } + root := tree.Node + v := &geositeCallVisitor{ + conditions: make([]string, 0, 4), + } + ast.Walk(&root, v) + return normalizeUniqueLowerStrings(v.conditions) +} + +type geositeCallVisitor struct { + conditions []string +} + +func (v *geositeCallVisitor) Visit(node *ast.Node) { + callNode, ok := (*node).(*ast.CallNode) + if !ok || callNode.Callee == nil || len(callNode.Arguments) < 2 { + return + } + idNode, ok := callNode.Callee.(*ast.IdentifierNode) + if !ok || strings.ToLower(idNode.Value) != "geosite" { + return + } + stringNode, ok := callNode.Arguments[1].(*ast.StringNode) + if !ok { + return + } + v.conditions = append(v.conditions, stringNode.Value) +} + +func normalizeUniqueLowerStrings(in []string) []string { + out := make([]string, 0, len(in)) + seen := make(map[string]struct{}, len(in)) + for _, v := range in { + s := strings.TrimSpace(strings.ToLower(v)) + if s == "" { + continue + } + if _, ok := seen[s]; ok { + continue + } + seen[s] = struct{}{} + out = append(out, s) + } + return out } diff --git a/ruleset/expr_test.go b/ruleset/expr_test.go new file mode 100644 index 0000000..23345f0 --- /dev/null +++ b/ruleset/expr_test.go @@ -0,0 +1,65 @@ +package ruleset + +import ( + "reflect" + "testing" + + "git.difuse.io/Difuse/Mellaris/analyzer" +) + +func TestExtractGeoSiteConditions(t *testing.T) { + expression := ` + (geosite(tls.req.sni, "openai") || geosite(quic.req.sni, "OpenAI")) && + geosite(http.req.headers.host, "google@ads") + ` + got := extractGeoSiteConditions(expression) + want := []string{"openai", "google@ads"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("extractGeoSiteConditions() = %v, want %v", got, want) + } +} + +func TestExtractGeoSiteHostCandidates(t *testing.T) { + info := StreamInfo{ + Props: analyzer.CombinedPropMap{ + "quic": analyzer.PropMap{ + "req": analyzer.PropMap{ + "sni": "ChatGPT.com", + }, + }, + "http": analyzer.PropMap{ + "req": analyzer.PropMap{ + "headers": analyzer.PropMap{ + "host": "api.openai.com:443", + }, + }, + }, + "dns": analyzer.PropMap{ + "questions": []analyzer.PropMap{ + {"name": "chatgpt.com."}, + {"name": "8.8.8.8"}, + }, + }, + }, + } + got := extractGeoSiteHostCandidates(info) + want := []string{"chatgpt.com", "api.openai.com"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("extractGeoSiteHostCandidates() = %v, want %v", got, want) + } +} + +func TestMatchGeoSiteConditions(t *testing.T) { + hosts := []string{"chatgpt.com", "api.openai.com"} + conditions := []string{" openai ", "google", "OPENAI"} + got := matchGeoSiteConditions(hosts, conditions, func(site, condition string) bool { + if condition != "openai" { + return false + } + return site == "chatgpt.com" || site == "api.openai.com" + }) + want := []string{"openai"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("matchGeoSiteConditions() = %v, want %v", got, want) + } +}