test: improve coverage across package
This commit is contained in:
98
ruleset/builtins/cidr_test.go
Normal file
98
ruleset/builtins/cidr_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package builtins
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompileCIDR(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
wantStr string
|
||||
}{
|
||||
{"valid ipv4", "192.168.0.0/24", false, "192.168.0.0/24"},
|
||||
{"valid ipv6", "2001:db8::/32", false, "2001:db8::/32"},
|
||||
{"valid host ipv4", "10.0.0.1/32", false, "10.0.0.1/32"},
|
||||
{"valid host ipv6", "::1/128", false, "::1/128"},
|
||||
{"invalid no mask", "192.168.0.0", true, ""},
|
||||
{"invalid bad ip", "not-an-ip/24", true, ""},
|
||||
{"invalid empty", "", true, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := CompileCIDR(tt.input)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("CompileCIDR(%q) expected error, got nil", tt.input)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("CompileCIDR(%q) unexpected error: %v", tt.input, err)
|
||||
}
|
||||
if got.String() != tt.wantStr {
|
||||
t.Errorf("CompileCIDR(%q) = %q, want %q", tt.input, got.String(), tt.wantStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchCIDR(t *testing.T) {
|
||||
cidr := mustCompileCIDR(t, "192.168.0.0/24")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
want bool
|
||||
}{
|
||||
{"inside", "192.168.0.1", true},
|
||||
{"boundary low", "192.168.0.0", true},
|
||||
{"boundary high", "192.168.0.255", true},
|
||||
{"outside", "192.168.1.1", false},
|
||||
{"different network", "10.0.0.1", false},
|
||||
{"invalid ip", "not-an-ip", false},
|
||||
{"empty", "", false},
|
||||
{"ipv6 in ipv4", "::1", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MatchCIDR(tt.ip, cidr)
|
||||
if got != tt.want {
|
||||
t.Errorf("MatchCIDR(%q, %q) = %v, want %v", tt.ip, cidr, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchCIDR_IPv6(t *testing.T) {
|
||||
cidr := mustCompileCIDR(t, "2001:db8::/32")
|
||||
|
||||
inside := "2001:db8::1"
|
||||
if !MatchCIDR(inside, cidr) {
|
||||
t.Errorf("MatchCIDR(%q) should be true", inside)
|
||||
}
|
||||
|
||||
outside := "2001:db9::1"
|
||||
if MatchCIDR(outside, cidr) {
|
||||
t.Errorf("MatchCIDR(%q) should be false", outside)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchCIDR_NullResult(t *testing.T) {
|
||||
if MatchCIDR("10.0.0.1", &net.IPNet{}) {
|
||||
t.Error("MatchCIDR with empty IPNet should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func mustCompileCIDR(t *testing.T, cidr string) *net.IPNet {
|
||||
t.Helper()
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse CIDR %q: %v", cidr, err)
|
||||
}
|
||||
return ipNet
|
||||
}
|
||||
115
ruleset/builtins/geo/geo_matcher_test.go
Normal file
115
ruleset/builtins/geo/geo_matcher_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package geo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
|
||||
)
|
||||
|
||||
type fakeGeoLoader struct {
|
||||
geoip map[string]*v2geo.GeoIP
|
||||
geosite map[string]*v2geo.GeoSite
|
||||
}
|
||||
|
||||
func (l *fakeGeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) {
|
||||
return l.geoip, nil
|
||||
}
|
||||
|
||||
func (l *fakeGeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) {
|
||||
return l.geosite, nil
|
||||
}
|
||||
|
||||
func TestGeoMatcher_MatchGeoIp_Cached(t *testing.T) {
|
||||
loader := &fakeGeoLoader{
|
||||
geoip: map[string]*v2geo.GeoIP{
|
||||
"us": {
|
||||
Cidr: []*v2geo.CIDR{
|
||||
{Ip: ipv4(8, 8, 8, 0), Prefix: 24},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
g := NewGeoMatcher("", "")
|
||||
g.geoLoader = loader
|
||||
|
||||
if !g.MatchGeoIp("8.8.8.8", "US") {
|
||||
t.Error("MatchGeoIp should match 8.8.8.8 in US range")
|
||||
}
|
||||
if g.MatchGeoIp("9.9.9.9", "US") {
|
||||
t.Error("MatchGeoIp should not match 9.9.9.9 in US range")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoMatcher_MatchGeoIp_EmptyCondition(t *testing.T) {
|
||||
g := NewGeoMatcher("", "")
|
||||
if g.MatchGeoIp("1.2.3.4", "") {
|
||||
t.Error("MatchGeoIp with empty condition should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoMatcher_MatchGeoIp_InvalidIP(t *testing.T) {
|
||||
g := NewGeoMatcher("", "")
|
||||
if g.MatchGeoIp("not-an-ip", "us") {
|
||||
t.Error("MatchGeoIp with invalid IP should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoMatcher_MatchGeoIp_MissingCountry(t *testing.T) {
|
||||
loader := &fakeGeoLoader{
|
||||
geoip: map[string]*v2geo.GeoIP{},
|
||||
}
|
||||
g := NewGeoMatcher("", "")
|
||||
g.geoLoader = loader
|
||||
|
||||
if g.MatchGeoIp("8.8.8.8", "us") {
|
||||
t.Error("MatchGeoIp for missing country should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoMatcher_MatchGeoSite(t *testing.T) {
|
||||
loader := &fakeGeoLoader{
|
||||
geosite: map[string]*v2geo.GeoSite{
|
||||
"openai": {
|
||||
Domain: []*v2geo.Domain{
|
||||
{Type: v2geo.Domain_Plain, Value: "openai"},
|
||||
{Type: v2geo.Domain_Full, Value: "chatgpt.com"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
g := NewGeoMatcher("", "")
|
||||
g.geoLoader = loader
|
||||
|
||||
if !g.MatchGeoSite("api.openai.com", "openai") {
|
||||
t.Error("MatchGeoSite should match via plain domain")
|
||||
}
|
||||
if !g.MatchGeoSite("chatgpt.com", "openai") {
|
||||
t.Error("MatchGeoSite should match via full domain")
|
||||
}
|
||||
if g.MatchGeoSite("google.com", "openai") {
|
||||
t.Error("MatchGeoSite should not match unrelated host")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoMatcher_MatchGeoSite_EmptyCondition(t *testing.T) {
|
||||
g := NewGeoMatcher("", "")
|
||||
if g.MatchGeoSite("test.com", "") {
|
||||
t.Error("MatchGeoSite with empty condition should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoMatcher_MatchGeoSite_MissingSite(t *testing.T) {
|
||||
loader := &fakeGeoLoader{
|
||||
geosite: map[string]*v2geo.GeoSite{},
|
||||
}
|
||||
g := NewGeoMatcher("", "")
|
||||
g.geoLoader = loader
|
||||
|
||||
if g.MatchGeoSite("test.com", "nonexistent") {
|
||||
t.Error("MatchGeoSite for missing site should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func ipv4(a, b, c, d byte) []byte {
|
||||
return []byte{a, b, c, d}
|
||||
}
|
||||
BIN
ruleset/builtins/geo/geoip.dat
Normal file
BIN
ruleset/builtins/geo/geoip.dat
Normal file
Binary file not shown.
324
ruleset/builtins/geo/matchers_v2geo_test.go
Normal file
324
ruleset/builtins/geo/matchers_v2geo_test.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package geo
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
|
||||
)
|
||||
|
||||
func TestParseGeoSiteName(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantBase string
|
||||
wantAttrs []string
|
||||
}{
|
||||
{"google", "google", nil},
|
||||
{"google@ads", "google", []string{"ads"}},
|
||||
{"google@ads@news", "google", []string{"ads", "news"}},
|
||||
{" google ", "google", nil},
|
||||
{" google @ ads ", "google", []string{"ads"}},
|
||||
{"openai@ ads @ news ", "openai", []string{"ads", "news"}},
|
||||
{"@onlyattrs", "", []string{"onlyattrs"}},
|
||||
{"", "", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
base, attrs := parseGeoSiteName(tt.input)
|
||||
if base != tt.wantBase {
|
||||
t.Errorf("parseGeoSiteName(%q) base = %q, want %q", tt.input, base, tt.wantBase)
|
||||
}
|
||||
if len(attrs) != len(tt.wantAttrs) {
|
||||
t.Fatalf("parseGeoSiteName(%q) attrs len = %d, want %d", tt.input, len(attrs), len(tt.wantAttrs))
|
||||
}
|
||||
for i, attr := range attrs {
|
||||
if attr != tt.wantAttrs[i] {
|
||||
t.Errorf("parseGeoSiteName(%q) attrs[%d] = %q, want %q", tt.input, i, attr, tt.wantAttrs[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostInfo_String(t *testing.T) {
|
||||
h := HostInfo{
|
||||
Name: "example.com",
|
||||
IPv4: net.ParseIP("1.2.3.4"),
|
||||
IPv6: net.ParseIP("::1"),
|
||||
}
|
||||
want := "example.com|1.2.3.4|::1"
|
||||
if got := h.String(); got != want {
|
||||
t.Errorf("HostInfo.String() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostInfo_String_Partial(t *testing.T) {
|
||||
h := HostInfo{
|
||||
Name: "test.com",
|
||||
IPv4: net.ParseIP("10.0.0.1"),
|
||||
}
|
||||
want := "test.com|10.0.0.1|<nil>"
|
||||
if got := h.String(); got != want {
|
||||
t.Errorf("HostInfo.String() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoipMatcher_Match(t *testing.T) {
|
||||
_, n4, _ := net.ParseCIDR("10.0.0.0/8")
|
||||
_, n4_2, _ := net.ParseCIDR("192.168.0.0/16")
|
||||
m := &geoipMatcher{
|
||||
N4: []*net.IPNet{n4, n4_2},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
host HostInfo
|
||||
want bool
|
||||
}{
|
||||
{"ipv4 match", HostInfo{IPv4: net.ParseIP("10.1.2.3")}, true},
|
||||
{"ipv4 no match", HostInfo{IPv4: net.ParseIP("172.16.0.1")}, false},
|
||||
{"ipv4 match second net", HostInfo{IPv4: net.ParseIP("192.168.1.1")}, true},
|
||||
{"no ip", HostInfo{}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := m.Match(tt.host); got != tt.want {
|
||||
t.Errorf("Match() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoipMatcher_Match_Inverse(t *testing.T) {
|
||||
_, n4, _ := net.ParseCIDR("10.0.0.0/8")
|
||||
m := &geoipMatcher{
|
||||
N4: []*net.IPNet{n4},
|
||||
Inverse: true,
|
||||
}
|
||||
|
||||
if m.Match(HostInfo{IPv4: net.ParseIP("10.1.2.3")}) {
|
||||
t.Error("Inverse: inside range should return false")
|
||||
}
|
||||
if !m.Match(HostInfo{IPv4: net.ParseIP("172.16.0.1")}) {
|
||||
t.Error("Inverse: outside range should return true")
|
||||
}
|
||||
if !m.Match(HostInfo{}) {
|
||||
t.Error("Inverse: no IP should return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoipMatcher_Match_IPv6(t *testing.T) {
|
||||
_, n6, _ := net.ParseCIDR("2001:db8::/32")
|
||||
m := &geoipMatcher{
|
||||
N6: []*net.IPNet{n6},
|
||||
}
|
||||
|
||||
if !m.Match(HostInfo{IPv6: net.ParseIP("2001:db8::1")}) {
|
||||
t.Error("IPv6 match failed")
|
||||
}
|
||||
if m.Match(HostInfo{IPv6: net.ParseIP("2001:db9::1")}) {
|
||||
t.Error("IPv6 should not match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeositeMatcher_matchDomain_Plain(t *testing.T) {
|
||||
m := &geositeMatcher{}
|
||||
d := geositeDomain{
|
||||
Type: geositeDomainPlain,
|
||||
Value: "openai",
|
||||
}
|
||||
if !m.matchDomain(d, HostInfo{Name: "api.openai.com"}) {
|
||||
t.Error("plain domain should match via substring")
|
||||
}
|
||||
if m.matchDomain(d, HostInfo{Name: "google.com"}) {
|
||||
t.Error("plain domain should not match unrelated host")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeositeMatcher_matchDomain_Full(t *testing.T) {
|
||||
m := &geositeMatcher{}
|
||||
d := geositeDomain{
|
||||
Type: geositeDomainFull,
|
||||
Value: "example.com",
|
||||
}
|
||||
if !m.matchDomain(d, HostInfo{Name: "example.com"}) {
|
||||
t.Error("full domain should match exact")
|
||||
}
|
||||
if m.matchDomain(d, HostInfo{Name: "www.example.com"}) {
|
||||
t.Error("full domain should not match subdomain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeositeMatcher_matchDomain_Root(t *testing.T) {
|
||||
m := &geositeMatcher{}
|
||||
d := geositeDomain{
|
||||
Type: geositeDomainRoot,
|
||||
Value: "example.com",
|
||||
}
|
||||
if !m.matchDomain(d, HostInfo{Name: "example.com"}) {
|
||||
t.Error("root domain should match exact")
|
||||
}
|
||||
if !m.matchDomain(d, HostInfo{Name: "www.example.com"}) {
|
||||
t.Error("root domain should match subdomain")
|
||||
}
|
||||
if m.matchDomain(d, HostInfo{Name: "www.example.com.au"}) {
|
||||
t.Error("root domain should not match unrelated suffix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeositeMatcher_matchDomain_Attrs(t *testing.T) {
|
||||
m := &geositeMatcher{Attrs: []string{"ads"}}
|
||||
d := geositeDomain{
|
||||
Type: geositeDomainPlain,
|
||||
Value: "google",
|
||||
Attrs: map[string]bool{"ads": true},
|
||||
}
|
||||
if !m.matchDomain(d, HostInfo{Name: "google.com"}) {
|
||||
t.Error("should match when domain has required attr")
|
||||
}
|
||||
|
||||
dNoAttrs := geositeDomain{
|
||||
Type: geositeDomainPlain,
|
||||
Value: "google",
|
||||
Attrs: map[string]bool{},
|
||||
}
|
||||
if m.matchDomain(dNoAttrs, HostInfo{Name: "google.com"}) {
|
||||
t.Error("should not match when domain lacks required attr")
|
||||
}
|
||||
|
||||
dOtherAttrs := geositeDomain{
|
||||
Type: geositeDomainPlain,
|
||||
Value: "google",
|
||||
Attrs: map[string]bool{"news": true},
|
||||
}
|
||||
if m.matchDomain(dOtherAttrs, HostInfo{Name: "google.com"}) {
|
||||
t.Error("should not match when domain has wrong attr")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeositeMatcher_Match(t *testing.T) {
|
||||
m := &geositeMatcher{
|
||||
Domains: []geositeDomain{
|
||||
{Type: geositeDomainFull, Value: "exact.com"},
|
||||
{Type: geositeDomainPlain, Value: "partial"},
|
||||
},
|
||||
}
|
||||
if !m.Match(HostInfo{Name: "exact.com"}) {
|
||||
t.Error("should match full domain")
|
||||
}
|
||||
if !m.Match(HostInfo{Name: "www.partial.net"}) {
|
||||
t.Error("should match partial domain")
|
||||
}
|
||||
if m.Match(HostInfo{Name: "other.net"}) {
|
||||
t.Error("should not match unrelated host")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainAttributeToMap(t *testing.T) {
|
||||
attrs := []*v2geo.Domain_Attribute{
|
||||
{Key: "ads"},
|
||||
{Key: "news"},
|
||||
}
|
||||
got := domainAttributeToMap(attrs)
|
||||
if len(got) != 2 || !got["ads"] || !got["news"] {
|
||||
t.Errorf("domainAttributeToMap = %v, want {ads:true, news:true}", got)
|
||||
}
|
||||
|
||||
got2 := domainAttributeToMap(nil)
|
||||
if len(got2) != 0 {
|
||||
t.Errorf("domainAttributeToMap(nil) = %v, want empty map", got2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGeoIPMatcher(t *testing.T) {
|
||||
list := &v2geo.GeoIP{
|
||||
Cidr: []*v2geo.CIDR{
|
||||
{Ip: net.IPv4(10, 0, 0, 0).To4(), Prefix: 8},
|
||||
{Ip: net.IPv4(192, 168, 0, 0).To4(), Prefix: 16},
|
||||
},
|
||||
InverseMatch: false,
|
||||
}
|
||||
m, err := newGeoIPMatcher(list)
|
||||
if err != nil {
|
||||
t.Fatalf("newGeoIPMatcher error: %v", err)
|
||||
}
|
||||
if len(m.N4) != 2 {
|
||||
t.Errorf("expected 2 IPv4 nets, got %d", len(m.N4))
|
||||
}
|
||||
if m.Inverse {
|
||||
t.Error("Inverse should be false")
|
||||
}
|
||||
// Verify sorted order: 10.0.0.0/8 < 192.168.0.0/16
|
||||
if m.N4[0].IP.String() != "10.0.0.0" {
|
||||
t.Errorf("N4[0] = %s, want 10.0.0.0", m.N4[0].IP)
|
||||
}
|
||||
if m.N4[1].IP.String() != "192.168.0.0" {
|
||||
t.Errorf("N4[1] = %s, want 192.168.0.0", m.N4[1].IP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGeoIPMatcher_IPv6(t *testing.T) {
|
||||
list := &v2geo.GeoIP{
|
||||
Cidr: []*v2geo.CIDR{
|
||||
{Ip: net.ParseIP("2001:db8::"), Prefix: 32},
|
||||
},
|
||||
}
|
||||
m, err := newGeoIPMatcher(list)
|
||||
if err != nil {
|
||||
t.Fatalf("newGeoIPMatcher error: %v", err)
|
||||
}
|
||||
if len(m.N6) != 1 {
|
||||
t.Errorf("expected 1 IPv6 net, got %d", len(m.N6))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGeoIPMatcher_InvalidIPLength(t *testing.T) {
|
||||
list := &v2geo.GeoIP{
|
||||
Cidr: []*v2geo.CIDR{
|
||||
{Ip: []byte{1, 2, 3}, Prefix: 24},
|
||||
},
|
||||
}
|
||||
_, err := newGeoIPMatcher(list)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid IP length")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGeositeMatcher(t *testing.T) {
|
||||
list := &v2geo.GeoSite{
|
||||
Domain: []*v2geo.Domain{
|
||||
{Type: v2geo.Domain_Plain, Value: "google"},
|
||||
{Type: v2geo.Domain_Full, Value: "exact.com"},
|
||||
},
|
||||
}
|
||||
m, err := newGeositeMatcher(list, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("newGeositeMatcher error: %v", err)
|
||||
}
|
||||
if len(m.Domains) != 2 {
|
||||
t.Errorf("expected 2 domains, got %d", len(m.Domains))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGeositeMatcher_WithAttrs(t *testing.T) {
|
||||
list := &v2geo.GeoSite{
|
||||
Domain: []*v2geo.Domain{
|
||||
{
|
||||
Type: v2geo.Domain_RootDomain,
|
||||
Value: "google.com",
|
||||
Attribute: []*v2geo.Domain_Attribute{
|
||||
{Key: "ads"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
m, err := newGeositeMatcher(list, []string{"ads"})
|
||||
if err != nil {
|
||||
t.Fatalf("newGeositeMatcher error: %v", err)
|
||||
}
|
||||
if !m.Match(HostInfo{Name: "www.google.com"}) {
|
||||
t.Error("should match with root domain and attr")
|
||||
}
|
||||
}
|
||||
125
ruleset/expr.go
125
ruleset/expr.go
@@ -26,11 +26,14 @@ import (
|
||||
|
||||
// ExprRule is the external representation of an expression rule.
|
||||
type ExprRule struct {
|
||||
Name string `yaml:"name"`
|
||||
Action string `yaml:"action"`
|
||||
Log bool `yaml:"log"`
|
||||
Modifier ModifierEntry `yaml:"modifier"`
|
||||
Expr string `yaml:"expr"`
|
||||
Name string `yaml:"name"`
|
||||
Action string `yaml:"action"`
|
||||
Log bool `yaml:"log"`
|
||||
Modifier ModifierEntry `yaml:"modifier"`
|
||||
Expr string `yaml:"expr"`
|
||||
StartTime string `yaml:"start_time"`
|
||||
StopTime string `yaml:"stop_time"`
|
||||
Weekdays []string `yaml:"weekdays"`
|
||||
}
|
||||
|
||||
type ModifierEntry struct {
|
||||
@@ -56,6 +59,10 @@ type compiledExprRule struct {
|
||||
ModInstance modifier.Instance
|
||||
Program *vm.Program
|
||||
GeoSiteConditions []string
|
||||
StartTimeSecs int // seconds since midnight, -1 if unset
|
||||
StopTimeSecs int // seconds since midnight, -1 if unset
|
||||
Weekdays []time.Weekday
|
||||
WeekdaysNegated bool
|
||||
}
|
||||
|
||||
var _ Ruleset = (*exprRuleset)(nil)
|
||||
@@ -73,10 +80,13 @@ func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
|
||||
|
||||
func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
||||
env := streamInfoToExprEnv(info)
|
||||
now := time.Now()
|
||||
for _, rule := range r.Rules {
|
||||
if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) {
|
||||
continue
|
||||
}
|
||||
v, err := vm.Run(rule.Program, env)
|
||||
if err != nil {
|
||||
// Log the error and continue to the next rule.
|
||||
r.Logger.MatchError(info, rule.Name, err)
|
||||
continue
|
||||
}
|
||||
@@ -163,12 +173,34 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
||||
depAnMap[name] = a
|
||||
}
|
||||
}
|
||||
startSecs := -1
|
||||
if rule.StartTime != "" {
|
||||
startSecs, err = parseTimeOfDay(rule.StartTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rule %q has invalid start_time: %w", rule.Name, err)
|
||||
}
|
||||
}
|
||||
stopSecs := -1
|
||||
if rule.StopTime != "" {
|
||||
stopSecs, err = parseTimeOfDay(rule.StopTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rule %q has invalid stop_time: %w", rule.Name, err)
|
||||
}
|
||||
}
|
||||
weekdays, weekdaysNegated, err := parseWeekdays(rule.Weekdays)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rule %q has invalid weekdays: %w", rule.Name, err)
|
||||
}
|
||||
cr := compiledExprRule{
|
||||
Name: rule.Name,
|
||||
Action: action,
|
||||
Log: rule.Log,
|
||||
Program: program,
|
||||
GeoSiteConditions: extractGeoSiteConditions(rule.Expr),
|
||||
StartTimeSecs: startSecs,
|
||||
StopTimeSecs: stopSecs,
|
||||
Weekdays: weekdays,
|
||||
WeekdaysNegated: weekdaysNegated,
|
||||
}
|
||||
if action != nil && *action == ActionModify {
|
||||
mod, ok := fullModMap[rule.Modifier.Name]
|
||||
@@ -391,6 +423,87 @@ func buildFunctionMap(config *BuiltinConfig) (map[string]*Function, *geo.GeoMatc
|
||||
}, geoMatcher
|
||||
}
|
||||
|
||||
func matchTime(now time.Time, startSecs, stopSecs int, weekdays []time.Weekday, negated bool) bool {
|
||||
if startSecs >= 0 || stopSecs >= 0 {
|
||||
currentSecs := now.Hour()*3600 + now.Minute()*60 + now.Second()
|
||||
if startSecs >= 0 && stopSecs >= 0 {
|
||||
if startSecs <= stopSecs {
|
||||
if currentSecs < startSecs || currentSecs > stopSecs {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if currentSecs < startSecs && currentSecs > stopSecs {
|
||||
return false
|
||||
}
|
||||
}
|
||||
} else if startSecs >= 0 {
|
||||
if currentSecs < startSecs {
|
||||
return false
|
||||
}
|
||||
} else if currentSecs > stopSecs {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if len(weekdays) > 0 {
|
||||
current := now.Weekday()
|
||||
found := false
|
||||
for _, d := range weekdays {
|
||||
if current == d {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if negated == found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func parseTimeOfDay(s string) (int, error) {
|
||||
t, err := time.Parse("15:04:05", s)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("invalid time %q (expected hh:mm:ss)", s)
|
||||
}
|
||||
return t.Hour()*3600 + t.Minute()*60 + t.Second(), nil
|
||||
}
|
||||
|
||||
func parseWeekdays(days []string) ([]time.Weekday, bool, error) {
|
||||
if len(days) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
negated := false
|
||||
parsed := make([]time.Weekday, 0, len(days))
|
||||
for i, d := range days {
|
||||
d = strings.TrimSpace(d)
|
||||
if i == 0 && strings.HasPrefix(d, "!") {
|
||||
negated = true
|
||||
d = strings.TrimSpace(strings.TrimPrefix(d, "!"))
|
||||
}
|
||||
var wd time.Weekday
|
||||
switch strings.ToLower(d) {
|
||||
case "sun", "sunday":
|
||||
wd = time.Sunday
|
||||
case "mon", "monday":
|
||||
wd = time.Monday
|
||||
case "tue", "tues", "tuesday":
|
||||
wd = time.Tuesday
|
||||
case "wed", "wednesday":
|
||||
wd = time.Wednesday
|
||||
case "thu", "thur", "thurs", "thursday":
|
||||
wd = time.Thursday
|
||||
case "fri", "friday":
|
||||
wd = time.Friday
|
||||
case "sat", "saturday":
|
||||
wd = time.Saturday
|
||||
default:
|
||||
return nil, false, fmt.Errorf("invalid weekday %q", d)
|
||||
}
|
||||
parsed = append(parsed, wd)
|
||||
}
|
||||
return parsed, negated, nil
|
||||
}
|
||||
|
||||
const rulesetLogMetaKey = "_ruleset"
|
||||
|
||||
func addGeoSiteLogMetadata(info StreamInfo, gm *geo.GeoMatcher, conditions []string) StreamInfo {
|
||||
|
||||
121
ruleset/interface_test.go
Normal file
121
ruleset/interface_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package ruleset
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||
)
|
||||
|
||||
func TestAction_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
action Action
|
||||
want string
|
||||
}{
|
||||
{ActionMaybe, "maybe"},
|
||||
{ActionAllow, "allow"},
|
||||
{ActionBlock, "block"},
|
||||
{ActionDrop, "drop"},
|
||||
{ActionModify, "modify"},
|
||||
{Action(99), "unknown"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.want, func(t *testing.T) {
|
||||
if got := tt.action.String(); got != tt.want {
|
||||
t.Errorf("Action.String() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtocol_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
protocol Protocol
|
||||
want string
|
||||
}{
|
||||
{ProtocolTCP, "tcp"},
|
||||
{ProtocolUDP, "udp"},
|
||||
{Protocol(99), "unknown"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.want, func(t *testing.T) {
|
||||
if got := tt.protocol.String(); got != tt.want {
|
||||
t.Errorf("Protocol.String() = %q, want %q", got, tt.protocol)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtocol_Constants(t *testing.T) {
|
||||
if ProtocolTCP != 0 {
|
||||
t.Errorf("ProtocolTCP = %d, want 0", ProtocolTCP)
|
||||
}
|
||||
if ProtocolUDP != 1 {
|
||||
t.Errorf("ProtocolUDP = %d, want 1", ProtocolUDP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAction_Constants(t *testing.T) {
|
||||
if ActionMaybe != 0 {
|
||||
t.Errorf("ActionMaybe = %d, want 0", ActionMaybe)
|
||||
}
|
||||
if ActionAllow != 1 {
|
||||
t.Errorf("ActionAllow = %d, want 1", ActionAllow)
|
||||
}
|
||||
if ActionBlock != 2 {
|
||||
t.Errorf("ActionBlock = %d, want 2", ActionBlock)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamInfo_SrcString(t *testing.T) {
|
||||
info := StreamInfo{
|
||||
SrcIP: net.ParseIP("192.168.1.1"),
|
||||
SrcPort: 8080,
|
||||
}
|
||||
want := "192.168.1.1:8080"
|
||||
if got := info.SrcString(); got != want {
|
||||
t.Errorf("SrcString() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamInfo_DstString(t *testing.T) {
|
||||
info := StreamInfo{
|
||||
DstIP: net.ParseIP("10.0.0.1"),
|
||||
DstPort: 443,
|
||||
}
|
||||
want := "10.0.0.1:443"
|
||||
if got := info.DstString(); got != want {
|
||||
t.Errorf("DstString() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamInfo_SrcString_IPv6(t *testing.T) {
|
||||
info := StreamInfo{
|
||||
SrcIP: net.ParseIP("::1"),
|
||||
SrcPort: 53,
|
||||
}
|
||||
want := "[::1]:53"
|
||||
if got := info.SrcString(); got != want {
|
||||
t.Errorf("SrcString() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchResult_ZeroValue(t *testing.T) {
|
||||
var mr MatchResult
|
||||
if mr.Action != ActionMaybe {
|
||||
t.Errorf("zero MatchResult.Action = %v, want ActionMaybe (0)", mr.Action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamInfo_PropsInitialization(t *testing.T) {
|
||||
info := StreamInfo{
|
||||
Props: analyzer.CombinedPropMap{
|
||||
"tls": analyzer.PropMap{"sni": "example.com"},
|
||||
},
|
||||
}
|
||||
if info.Props.Get("tls", "sni") != "example.com" {
|
||||
t.Error("StreamInfo.Props not properly initialized")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user