package ruleset import ( "net" "reflect" "strings" "testing" "git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo" "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/parser" ) type testAnalyzer struct { name string } func (a testAnalyzer) Name() string { return a.name } func (a testAnalyzer) Limit() int { return 0 } func TestExtractGeoSiteConditions(t *testing.T) { expression := ` (geosite(tls.req.sni, "openai") || geosite(quic.req.sni, "OpenAI")) && geosite(http.req.headers.host, "google@ads") ` got := extractGeoSiteConditions(expression) want := []string{"openai", "google@ads"} if !reflect.DeepEqual(got, want) { t.Fatalf("extractGeoSiteConditions() = %v, want %v", got, want) } } func TestExtractGeoSiteHostCandidates(t *testing.T) { info := StreamInfo{ Props: analyzer.CombinedPropMap{ "quic": analyzer.PropMap{ "req": analyzer.PropMap{ "sni": "ChatGPT.com", }, }, "http": analyzer.PropMap{ "req": analyzer.PropMap{ "headers": analyzer.PropMap{ "host": "api.openai.com:443", }, }, }, "dns": analyzer.PropMap{ "questions": []analyzer.PropMap{ {"name": "chatgpt.com."}, {"name": "8.8.8.8"}, }, }, }, } got := extractGeoSiteHostCandidates(info) want := []string{"chatgpt.com", "api.openai.com"} if !reflect.DeepEqual(got, want) { t.Fatalf("extractGeoSiteHostCandidates() = %v, want %v", got, want) } } func TestMatchGeoSiteConditions(t *testing.T) { hosts := []string{"chatgpt.com", "api.openai.com"} conditions := []string{" openai ", "google", "OPENAI"} got := matchGeoSiteConditions(hosts, conditions, func(site, condition string) bool { if condition != "openai" { return false } return site == "chatgpt.com" || site == "api.openai.com" }) want := []string{"openai"} if !reflect.DeepEqual(got, want) { t.Fatalf("matchGeoSiteConditions() = %v, want %v", got, want) } } func TestIDPatcher_PatchesGeoSiteORChainToGeoSiteSet(t *testing.T) { tree, err := parser.Parse(`geosite(tls.req.sni, "google") || geosite(tls.req.sni, "youtube") || geosite(tls.req.sni, "openai")`) if err != nil { t.Fatalf("parse expression: %v", err) } root := tree.Node patcher := &idPatcher{GeoMatcher: geo.NewGeoMatcher("", "")} ast.Walk(&root, patcher) if patcher.Err != nil { t.Fatalf("patch error: %v", patcher.Err) } got := root.String() if !strings.Contains(got, "geosite_set(") { t.Fatalf("expected geosite_set rewrite, got %q", got) } if strings.Contains(got, "||") || strings.Contains(got, " or ") { t.Fatalf("expected OR chain to be collapsed, got %q", got) } } func TestCompileExprRulesPrunesUnusedAnalyzers(t *testing.T) { rs, err := CompileExprRules([]ExprRule{ {Name: "network-only", Action: "allow", Expr: `proto == "tcp" && port.dst == 443`}, }, []analyzer.Analyzer{testAnalyzer{name: "tls"}, testAnalyzer{name: "quic"}}, nil, &BuiltinConfig{}) if err != nil { t.Fatalf("CompileExprRules error: %v", err) } exprRS := rs.(*exprRuleset) if len(exprRS.Ans) != 0 { t.Fatalf("expected no analyzers for network-only rule, got %d", len(exprRS.Ans)) } if exprRS.Rules[0].Native == nil { t.Fatalf("expected network-only rule to compile to native matcher") } got := rs.Match(StreamInfo{Protocol: ProtocolTCP, DstPort: 443}) if got.Action != ActionAllow { t.Fatalf("native match action=%v want=%v", got.Action, ActionAllow) } } func TestCompileExprRulesKeepsReferencedAnalyzersOnly(t *testing.T) { rs, err := CompileExprRules([]ExprRule{ {Name: "tls-only", Action: "allow", Expr: `tls != nil && tls.req != nil && tls.req.sni == "example.com"`}, }, []analyzer.Analyzer{testAnalyzer{name: "tls"}, testAnalyzer{name: "quic"}}, nil, &BuiltinConfig{}) if err != nil { t.Fatalf("CompileExprRules error: %v", err) } exprRS := rs.(*exprRuleset) if len(exprRS.Ans) != 1 || exprRS.Ans[0].Name() != "tls" { t.Fatalf("expected only tls analyzer, got %#v", exprRS.Ans) } } func TestNativeCIDRMatcher(t *testing.T) { funcMap, geoMatcher := buildFunctionMapForTest() n := compileNativeExpr(`cidr(ip.src, "192.168.1.0/24") && port.dst >= 80 && port.dst <= 443`, funcMap, geoMatcher) if n == nil { t.Fatal("expected native matcher") } if !n.Match(StreamInfo{SrcIP: net.ParseIP("192.168.1.10"), DstPort: 443}) { t.Fatal("expected native CIDR matcher to match") } if n.Match(StreamInfo{SrcIP: net.ParseIP("10.0.0.1"), DstPort: 443}) { t.Fatal("expected native CIDR matcher not to match") } } func TestCanFinalizeAfterLogForRequestOnlyActionRules(t *testing.T) { rs, err := CompileExprRules([]ExprRule{ {Name: "log-host", Log: true, Expr: `tls != nil && tls.req != nil && tls.req.sni != nil`}, {Name: "block-bad-host", Action: "block", Expr: `tls != nil && tls.req != nil && tls.req.sni == "bad.example"`}, }, []analyzer.Analyzer{testAnalyzer{name: "tls"}}, nil, &BuiltinConfig{}) if err != nil { t.Fatalf("CompileExprRules error: %v", err) } info := StreamInfo{ Props: analyzer.CombinedPropMap{ "tls": analyzer.PropMap{"req": analyzer.PropMap{"sni": "good.example"}}, }, } if !rs.(LogFinalizer).CanFinalizeAfterLog(info, []string{"tls"}) { t.Fatal("expected request-only rules to allow log finalization once request props exist") } } func TestCanFinalizeAfterLogWaitsForResponseActionRules(t *testing.T) { rs, err := CompileExprRules([]ExprRule{ {Name: "log-host", Log: true, Expr: `tls != nil && tls.req != nil && tls.req.sni != nil`}, {Name: "block-response", Action: "block", Expr: `tls != nil && tls.resp != nil && tls.resp.cipher_suite == "bad"`}, }, []analyzer.Analyzer{testAnalyzer{name: "tls"}}, nil, &BuiltinConfig{}) if err != nil { t.Fatalf("CompileExprRules error: %v", err) } info := StreamInfo{ Props: analyzer.CombinedPropMap{ "tls": analyzer.PropMap{"req": analyzer.PropMap{"sni": "good.example"}}, }, } if rs.(LogFinalizer).CanFinalizeAfterLog(info, []string{"tls"}) { t.Fatal("expected response-side rule to keep inspection open") } } func buildFunctionMapForTest() (map[string]*Function, *geo.GeoMatcher) { m, g := buildFunctionMap(&BuiltinConfig{}, nil) return m, g }