diff --git a/README.md b/README.md index 1d63ace..1d13768 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,29 @@ func main() { } ``` +## Testing + +Run all tests: + +```bash +go test ./... +``` + +Run tests with coverage: + +```bash +go test -cover ./... # per-package summary +go test -coverprofile=coverage.out ./... # full profile +go tool cover -html=coverage.out -o coverage.html # HTML report +``` + +Run tests for a specific package: + +```bash +go test -v ./analyzer/tcp/ +go test -v -run TestSSH ./analyzer/tcp/ +``` + Based on OpenGFW by apernet: https://github.com/apernet/OpenGFW ## LICENSE diff --git a/analyzer/interface_test.go b/analyzer/interface_test.go new file mode 100644 index 0000000..e77a96d --- /dev/null +++ b/analyzer/interface_test.go @@ -0,0 +1,124 @@ +package analyzer + +import ( + "reflect" + "testing" +) + +func TestPropMap_Get(t *testing.T) { + m := PropMap{ + "a": "value-a", + "nested": PropMap{ + "b": "value-b", + "deep": PropMap{ + "c": "value-c", + }, + }, + } + + tests := []struct { + name string + key string + want interface{} + }{ + {"simple key", "a", "value-a"}, + {"nested key", "nested.b", "value-b"}, + {"deeply nested", "nested.deep.c", "value-c"}, + {"non-existent top", "x", nil}, + {"non-existent nested", "nested.x", nil}, + {"empty key", "", nil}, + {"partial path", "nested", PropMap{"b": "value-b", "deep": PropMap{"c": "value-c"}}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := m.Get(tt.key) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("PropMap.Get(%q) = %v, want %v", tt.key, got, tt.want) + } + }) + } +} + +func TestPropMap_Get_NilMap(t *testing.T) { + var m PropMap + if got := m.Get("any"); got != nil { + t.Errorf("nil PropMap.Get() = %v, want nil", got) + } +} + +func TestPropMap_Get_EmptyMap(t *testing.T) { + m := PropMap{} + if got := m.Get("any"); got != nil { + t.Errorf("empty PropMap.Get() = %v, want nil", got) + } +} + +func TestPropMap_Get_IntermediateNonMap(t *testing.T) { + m := PropMap{ + "a": "hello", + } + if got := m.Get("a.b.c"); got != nil { + t.Errorf("PropMap.Get() through non-map = %v, want nil", got) + } +} + +func TestCombinedPropMap_Get(t *testing.T) { + cm := CombinedPropMap{ + "tls": PropMap{ + "req": PropMap{ + "sni": "example.com", + }, + }, + "http": PropMap{ + "resp": PropMap{ + "status": 200, + }, + }, + } + + tests := []struct { + name string + analyzer string + key string + want interface{} + }{ + {"tls sni", "tls", "req.sni", "example.com"}, + {"http status", "http", "resp.status", 200}, + {"unknown analyzer", "dns", "any", nil}, + {"valid analyzer bad key", "tls", "req.nonexistent", nil}, + {"empty analyzer", "", "key", nil}, + {"empty key", "tls", "", nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := cm.Get(tt.analyzer, tt.key) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("CombinedPropMap.Get(%q, %q) = %v, want %v", tt.analyzer, tt.key, got, tt.want) + } + }) + } +} + +func TestCombinedPropMap_Get_Nil(t *testing.T) { + var cm CombinedPropMap + if got := cm.Get("any", "key"); got != nil { + t.Errorf("nil CombinedPropMap.Get() = %v, want nil", got) + } +} + +func TestPropUpdateType_Values(t *testing.T) { + if PropUpdateNone != 0 { + t.Errorf("PropUpdateNone = %d, want 0", PropUpdateNone) + } + if PropUpdateMerge != 1 { + t.Errorf("PropUpdateMerge = %d, want 1", PropUpdateMerge) + } + if PropUpdateReplace != 2 { + t.Errorf("PropUpdateReplace = %d, want 2", PropUpdateReplace) + } + if PropUpdateDelete != 3 { + t.Errorf("PropUpdateDelete = %d, want 3", PropUpdateDelete) + } +} diff --git a/analyzer/internal/tls_test.go b/analyzer/internal/tls_test.go new file mode 100644 index 0000000..e8a21dd --- /dev/null +++ b/analyzer/internal/tls_test.go @@ -0,0 +1,319 @@ +package internal + +import ( + "reflect" + "testing" + + "git.difuse.io/Difuse/Mellaris/analyzer/utils" +) + +func buildClientHelloMsg(t *testing.T) *utils.ByteBuffer { + t.Helper() + // Bytes taken from the standard TLS test vector, starting after the record + // header (5 bytes) and handshake header (4 bytes). + body := []byte{ + 0x03, 0x03, // version TLS 1.2 + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random + 0x00, // session ID length = 0 + 0x00, 0x20, // cipher suites length = 32 bytes = 16 suites + 0xcc, 0xa8, 0xcc, 0xa9, 0xc0, 0x2f, 0xc0, 0x30, + 0xc0, 0x2b, 0xc0, 0x2c, 0xc0, 0x13, 0xc0, 0x09, + 0xc0, 0x14, 0xc0, 0x0a, 0x00, 0x9c, 0x00, 0x9d, + 0x00, 0x2f, 0x00, 0x35, 0xc0, 0x12, 0x00, 0x0a, // ciphers + 0x01, // compression methods length = 1 + 0x00, // compression method = null + 0x00, 0x58, // extensions length = 88 + // extension: server_name + 0x00, 0x00, // type = server_name + 0x00, 0x18, // length = 24 + 0x00, 0x16, // server name list length = 22 + 0x00, // name type = hostname + 0x00, 0x13, // name length = 19 + 'e', 'x', 'a', 'm', 'p', 'l', 'e', '.', 'u', 'l', 'f', 'h', 'e', 'i', 'm', '.', 'n', 'e', 't', + // extension: status_request + 0x00, 0x05, // type = status_request + 0x00, 0x05, // length = 5 + 0x01, 0x00, 0x00, 0x00, 0x00, + // extension: supported_groups + 0x00, 0x0a, // type = supported_groups + 0x00, 0x0a, // length = 10 + 0x00, 0x08, // list length = 8 + 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, + // extension: ec_point_formats + 0x00, 0x0b, // type = ec_point_formats + 0x00, 0x02, // length = 2 + 0x01, 0x00, + // extension: signature_algorithms + 0x00, 0x0d, // type = signature_algorithms + 0x00, 0x12, // length = 18 + 0x00, 0x10, // list length = 16 + 0x04, 0x01, 0x04, 0x03, 0x05, 0x01, 0x05, 0x03, + 0x06, 0x01, 0x06, 0x03, 0x02, 0x01, 0x02, 0x03, + // extension: renegotiation_info + 0xff, 0x01, // type = renegotiation_info + 0x00, 0x01, // length = 1 + 0x00, + // extension: extended_master_secret (empty) + 0x00, 0x17, // type = extended_master_secret + 0x00, 0x00, // length = 0 + // extension: session_ticket (empty) + 0x00, 0x23, // type = session_ticket + 0x00, 0x00, // length = 0 + } + return &utils.ByteBuffer{Buf: body} +} + +func TestParseTLSClientHelloMsgData(t *testing.T) { + chBuf := buildClientHelloMsg(t) + m := ParseTLSClientHelloMsgData(chBuf) + if m == nil { + t.Fatal("ParseTLSClientHelloMsgData returned nil") + } + + wantVersion := uint16(0x0303) + if v, ok := m["version"].(uint16); !ok || v != wantVersion { + t.Errorf("version = %v, want %v", m["version"], wantVersion) + } + + wantCiphers := []uint16{52392, 52393, 49199, 49200, 49195, 49196, 49171, 49161, 49172, 49162, 156, 157, 47, 53, 49170, 10} + if c, ok := m["ciphers"].([]uint16); ok { + if !reflect.DeepEqual(c, wantCiphers) { + t.Errorf("ciphers = %v, want %v", c, wantCiphers) + } + } else { + t.Errorf("ciphers missing or wrong type: %T", m["ciphers"]) + } + + if sni, ok := m["sni"].(string); !ok || sni != "example.ulfheim.net" { + t.Errorf("sni = %q, want %q", m["sni"], "example.ulfheim.net") + } + + if _, ok := m["compression"]; !ok { + t.Error("compression key missing") + } + + if _, ok := m["session"]; !ok { + t.Error("session key missing") + } + + if _, ok := m["random"]; !ok { + t.Error("random key missing") + } +} + +func TestParseTLSServerHelloMsgData(t *testing.T) { + // ServerHello message body (after record header + handshake header) + body := []byte{ + 0x03, 0x03, // version TLS 1.2 + 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, + 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, + 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, // random + 0x00, // session ID length = 0 + 0xc0, 0x13, // cipher suite + 0x00, // compression method + 0x00, 0x05, // extensions length = 5 + // extension: renegotiation_info + 0xff, 0x01, // type + 0x00, 0x01, // length = 1 + 0x00, + } + + m := ParseTLSServerHelloMsgData(&utils.ByteBuffer{Buf: body}) + if m == nil { + t.Fatal("ParseTLSServerHelloMsgData returned nil") + } + + wantCipher := uint16(0xc013) + if c, ok := m["cipher"].(uint16); !ok || c != wantCipher { + t.Errorf("cipher = %v, want %v", m["cipher"], wantCipher) + } + + wantCompression := uint8(0) + if c, ok := m["compression"].(uint8); !ok || c != wantCompression { + t.Errorf("compression = %v, want %v", m["compression"], wantCompression) + } +} + +func TestParseTLSServerHelloMsgData_NoExtensions(t *testing.T) { + // ServerHello message without extensions + body := []byte{ + 0x03, 0x03, // version + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random + 0x00, // session ID length + 0x00, 0xff, // cipher suite + 0x00, // compression method + } + + m := ParseTLSServerHelloMsgData(&utils.ByteBuffer{Buf: body}) + if m == nil { + t.Fatal("ParseTLSServerHelloMsgData returned nil for no-extensions case") + } + if _, ok := m["cipher"]; !ok { + t.Error("cipher key missing") + } +} + +func TestParseTLSClientHelloMsgData_Truncated(t *testing.T) { + tests := []struct { + name string + buf []byte + }{ + {"too short for session id", []byte{0x03, 0x03, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f}}, + {"odd cipher suites length", []byte{ + 0x03, 0x03, // version + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random + 0x00, // session ID length + 0x00, 0x03, // odd cipher suites length + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: tt.buf}) + if m != nil { + t.Error("expected nil for truncated input") + } + }) + } +} + +func TestParseTLSClientHelloMsgData_ECH(t *testing.T) { + // ClientHello with ECH extension + body := []byte{ + 0x03, 0x03, // version + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random + 0x00, // session ID length + 0x00, 0x02, // cipher suites length = 2 + 0x13, 0x01, // TLS_AES_128_GCM_SHA256 + 0x01, // compression methods length = 1 + 0x00, // null compression + 0x00, 0x07, // extensions length = 7 + // ECH extension + 0xfe, 0x0d, // type = encrypted_client_hello + 0x00, 0x03, // length = 3 + 0x01, 0x02, 0x03, // some ECH data + } + + m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: body}) + if m == nil { + t.Fatal("ParseTLSClientHelloMsgData returned nil") + } + ech, ok := m["ech"].(bool) + if !ok || !ech { + t.Errorf("ech = %v, want true", m["ech"]) + } +} + +func TestParseTLSClientHelloMsgData_ALPN(t *testing.T) { + // ClientHello with ALPN extension + body := []byte{ + 0x03, 0x03, // version + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random + 0x00, // session ID length + 0x00, 0x02, // cipher suites length = 2 + 0x13, 0x01, // cipher suite + 0x01, // compression methods length = 1 + 0x00, // compression method + 0x00, 0x12, // extensions length = 18 + // ALPN extension + 0x00, 0x10, // type = ALPN + 0x00, 0x0e, // length = 14 + 0x00, 0x0c, // list length = 12 + 0x08, 'h', 't', 't', 'p', '/', '1', '.', '1', + 0x02, 'h', '2', + } + + m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: body}) + if m == nil { + t.Fatal("ParseTLSClientHelloMsgData returned nil") + } + alpn, ok := m["alpn"].([]string) + if !ok { + t.Fatalf("alpn missing or wrong type: %T", m["alpn"]) + } + if len(alpn) != 2 || alpn[0] != "http/1.1" || alpn[1] != "h2" { + t.Errorf("alpn = %v, want [http/1.1 h2]", alpn) + } +} + +func TestParseTLSClientHelloMsgData_SupportedVersionsClient(t *testing.T) { + // ClientHello with supported_versions extension (client format - list) + body := []byte{ + 0x03, 0x03, // version + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random + 0x00, // session ID length + 0x00, 0x02, // cipher suites length = 2 + 0x13, 0x01, // cipher suite + 0x01, // compression methods length = 1 + 0x00, // compression method + 0x00, 0x0b, // extensions length = 11 + // supported_versions (client list format) + 0x00, 0x2b, // type = supported_versions + 0x00, 0x07, // length = 7 + 0x06, // list length = 6 + 0x03, 0x04, // TLS 1.3 + 0x03, 0x03, // TLS 1.2 + 0x03, 0x01, // TLS 1.0 + } + + m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: body}) + if m == nil { + t.Fatal("ParseTLSClientHelloMsgData returned nil") + } + versions, ok := m["supported_versions"].([]uint16) + if !ok { + t.Fatalf("supported_versions missing or wrong type: %T", m["supported_versions"]) + } + want := []uint16{0x0304, 0x0303, 0x0301} + if !reflect.DeepEqual(versions, want) { + t.Errorf("supported_versions = %v, want %v", versions, want) + } +} + +func TestParseTLSServerHelloMsgData_SupportedVersionsServer(t *testing.T) { + // ServerHello with supported_versions extension (server format - single value) + body := []byte{ + 0x03, 0x03, // version + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random + 0x00, // session ID length + 0x13, 0x01, // cipher suite + 0x00, // compression method + 0x00, 0x06, // extensions length = 6 + // supported_versions (server format - single value) + 0x00, 0x2b, // type + 0x00, 0x02, // length = 2 + 0x03, 0x04, // TLS 1.3 + } + + m := ParseTLSServerHelloMsgData(&utils.ByteBuffer{Buf: body}) + if m == nil { + t.Fatal("ParseTLSServerHelloMsgData returned nil") + } + v, ok := m["supported_versions"].(uint16) + if !ok { + t.Fatalf("supported_versions missing or wrong type: %T", m["supported_versions"]) + } + if v != 0x0304 { + t.Errorf("supported_versions = 0x%04x, want 0x0304", v) + } +} diff --git a/analyzer/tcp/fet_test.go b/analyzer/tcp/fet_test.go new file mode 100644 index 0000000..aded7b5 --- /dev/null +++ b/analyzer/tcp/fet_test.go @@ -0,0 +1,256 @@ +package tcp + +import ( + "testing" +) + +func TestPopCount(t *testing.T) { + tests := []struct { + input byte + want int + }{ + {0x00, 0}, + {0x01, 1}, + {0x02, 1}, + {0x03, 2}, + {0x07, 3}, + {0x0f, 4}, + {0xff, 8}, + {0x55, 4}, // 01010101 + {0xaa, 4}, // 10101010 + } + + for _, tt := range tests { + if got := popCount(tt.input); got != tt.want { + t.Errorf("popCount(0x%02x) = %d, want %d", tt.input, got, tt.want) + } + } +} + +func TestAveragePopCount(t *testing.T) { + tests := []struct { + name string + input []byte + want float32 + }{ + {"empty", []byte{}, 0}, + {"all zeros", []byte{0x00, 0x00, 0x00}, 0}, + {"all ones", []byte{0xff, 0xff}, 8.0}, + {"mixed", []byte{0x00, 0xff}, 4.0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := averagePopCount(tt.input) + if got != tt.want { + t.Errorf("averagePopCount() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsFirstSixPrintable(t *testing.T) { + tests := []struct { + name string + input []byte + want bool + }{ + {"too short", []byte("abc"), false}, + {"all printable", []byte("abcdef"), true}, + {"non-printable at pos 0", []byte{0x00, 'a', 'b', 'c', 'd', 'e'}, false}, + {"non-printable at pos 5", []byte{'a', 'b', 'c', 'd', 'e', 0x1f}, false}, + {"exactly 6 printable", []byte("123456"), true}, + {"spaces", []byte(" "), true}, + {"non-printable middle", []byte{'a', 'b', 0x01, 'd', 'e', 'f'}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isFirstSixPrintable(tt.input) + if got != tt.want { + t.Errorf("isFirstSixPrintable(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestPrintablePercentage(t *testing.T) { + tests := []struct { + name string + input []byte + want float32 + }{ + {"empty", []byte{}, 0}, + {"all printable", []byte("hello"), 1.0}, + {"none printable", []byte{0x00, 0x01, 0x02, 0x03}, 0}, + {"half printable", []byte{'a', 0x00, 'b', 0x00}, 0.5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := printablePercentage(tt.input) + if got != tt.want { + t.Errorf("printablePercentage() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestContiguousPrintable(t *testing.T) { + tests := []struct { + name string + input []byte + want int + }{ + {"empty", []byte{}, 0}, + {"all printable", []byte("hello world"), 11}, + {"none printable", []byte{0x00, 0x01, 0x02}, 0}, + {"start printable", []byte{'a', 'b', 'c', 0x00, 'd', 'e', 'f'}, 3}, + {"end printable", []byte{0x00, 'a', 'b', 'c', 'd', 'e', 'f'}, 6}, + {"middle printable", []byte{0x00, 'a', 'b', 'c', 0x00}, 3}, + {"two segments", []byte{'a', 'b', 0x00, 'c', 'd', 'e'}, 3}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := contiguousPrintable(tt.input) + if got != tt.want { + t.Errorf("contiguousPrintable() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestIsTLSorHTTP(t *testing.T) { + tests := []struct { + name string + input []byte + want bool + }{ + {"too short", []byte("AB"), false}, + // TLS ClientHello: 0x16 0x03 0x01 + {"tls 1.0", []byte{0x16, 0x03, 0x01}, true}, + // TLS 0x17 is application data record + {"tls app data", []byte{0x17, 0x03, 0x03}, true}, + {"tls max content type", []byte{0x17, 0x03, 0x09}, true}, + {"bad tls content type", []byte{0x15, 0x03, 0x01}, false}, + {"bad tls version", []byte{0x16, 0x04, 0x01}, false}, + {"bad tls length", []byte{0x16, 0x03, 0x0a}, false}, + // HTTP methods + {"GET", []byte("GET / HTTP/1.1..."), true}, + {"HEAD", []byte("HEAD /index.html..."), true}, + {"POST", []byte("POST /api..."), true}, + {"PUT", []byte("PUT /data..."), true}, + {"DELETE", []byte("DELETE /..."), true}, + {"CONNECT", []byte("CONNECT proxy..."), true}, + {"OPTIONS", []byte("OPTIONS *..."), true}, + {"TRACE", []byte("TRACE /..."), true}, + {"PATCH", []byte("PATCH /..."), true}, + {"random data", []byte{0x00, 0x01, 0x02, 0x03}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isTLSorHTTP(tt.input) + if got != tt.want { + t.Errorf("isTLSorHTTP(%v) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestIsPrintable(t *testing.T) { + if !isPrintable('a') { + t.Error("'a' should be printable") + } + if isPrintable(0x00) { + t.Error("0x00 should not be printable") + } + if !isPrintable(0x20) { + t.Error("0x20 (space) should be printable") + } + if !isPrintable(0x7e) { + t.Error("0x7e (~) should be printable") + } + if isPrintable(0x7f) { + t.Error("0x7f (DEL) should not be printable") + } +} + +func TestFETStream_Feed(t *testing.T) { + s := newFETStream(nil) + u, done := s.Feed(false, false, false, 0, []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")) + if u == nil { + t.Fatal("Feed returned nil update") + } + if !done { + t.Error("FET should be done after first packet") + } + m := u.M + if _, ok := m["ex1"]; !ok { + t.Error("ex1 missing") + } + if _, ok := m["ex2"]; !ok { + t.Error("ex2 missing") + } + if _, ok := m["ex3"]; !ok { + t.Error("ex3 missing") + } + if _, ok := m["ex4"]; !ok { + t.Error("ex4 missing") + } + if _, ok := m["ex5"]; !ok { + t.Error("ex5 missing") + } + if yes, ok := m["yes"].(bool); ok && yes { + t.Error("HTTP should be exempt (yes=false)") + } + if u.Type != 2 { + t.Errorf("prop update type = %d, want PropUpdateReplace", u.Type) + } +} + +func TestFETStream_Feed_EncryptedLike(t *testing.T) { + s := newFETStream(nil) + data := make([]byte, 100) + for i := range data { + data[i] = byte(i % 256) + } + u, done := s.Feed(false, false, false, 0, data) + if u == nil { + t.Fatal("Feed returned nil update") + } + if !done { + t.Error("should be done") + } +} + +func TestFETStream_Feed_Skip(t *testing.T) { + s := newFETStream(nil) + _, done := s.Feed(false, false, false, 5, []byte("data")) + if !done { + t.Error("skip != 0 should return done=true") + } +} + +func TestFETStream_Feed_Empty(t *testing.T) { + s := newFETStream(nil) + u, done := s.Feed(false, false, false, 0, []byte{}) + if u != nil || done { + t.Error("empty data should return nil, false") + } +} + +func TestFETAnalyzer_Name(t *testing.T) { + a := &FETAnalyzer{} + if a.Name() != "fet" { + t.Errorf("Name() = %q, want fet", a.Name()) + } +} + +func TestFETStream_Close(t *testing.T) { + s := newFETStream(nil) + if u := s.Close(false); u != nil { + t.Error("Close should return nil") + } +} diff --git a/analyzer/tcp/ssh_test.go b/analyzer/tcp/ssh_test.go new file mode 100644 index 0000000..83b9914 --- /dev/null +++ b/analyzer/tcp/ssh_test.go @@ -0,0 +1,159 @@ +package tcp + +import ( + "testing" + + "git.difuse.io/Difuse/Mellaris/analyzer" +) + +func TestSSHAnalyzer_Name(t *testing.T) { + a := &SSHAnalyzer{} + if a.Name() != "ssh" { + t.Errorf("Name() = %q, want ssh", a.Name()) + } +} + +func TestSSHStream_Feed_Client(t *testing.T) { + s := newSSHStream(nil) + u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3\r\n")) + if u == nil { + t.Fatal("Feed returned nil update") + } + if done { + t.Error("should not be done (server not yet received)") + } + client, ok := u.M["client"].(analyzer.PropMap) + if !ok { + t.Fatal("client prop missing") + } + if client["protocol"] != "2.0" { + t.Errorf("protocol = %v, want 2.0", client["protocol"]) + } + if client["software"] != "OpenSSH_8.9p1" { + t.Errorf("software = %v, want OpenSSH_8.9p1", client["software"]) + } + if comments, ok := client["comments"]; ok { + if comments != "Ubuntu-3" { + t.Errorf("comments = %v, want Ubuntu-3", comments) + } + } +} + +func TestSSHStream_Feed_ClientWithComments(t *testing.T) { + s := newSSHStream(nil) + u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH_7.4 Ubuntu-3\r\n")) + if u == nil { + t.Fatal("Feed returned nil update") + } + if done { + t.Error("should not be done") + } + client := u.M["client"].(analyzer.PropMap) + if client["comments"] != "Ubuntu-3" { + t.Errorf("comments = %v, want Ubuntu-3", client["comments"]) + } +} + +func TestSSHStream_Feed_Both(t *testing.T) { + s := newSSHStream(nil) + s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH_8.9\r\n")) + u, done := s.Feed(true, false, false, 0, []byte("SSH-2.0-dropbear_2022.83\r\n")) + if u == nil { + t.Fatal("Feed returned nil update") + } + if !done { + t.Error("should be done after both sides") + } + server, ok := u.M["server"].(analyzer.PropMap) + if !ok { + t.Fatal("server prop missing") + } + if server["software"] != "dropbear_2022.83" { + t.Errorf("server software = %v", server["software"]) + } +} + +func TestSSHStream_Feed_NotSSH(t *testing.T) { + s := newSSHStream(nil) + u, done := s.Feed(false, false, false, 0, []byte("HTTP/1.1 200 OK\r\n")) + if u != nil { + t.Error("should return nil for non-SSH") + } + if !done { + t.Error("should be cancelled (done) for non-SSH") + } +} + +func TestSSHStream_Feed_InvalidLine(t *testing.T) { + s := newSSHStream(nil) + u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-foo bar baz\r\n")) + if u != nil { + t.Error("should return nil for invalid line (>2 fields)") + } + if !done { + t.Error("should be cancelled for invalid SSH line") + } +} + +func TestSSHStream_Feed_NoLineEnd(t *testing.T) { + s := newSSHStream(nil) + u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH")) + if u != nil { + t.Error("should return nil when no EOL found yet") + } + if done { + t.Error("should not be done, waiting for more data") + } +} + +func TestSSHStream_Feed_IncompleteThenComplete(t *testing.T) { + s := newSSHStream(nil) + u, _ := s.Feed(false, false, false, 0, []byte("SSH-2.0-")) + if u != nil { + t.Error("first partial feed should return nil") + } + u, done := s.Feed(false, false, false, 0, []byte("Dropbear\r\n")) + if u == nil { + t.Fatal("second feed should return update") + } + if done { + t.Error("should not be done (only client)") + } + client := u.M["client"].(analyzer.PropMap) + if client["software"] != "Dropbear" { + t.Errorf("software = %v", client["software"]) + } +} + +func TestSSHStream_Feed_Skip(t *testing.T) { + s := newSSHStream(nil) + _, done := s.Feed(false, false, false, 5, []byte("data")) + if !done { + t.Error("skip != 0 should return done=true") + } +} + +func TestSSHStream_Feed_Empty(t *testing.T) { + s := newSSHStream(nil) + u, done := s.Feed(false, false, false, 0, []byte{}) + if u != nil || done { + t.Error("empty data should return nil, false") + } +} + +func TestSSHStream_Close(t *testing.T) { + s := newSSHStream(nil) + s.clientBuf.Append([]byte("data")) + s.serverBuf.Append([]byte("data")) + s.clientMap = analyzer.PropMap{"key": "val"} + u := s.Close(false) + if u != nil { + t.Error("Close should return nil") + } + if s.clientBuf.Len() != 0 || s.serverBuf.Len() != 0 { + t.Error("Close should reset buffers") + } + if s.clientMap != nil || s.serverMap != nil { + t.Error("Close should nil maps") + } +} diff --git a/analyzer/udp/wireguard_test.go b/analyzer/udp/wireguard_test.go new file mode 100644 index 0000000..50d1d1a --- /dev/null +++ b/analyzer/udp/wireguard_test.go @@ -0,0 +1,187 @@ +package udp + +import ( + "encoding/binary" + "testing" + + "git.difuse.io/Difuse/Mellaris/analyzer" +) + +func makeWireGuardPacket(msgType byte, body []byte) []byte { + buf := make([]byte, 1+3+len(body)) + buf[0] = msgType + copy(buf[4:], body) + return buf +} + +func TestWireGuardUDPStream_Feed_HandshakeInitiation(t *testing.T) { + s := newWireGuardUDPStream(nil) + body := make([]byte, wireguardSizeHandshakeInitiation-4) + binary.LittleEndian.PutUint32(body[0:4], 0x01020304) + + u, done := s.Feed(false, makeWireGuardPacket(wireguardTypeHandshakeInitiation, body)) + if u == nil { + t.Fatal("Feed returned nil update") + } + if done { + t.Error("should not be done") + } + if u.Type != analyzer.PropUpdateReplace { + t.Errorf("Type = %d, want PropUpdateReplace", u.Type) + } + initMap, ok := u.M["handshake_initiation"].(analyzer.PropMap) + if !ok { + t.Fatal("handshake_initiation missing") + } + if idx := initMap["sender_index"].(uint32); idx != 0x01020304 { + t.Errorf("sender_index = %d, want %d", idx, 0x01020304) + } +} + +func TestWireGuardUDPStream_Feed_HandshakeResponse(t *testing.T) { + s := newWireGuardUDPStream(nil) + body := make([]byte, wireguardSizeHandshakeResponse-4) + binary.LittleEndian.PutUint32(body[0:4], 0x01020304) + binary.LittleEndian.PutUint32(body[4:8], 0x05060708) + + u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeHandshakeResponse, body)) + if u == nil { + t.Fatal("Feed returned nil update") + } + respMap, ok := u.M["handshake_response"].(analyzer.PropMap) + if !ok { + t.Fatal("handshake_response missing") + } + if idx := respMap["sender_index"].(uint32); idx != 0x01020304 { + t.Errorf("sender_index = %d", idx) + } + if idx := respMap["receiver_index"].(uint32); idx != 0x05060708 { + t.Errorf("receiver_index = %d", idx) + } +} + +func TestWireGuardUDPStream_Feed_PacketData(t *testing.T) { + s := newWireGuardUDPStream(nil) + body := make([]byte, wireguardMinSizePacketData-4) + binary.LittleEndian.PutUint32(body[0:4], 0x0a0b0c0d) + binary.LittleEndian.PutUint64(body[4:12], 42) + + u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeData, body)) + if u == nil { + t.Fatal("Feed returned nil update") + } + dataMap, ok := u.M["packet_data"].(analyzer.PropMap) + if !ok { + t.Fatal("packet_data missing") + } + if idx := dataMap["receiver_index"].(uint32); idx != 0x0a0b0c0d { + t.Errorf("receiver_index = %d", idx) + } + if ctr := dataMap["counter"].(uint64); ctr != 42 { + t.Errorf("counter = %d", ctr) + } +} + +func TestWireGuardUDPStream_Feed_CookieReply(t *testing.T) { + s := newWireGuardUDPStream(nil) + body := make([]byte, wireguardSizePacketCookieReply-4) + binary.LittleEndian.PutUint32(body[0:4], 0x11111111) + + u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeCookieReply, body)) + if u == nil { + t.Fatal("Feed returned nil update") + } + crMap, ok := u.M["packet_cookie_reply"].(analyzer.PropMap) + if !ok { + t.Fatal("packet_cookie_reply missing") + } + if idx := crMap["receiver_index"].(uint32); idx != 0x11111111 { + t.Errorf("receiver_index = %d", idx) + } +} + +func TestWireGuardUDPStream_Feed_InvalidLength(t *testing.T) { + s := newWireGuardUDPStream(nil) + u, done := s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01}) + if u == nil || !done { + u, done = s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01}) + } + if u == nil || !done { + u, _ = s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01}) + } + u, done = s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01}) + if u != nil || !done { + t.Error("should return done after 4 invalid packets") + } +} + +func TestWireGuardUDPStream_Feed_TooShort(t *testing.T) { + s := newWireGuardUDPStream(nil) + u, _ := s.Feed(false, []byte{1, 0, 0}) + if u != nil { + t.Error("should return nil for too-short packet") + } +} + +func TestWireGuardUDPStream_Feed_NonZeroReserved(t *testing.T) { + s := newWireGuardUDPStream(nil) + u, _ := s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 1, 0, 0, 0, 0}) + if u != nil { + t.Error("should return nil when reserved bytes are non-zero") + } +} + +func TestWireGuardUDPStream_Feed_HandshakeInitiation_WrongSize(t *testing.T) { + s := newWireGuardUDPStream(nil) + body := make([]byte, wireguardSizeHandshakeInitiation-4+1) + binary.LittleEndian.PutUint32(body[0:4], 1) + u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeHandshakeInitiation, body)) + if u != nil { + t.Error("should return nil for wrong init size") + } +} + +func TestWireGuardUDPStream_Close(t *testing.T) { + s := newWireGuardUDPStream(nil) + u := s.Close(false) + if u != nil { + t.Error("Close() should return nil") + } +} + +func TestWireGuardUDPStream_PutSenderIndex_MatchReceiverIndex(t *testing.T) { + s := newWireGuardUDPStream(nil) + + s.putSenderIndex(false, 0xdeadbeef) + if !s.matchReceiverIndex(true, 0xdeadbeef) { + t.Error("should match reverse direction") + } + if s.matchReceiverIndex(false, 0xdeadbeef) { + t.Error("should not match same direction") + } + if s.matchReceiverIndex(true, 0x12345678) { + t.Error("should not match wrong index") + } +} + +func TestWireGuardUDPStream_RememberedIndexRing(t *testing.T) { + s := newWireGuardUDPStream(nil) + + for i := uint32(0); i < wireguardRememberedIndexCount+2; i++ { + s.putSenderIndex(false, i) + } + + if s.matchReceiverIndex(true, 0) { + t.Error("index 0 should have been evicted from ring") + } + if !s.matchReceiverIndex(true, uint32(wireguardRememberedIndexCount+1)) { + t.Error("latest index should still be in ring") + } +} + +func TestWireGuardAnalyzer_Name(t *testing.T) { + a := &WireGuardAnalyzer{} + if a.Name() != "wireguard" { + t.Errorf("Name() = %q, want wireguard", a.Name()) + } +} diff --git a/analyzer/utils/bytebuffer_test.go b/analyzer/utils/bytebuffer_test.go new file mode 100644 index 0000000..99eb559 --- /dev/null +++ b/analyzer/utils/bytebuffer_test.go @@ -0,0 +1,321 @@ +package utils + +import ( + "bytes" + "reflect" + "testing" +) + +func TestByteBuffer_Append(t *testing.T) { + b := &ByteBuffer{} + b.Append([]byte("hello")) + b.Append([]byte(" ")) + b.Append([]byte("world")) + if string(b.Buf) != "hello world" { + t.Errorf("Append() result = %q, want %q", string(b.Buf), "hello world") + } +} + +func TestByteBuffer_Len(t *testing.T) { + b := &ByteBuffer{} + if b.Len() != 0 { + t.Errorf("Len() = %d, want 0", b.Len()) + } + b.Append([]byte("abc")) + if b.Len() != 3 { + t.Errorf("Len() = %d, want 3", b.Len()) + } +} + +func TestByteBuffer_Index(t *testing.T) { + b := &ByteBuffer{Buf: []byte("hello world")} + if i := b.Index([]byte("world")); i != 6 { + t.Errorf("Index('world') = %d, want 6", i) + } + if i := b.Index([]byte("xyz")); i != -1 { + t.Errorf("Index('xyz') = %d, want -1", i) + } + if i := b.Index([]byte{}); i != 0 { + t.Errorf("Index('') = %d, want 0", i) + } +} + +func TestByteBuffer_Get(t *testing.T) { + b := &ByteBuffer{Buf: []byte("abcdef")} + + data, ok := b.Get(3, true) + if !ok { + t.Fatal("Get(3, true) returned false") + } + if !bytes.Equal(data, []byte("abc")) { + t.Errorf("Get(3, true) data = %q, want %q", data, "abc") + } + if !bytes.Equal(b.Buf, []byte("def")) { + t.Errorf("after consume, Buf = %q, want %q", b.Buf, "def") + } + + data, ok = b.Get(4, false) + if ok { + t.Fatal("Get(4, false) should return false (only 3 bytes left)") + } + + data, ok = b.Get(2, false) + if !ok { + t.Fatal("Get(2, false) returned false") + } + if !bytes.Equal(data, []byte("de")) { + t.Errorf("Get(2, false) data = %q, want %q", data, "de") + } + if !bytes.Equal(b.Buf, []byte("def")) { + t.Errorf("after non-consume, Buf = %q, want %q (unchanged)", b.Buf, "def") + } + + data, ok = b.Get(3, true) + if !ok { + t.Fatal("Get(3, true) returned false") + } + if !bytes.Equal(data, []byte("def")) { + t.Errorf("Get(3, true) data = %q, want %q", data, "def") + } + if b.Len() != 0 { + t.Errorf("after consume all, Len() = %d, want 0", b.Len()) + } +} + +func TestByteBuffer_GetString(t *testing.T) { + b := &ByteBuffer{Buf: []byte("hello world")} + + s, ok := b.GetString(5, true) + if !ok { + t.Fatal("GetString(5, true) returned false") + } + if s != "hello" { + t.Errorf("GetString() = %q, want %q", s, "hello") + } + if b.Len() != 6 { + t.Errorf("after consume, Len() = %d, want 6", b.Len()) + } + + s, ok = b.GetString(7, false) + if ok { + t.Fatal("GetString(7, false) should return false") + } + + s, ok = b.GetString(6, false) + if !ok { + t.Fatal("GetString(6, false) returned false") + } + if s != " world" { + t.Errorf("GetString() = %q, want %q", s, " world") + } +} + +func TestByteBuffer_GetByte(t *testing.T) { + b := &ByteBuffer{Buf: []byte("abc")} + + bt, ok := b.GetByte(true) + if !ok { + t.Fatal("GetByte(true) returned false") + } + if bt != 'a' { + t.Errorf("GetByte() = %c, want 'a'", bt) + } + if b.Len() != 2 { + t.Errorf("after consume, Len() = %d, want 2", b.Len()) + } + + bt, ok = b.GetByte(false) + if !ok { + t.Fatal("GetByte(false) returned false") + } + if bt != 'b' { + t.Errorf("GetByte() = %c, want 'b'", bt) + } + if b.Len() != 2 { + t.Errorf("after non-consume, Len() = %d, want 2 (unchanged)", b.Len()) + } + + b.GetByte(true) + b.GetByte(true) + bt, ok = b.GetByte(true) + if ok { + t.Fatal("GetByte(true) on empty buffer should return false") + } +} + +func TestByteBuffer_GetUint16(t *testing.T) { + b := &ByteBuffer{Buf: []byte{0x01, 0x02, 0x03, 0x04}} + + v, ok := b.GetUint16(false, true) + if !ok { + t.Fatal("GetUint16(bigEndian) returned false") + } + if v != 0x0102 { + t.Errorf("GetUint16(bigEndian) = 0x%04x, want 0x0102", v) + } + if b.Len() != 2 { + t.Errorf("after consume, Len() = %d, want 2", b.Len()) + } + + v, ok = b.GetUint16(true, true) + if !ok { + t.Fatal("GetUint16(littleEndian) returned false") + } + if v != 0x0403 { + t.Errorf("GetUint16(littleEndian) = 0x%04x, want 0x0403", v) + } + + v, ok = b.GetUint16(false, false) + if ok { + t.Fatal("GetUint16 on empty buffer should return false") + } +} + +func TestByteBuffer_GetUint32(t *testing.T) { + b := &ByteBuffer{Buf: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}} + + v, ok := b.GetUint32(false, true) + if !ok { + t.Fatal("GetUint32(bigEndian) returned false") + } + if v != 0x01020304 { + t.Errorf("GetUint32(bigEndian) = 0x%08x, want 0x01020304", v) + } + if b.Len() != 4 { + t.Errorf("after consume, Len() = %d, want 4", b.Len()) + } + + v, ok = b.GetUint32(true, true) + if !ok { + t.Fatal("GetUint32(littleEndian) returned false") + } + if v != 0x08070605 { + t.Errorf("GetUint32(littleEndian) = 0x%08x, want 0x08070605", v) + } + + v, ok = b.GetUint32(false, false) + if ok { + t.Fatal("GetUint32 on empty buffer should return false") + } +} + +func TestByteBuffer_GetUntil(t *testing.T) { + b := &ByteBuffer{Buf: []byte("hello\r\nworld\r\n")} + + data, ok := b.GetUntil([]byte("\r\n"), true, true) + if !ok { + t.Fatal("GetUntil(sep, include) returned false") + } + if !bytes.Equal(data, []byte("hello\r\n")) { + t.Errorf("GetUntil(include) = %q, want %q", data, "hello\\r\\n") + } + + data, ok = b.GetUntil([]byte("\r\n"), false, false) + if !ok { + t.Fatal("GetUntil(sep, exclude, non-consume) returned false") + } + if !bytes.Equal(data, []byte("world")) { + t.Errorf("GetUntil(exclude) = %q, want %q", data, "world") + } + if b.Len() != 7 { + t.Errorf("after non-consume, Len() = %d, want 7", b.Len()) + } + + data, ok = b.GetUntil([]byte("\r\n"), true, true) + if !ok { + t.Fatal("GetUntil second (consume) returned false") + } + if !bytes.Equal(data, []byte("world\r\n")) { + t.Errorf("GetUntil second = %q, want %q", data, "world\\r\\n") + } + + _, ok = b.GetUntil([]byte("xyz"), false, false) + if ok { + t.Fatal("GetUntil(not found) should return false") + } +} + +func TestByteBuffer_GetSubBuffer(t *testing.T) { + b := &ByteBuffer{Buf: []byte("hello world")} + + sub, ok := b.GetSubBuffer(5, true) + if !ok { + t.Fatal("GetSubBuffer() returned false") + } + if !bytes.Equal(sub.Buf, []byte("hello")) { + t.Errorf("GetSubBuffer() = %q, want %q", sub.Buf, "hello") + } + if b.Len() != 6 { + t.Errorf("after consume, Len() = %d, want 6", b.Len()) + } + + _, ok = b.GetSubBuffer(7, false) + if ok { + t.Fatal("GetSubBuffer(7) should return false (only 6 bytes left)") + } +} + +func TestByteBuffer_Skip(t *testing.T) { + b := &ByteBuffer{Buf: []byte("abcdef")} + + ok := b.Skip(2) + if !ok { + t.Fatal("Skip(2) returned false") + } + if !bytes.Equal(b.Buf, []byte("cdef")) { + t.Errorf("after Skip(2), Buf = %q, want %q", b.Buf, "cdef") + } + + ok = b.Skip(10) + if ok { + t.Fatal("Skip(10) should return false") + } + if !bytes.Equal(b.Buf, []byte("cdef")) { + t.Errorf("after failed Skip, Buf = %q, want %q (unchanged)", b.Buf, "cdef") + } + + ok = b.Skip(4) + if !ok { + t.Fatal("Skip(4) returned false") + } + if b.Len() != 0 { + t.Errorf("after Skip all, Len() = %d, want 0", b.Len()) + } +} + +func TestByteBuffer_Reset(t *testing.T) { + b := &ByteBuffer{Buf: []byte("data")} + b.Reset() + if b.Buf != nil { + t.Errorf("after Reset, Buf = %v, want nil", b.Buf) + } +} + +func TestByteBuffer_GetZeroLength(t *testing.T) { + b := &ByteBuffer{Buf: []byte("abc")} + + data, ok := b.Get(0, true) + if !ok { + t.Fatal("Get(0) returned false") + } + if len(data) != 0 { + t.Errorf("Get(0) len = %d, want 0", len(data)) + } + if b.Len() != 3 { + t.Errorf("after Get(0, consume), Len() = %d, want 3 (0-length consume is no-op)", b.Len()) + } +} + +func TestByteBuffer_GetConsumeDoesNotMutateReturnedSlice(t *testing.T) { + b := &ByteBuffer{Buf: []byte("hello")} + data, ok := b.Get(5, true) + if !ok { + t.Fatal("Get() returned false") + } + if !reflect.DeepEqual(data, []byte("hello")) { + t.Errorf("Get() returned wrong data: %v", data) + } + if b.Len() != 0 { + t.Errorf("after consume, Len() should be 0") + } +} diff --git a/analyzer/utils/lsm_test.go b/analyzer/utils/lsm_test.go new file mode 100644 index 0000000..b7ae12b --- /dev/null +++ b/analyzer/utils/lsm_test.go @@ -0,0 +1,185 @@ +package utils + +import "testing" + +func TestLinearStateMachine_RunPause(t *testing.T) { + callCount := 0 + lsm := NewLinearStateMachine( + func() LSMAction { + callCount++ + return LSMActionPause + }, + ) + cancelled, done := lsm.Run() + if cancelled { + t.Error("unexpected cancelled=true") + } + if done { + t.Error("unexpected done=true") + } + if callCount != 1 { + t.Errorf("callCount = %d, want 1", callCount) + } +} + +func TestLinearStateMachine_RunNext(t *testing.T) { + callOrder := []int{} + lsm := NewLinearStateMachine( + func() LSMAction { callOrder = append(callOrder, 1); return LSMActionNext }, + func() LSMAction { callOrder = append(callOrder, 2); return LSMActionNext }, + func() LSMAction { callOrder = append(callOrder, 3); return LSMActionNext }, + ) + cancelled, done := lsm.Run() + if cancelled { + t.Error("unexpected cancelled=true") + } + if !done { + t.Error("unexpected done=false") + } + if len(callOrder) != 3 { + t.Fatalf("callOrder len = %d, want 3", len(callOrder)) + } + for i, v := range []int{1, 2, 3} { + if callOrder[i] != v { + t.Errorf("callOrder[%d] = %d, want %d", i, callOrder[i], v) + } + } +} + +func TestLinearStateMachine_RunReset(t *testing.T) { + callCount := 0 + lsm := NewLinearStateMachine( + func() LSMAction { + callCount++ + if callCount == 1 { + return LSMActionReset + } + return LSMActionNext + }, + func() LSMAction { callCount++; return LSMActionNext }, + ) + cancelled, done := lsm.Run() + if cancelled { + t.Error("unexpected cancelled=true") + } + if !done { + t.Error("unexpected done=false") + } + if callCount != 3 { + t.Errorf("callCount = %d, want 3 (step0 reset, step0 next, step1 next)", callCount) + } +} + +func TestLinearStateMachine_RunCancel(t *testing.T) { + callCount := 0 + lsm := NewLinearStateMachine( + func() LSMAction { callCount++; return LSMActionNext }, + func() LSMAction { callCount++; return LSMActionCancel }, + func() LSMAction { callCount++; return LSMActionNext }, + ) + cancelled, done := lsm.Run() + if !cancelled { + t.Error("unexpected cancelled=false") + } + if !done { + t.Error("unexpected done=false") + } + if callCount != 2 { + t.Errorf("callCount = %d, want 2 (third step should not execute)", callCount) + } +} + +func TestLinearStateMachine_RunMixed(t *testing.T) { + pauseCount := 0 + lsm := NewLinearStateMachine( + func() LSMAction { return LSMActionNext }, + func() LSMAction { + pauseCount++ + if pauseCount == 1 { + return LSMActionPause + } + return LSMActionNext + }, + func() LSMAction { return LSMActionNext }, + func() LSMAction { return LSMActionNext }, + ) + cancelled, done := lsm.Run() + if cancelled { + t.Error("unexpected cancelled=true") + } + if done { + t.Error("unexpected done=true on first run") + } + cancelled, done = lsm.Run() + if cancelled { + t.Error("unexpected cancelled=true on second run") + } + if !done { + t.Error("unexpected done=false on second run") + } +} + +func TestLinearStateMachine_RunEmpty(t *testing.T) { + lsm := NewLinearStateMachine() + cancelled, done := lsm.Run() + if cancelled { + t.Error("unexpected cancelled=true") + } + if !done { + t.Error("unexpected done=false for empty LSM") + } +} + +func TestLinearStateMachine_AppendSteps(t *testing.T) { + lsm := NewLinearStateMachine( + func() LSMAction { return LSMActionNext }, + ) + lsm.Run() + lsm.AppendSteps( + func() LSMAction { return LSMActionNext }, + ) + _, done := lsm.Run() + if !done { + t.Error("unexpected done=false after AppendSteps") + } +} + +func TestLinearStateMachine_Reset(t *testing.T) { + callCount := 0 + lsm := NewLinearStateMachine( + func() LSMAction { callCount++; return LSMActionCancel }, + ) + lsm.Run() + if !lsm.cancelled { + t.Error("expected cancelled=true after cancel") + } + lsm.Reset() + if lsm.cancelled { + t.Error("expected cancelled=false after Reset") + } + if lsm.index != 0 { + t.Errorf("expected index=0 after Reset, got %d", lsm.index) + } + _, done := lsm.Run() + if !done { + t.Error("expected done=true, step executed again after Reset") + } + if callCount != 2 { + t.Errorf("callCount = %d, want 2 (first run + reset run)", callCount) + } +} + +func TestLSMActionConstants(t *testing.T) { + if LSMActionPause != 0 { + t.Errorf("LSMActionPause = %d, want 0", LSMActionPause) + } + if LSMActionNext != 1 { + t.Errorf("LSMActionNext = %d, want 1", LSMActionNext) + } + if LSMActionReset != 2 { + t.Errorf("LSMActionReset = %d, want 2", LSMActionReset) + } + if LSMActionCancel != 3 { + t.Errorf("LSMActionCancel = %d, want 3", LSMActionCancel) + } +} diff --git a/analyzer/utils/string_test.go b/analyzer/utils/string_test.go new file mode 100644 index 0000000..fe35351 --- /dev/null +++ b/analyzer/utils/string_test.go @@ -0,0 +1,29 @@ +package utils + +import ( + "reflect" + "testing" +) + +func TestByteSlicesToStrings(t *testing.T) { + tests := []struct { + name string + input [][]byte + want []string + }{ + {"nil", nil, []string{}}, + {"empty", [][]byte{}, []string{}}, + {"single", [][]byte{[]byte("hello")}, []string{"hello"}}, + {"multiple", [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, []string{"foo", "bar", "baz"}}, + {"empty element", [][]byte{[]byte("a"), []byte{}, []byte("b")}, []string{"a", "", "b"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ByteSlicesToStrings(tt.input) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ByteSlicesToStrings() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..9780d2a --- /dev/null +++ b/errors_test.go @@ -0,0 +1,45 @@ +package mellaris + +import ( + "errors" + "testing" +) + +func TestConfigError_Error(t *testing.T) { + e := ConfigError{Field: "port", Err: errors.New("must be > 0")} + want := "invalid config: port: must be > 0" + if got := e.Error(); got != want { + t.Errorf("Error() = %q, want %q", got, want) + } +} + +func TestConfigError_Error_NilErr(t *testing.T) { + e := ConfigError{Field: "host", Err: nil} + want := "invalid config: host: %!s()" + if got := e.Error(); got != want { + t.Errorf("Error() = %q, want %q", got, want) + } +} + +func TestConfigError_Unwrap(t *testing.T) { + wrapped := errors.New("inner error") + e := ConfigError{Field: "field", Err: wrapped} + if got := e.Unwrap(); got != wrapped { + t.Errorf("Unwrap() = %v, want %v", got, wrapped) + } +} + +func TestConfigError_Unwrap_Nil(t *testing.T) { + e := ConfigError{Field: "field"} + if got := e.Unwrap(); got != nil { + t.Errorf("Unwrap() = %v, want nil", got) + } +} + +func TestConfigError_ErrorsIs(t *testing.T) { + wrapped := errors.New("inner") + e := ConfigError{Field: "x", Err: wrapped} + if !errors.Is(e, wrapped) { + t.Error("errors.Is should find wrapped error") + } +} diff --git a/modifier/interface_test.go b/modifier/interface_test.go new file mode 100644 index 0000000..c1e1ace --- /dev/null +++ b/modifier/interface_test.go @@ -0,0 +1,22 @@ +package modifier + +import ( + "errors" + "testing" +) + +func TestErrInvalidPacket_Error(t *testing.T) { + e := &ErrInvalidPacket{Err: errors.New("bad checksum")} + want := "invalid packet: bad checksum" + if got := e.Error(); got != want { + t.Errorf("Error() = %q, want %q", got, want) + } +} + +func TestErrInvalidArgs_Error(t *testing.T) { + e := &ErrInvalidArgs{Err: errors.New("missing 'a' arg")} + want := "invalid args: missing 'a' arg" + if got := e.Error(); got != want { + t.Errorf("Error() = %q, want %q", got, want) + } +} diff --git a/modifier/udp/dns_test.go b/modifier/udp/dns_test.go new file mode 100644 index 0000000..813208d --- /dev/null +++ b/modifier/udp/dns_test.go @@ -0,0 +1,205 @@ +package udp + +import ( + "testing" + + "git.difuse.io/Difuse/Mellaris/modifier" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +func dnsResponseBytes(t *testing.T, id uint16, qtype layers.DNSType, name string) []byte { + t.Helper() + dns := &layers.DNS{ + ID: id, + QR: true, + OpCode: layers.DNSOpCodeQuery, + AA: false, + RD: true, + RA: true, + ResponseCode: layers.DNSResponseCodeNoErr, + QDCount: 1, + Questions: []layers.DNSQuestion{ + { + Name: []byte(name), + Type: qtype, + Class: layers.DNSClassIN, + }, + }, + } + buf := gopacket.NewSerializeBuffer() + err := gopacket.SerializeLayers(buf, gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + }, dns) + if err != nil { + t.Fatalf("failed to serialize DNS response: %v", err) + } + return buf.Bytes() +} + +func TestDNSModifier_Name(t *testing.T) { + m := &DNSModifier{} + if m.Name() != "dns" { + t.Errorf("Name() = %q, want dns", m.Name()) + } +} + +func TestDNSModifier_New_NoArgs(t *testing.T) { + m := &DNSModifier{} + inst, err := m.New(map[string]interface{}{}) + if err != nil { + t.Fatalf("New() unexpected error: %v", err) + } + if inst == nil { + t.Fatal("New() returned nil instance") + } +} + +func TestDNSModifier_New_WithA(t *testing.T) { + m := &DNSModifier{} + inst, err := m.New(map[string]interface{}{ + "a": "192.168.1.1", + }) + if err != nil { + t.Fatalf("New() unexpected error: %v", err) + } + dnsInst, ok := inst.(*dnsModifierInstance) + if !ok { + t.Fatalf("New() returned wrong type: %T", inst) + } + if dnsInst.A == nil || dnsInst.A.String() != "192.168.1.1" { + t.Errorf("A = %v, want 192.168.1.1", dnsInst.A) + } +} + +func TestDNSModifier_New_WithAAAA(t *testing.T) { + m := &DNSModifier{} + inst, err := m.New(map[string]interface{}{ + "aaaa": "2001:db8::1", + }) + if err != nil { + t.Fatalf("New() unexpected error: %v", err) + } + dnsInst, ok := inst.(*dnsModifierInstance) + if !ok { + t.Fatalf("New() returned wrong type: %T", inst) + } + if dnsInst.AAAA == nil || dnsInst.AAAA.String() != "2001:db8::1" { + t.Errorf("AAAA = %v, want 2001:db8::1", dnsInst.AAAA) + } +} + +func TestDNSModifier_New_InvalidA(t *testing.T) { + m := &DNSModifier{} + _, err := m.New(map[string]interface{}{ + "a": "not-an-ip", + }) + if err == nil { + t.Fatal("New() expected error for invalid A") + } +} + +func TestDNSModifier_New_InvalidAAAA(t *testing.T) { + m := &DNSModifier{} + _, err := m.New(map[string]interface{}{ + "aaaa": "not-an-ip", + }) + if err == nil { + t.Fatal("New() expected error for invalid AAAA") + } +} + +func TestDNSModifierInstance_Process_AType(t *testing.T) { + m := &DNSModifier{} + inst, err := m.New(map[string]interface{}{ + "a": "10.10.10.10", + }) + if err != nil { + t.Fatal(err) + } + + dnsResp := dnsResponseBytes(t, 1, layers.DNSTypeA, "example.com") + modified, err := inst.(*dnsModifierInstance).Process(dnsResp) + if err != nil { + t.Fatalf("Process() error: %v", err) + } + + result := &layers.DNS{} + err = result.DecodeFromBytes(modified, gopacket.NilDecodeFeedback) + if err != nil { + t.Fatalf("decode result: %v", err) + } + if len(result.Answers) != 1 { + t.Fatalf("expected 1 answer, got %d", len(result.Answers)) + } + if result.Answers[0].Type != layers.DNSTypeA { + t.Errorf("answer type = %d, want A", result.Answers[0].Type) + } + if result.Answers[0].IP.String() != "10.10.10.10" { + t.Errorf("answer IP = %s, want 10.10.10.10", result.Answers[0].IP) + } +} + +func TestDNSModifierInstance_Process_AAAAType(t *testing.T) { + m := &DNSModifier{} + inst, err := m.New(map[string]interface{}{ + "aaaa": "2001:db8::1", + }) + if err != nil { + t.Fatal(err) + } + + dnsResp := dnsResponseBytes(t, 1, layers.DNSTypeAAAA, "example.com") + modified, err := inst.(*dnsModifierInstance).Process(dnsResp) + if err != nil { + t.Fatalf("Process() error: %v", err) + } + + result := &layers.DNS{} + err = result.DecodeFromBytes(modified, gopacket.NilDecodeFeedback) + if err != nil { + t.Fatalf("decode result: %v", err) + } + if len(result.Answers) != 1 { + t.Fatalf("expected 1 answer, got %d", len(result.Answers)) + } + if result.Answers[0].Type != layers.DNSTypeAAAA { + t.Errorf("answer type = %d, want AAAA", result.Answers[0].Type) + } +} + +func TestDNSModifierInstance_Process_NotAResponse(t *testing.T) { + inst := &dnsModifierInstance{} + query := dnsResponseBytes(t, 1, layers.DNSTypeA, "test.com") + query[2] &^= 0x80 // clear QR bit + + _, err := inst.Process(query) + if err == nil { + t.Fatal("expected error for non-response") + } + errPkt, ok := err.(*modifier.ErrInvalidPacket) + if !ok { + t.Fatalf("expected *ErrInvalidPacket, got %T", err) + } + _ = errPkt +} + +func TestDNSModifierInstance_Process_NoQuestions(t *testing.T) { + inst := &dnsModifierInstance{} + dns := &layers.DNS{ + ID: 1, + QR: true, + RD: true, + RA: true, + ResponseCode: layers.DNSResponseCodeNoErr, + } + buf := gopacket.NewSerializeBuffer() + gopacket.SerializeLayers(buf, gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}, dns) + + _, err := inst.Process(buf.Bytes()) + if err == nil { + t.Fatal("expected error for empty questions") + } +} diff --git a/ruleset/builtins/cidr_test.go b/ruleset/builtins/cidr_test.go new file mode 100644 index 0000000..9bd8499 --- /dev/null +++ b/ruleset/builtins/cidr_test.go @@ -0,0 +1,98 @@ +package builtins + +import ( + "net" + "testing" +) + +func TestCompileCIDR(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + wantStr string + }{ + {"valid ipv4", "192.168.0.0/24", false, "192.168.0.0/24"}, + {"valid ipv6", "2001:db8::/32", false, "2001:db8::/32"}, + {"valid host ipv4", "10.0.0.1/32", false, "10.0.0.1/32"}, + {"valid host ipv6", "::1/128", false, "::1/128"}, + {"invalid no mask", "192.168.0.0", true, ""}, + {"invalid bad ip", "not-an-ip/24", true, ""}, + {"invalid empty", "", true, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := CompileCIDR(tt.input) + if tt.wantErr { + if err == nil { + t.Errorf("CompileCIDR(%q) expected error, got nil", tt.input) + } + return + } + if err != nil { + t.Fatalf("CompileCIDR(%q) unexpected error: %v", tt.input, err) + } + if got.String() != tt.wantStr { + t.Errorf("CompileCIDR(%q) = %q, want %q", tt.input, got.String(), tt.wantStr) + } + }) + } +} + +func TestMatchCIDR(t *testing.T) { + cidr := mustCompileCIDR(t, "192.168.0.0/24") + + tests := []struct { + name string + ip string + want bool + }{ + {"inside", "192.168.0.1", true}, + {"boundary low", "192.168.0.0", true}, + {"boundary high", "192.168.0.255", true}, + {"outside", "192.168.1.1", false}, + {"different network", "10.0.0.1", false}, + {"invalid ip", "not-an-ip", false}, + {"empty", "", false}, + {"ipv6 in ipv4", "::1", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MatchCIDR(tt.ip, cidr) + if got != tt.want { + t.Errorf("MatchCIDR(%q, %q) = %v, want %v", tt.ip, cidr, got, tt.want) + } + }) + } +} + +func TestMatchCIDR_IPv6(t *testing.T) { + cidr := mustCompileCIDR(t, "2001:db8::/32") + + inside := "2001:db8::1" + if !MatchCIDR(inside, cidr) { + t.Errorf("MatchCIDR(%q) should be true", inside) + } + + outside := "2001:db9::1" + if MatchCIDR(outside, cidr) { + t.Errorf("MatchCIDR(%q) should be false", outside) + } +} + +func TestMatchCIDR_NullResult(t *testing.T) { + if MatchCIDR("10.0.0.1", &net.IPNet{}) { + t.Error("MatchCIDR with empty IPNet should return false") + } +} + +func mustCompileCIDR(t *testing.T, cidr string) *net.IPNet { + t.Helper() + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + t.Fatalf("failed to parse CIDR %q: %v", cidr, err) + } + return ipNet +} diff --git a/ruleset/builtins/geo/geo_matcher_test.go b/ruleset/builtins/geo/geo_matcher_test.go new file mode 100644 index 0000000..a4bf580 --- /dev/null +++ b/ruleset/builtins/geo/geo_matcher_test.go @@ -0,0 +1,115 @@ +package geo + +import ( + "testing" + + "git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo" +) + +type fakeGeoLoader struct { + geoip map[string]*v2geo.GeoIP + geosite map[string]*v2geo.GeoSite +} + +func (l *fakeGeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) { + return l.geoip, nil +} + +func (l *fakeGeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) { + return l.geosite, nil +} + +func TestGeoMatcher_MatchGeoIp_Cached(t *testing.T) { + loader := &fakeGeoLoader{ + geoip: map[string]*v2geo.GeoIP{ + "us": { + Cidr: []*v2geo.CIDR{ + {Ip: ipv4(8, 8, 8, 0), Prefix: 24}, + }, + }, + }, + } + g := NewGeoMatcher("", "") + g.geoLoader = loader + + if !g.MatchGeoIp("8.8.8.8", "US") { + t.Error("MatchGeoIp should match 8.8.8.8 in US range") + } + if g.MatchGeoIp("9.9.9.9", "US") { + t.Error("MatchGeoIp should not match 9.9.9.9 in US range") + } +} + +func TestGeoMatcher_MatchGeoIp_EmptyCondition(t *testing.T) { + g := NewGeoMatcher("", "") + if g.MatchGeoIp("1.2.3.4", "") { + t.Error("MatchGeoIp with empty condition should return false") + } +} + +func TestGeoMatcher_MatchGeoIp_InvalidIP(t *testing.T) { + g := NewGeoMatcher("", "") + if g.MatchGeoIp("not-an-ip", "us") { + t.Error("MatchGeoIp with invalid IP should return false") + } +} + +func TestGeoMatcher_MatchGeoIp_MissingCountry(t *testing.T) { + loader := &fakeGeoLoader{ + geoip: map[string]*v2geo.GeoIP{}, + } + g := NewGeoMatcher("", "") + g.geoLoader = loader + + if g.MatchGeoIp("8.8.8.8", "us") { + t.Error("MatchGeoIp for missing country should return false") + } +} + +func TestGeoMatcher_MatchGeoSite(t *testing.T) { + loader := &fakeGeoLoader{ + geosite: map[string]*v2geo.GeoSite{ + "openai": { + Domain: []*v2geo.Domain{ + {Type: v2geo.Domain_Plain, Value: "openai"}, + {Type: v2geo.Domain_Full, Value: "chatgpt.com"}, + }, + }, + }, + } + g := NewGeoMatcher("", "") + g.geoLoader = loader + + if !g.MatchGeoSite("api.openai.com", "openai") { + t.Error("MatchGeoSite should match via plain domain") + } + if !g.MatchGeoSite("chatgpt.com", "openai") { + t.Error("MatchGeoSite should match via full domain") + } + if g.MatchGeoSite("google.com", "openai") { + t.Error("MatchGeoSite should not match unrelated host") + } +} + +func TestGeoMatcher_MatchGeoSite_EmptyCondition(t *testing.T) { + g := NewGeoMatcher("", "") + if g.MatchGeoSite("test.com", "") { + t.Error("MatchGeoSite with empty condition should return false") + } +} + +func TestGeoMatcher_MatchGeoSite_MissingSite(t *testing.T) { + loader := &fakeGeoLoader{ + geosite: map[string]*v2geo.GeoSite{}, + } + g := NewGeoMatcher("", "") + g.geoLoader = loader + + if g.MatchGeoSite("test.com", "nonexistent") { + t.Error("MatchGeoSite for missing site should return false") + } +} + +func ipv4(a, b, c, d byte) []byte { + return []byte{a, b, c, d} +} diff --git a/ruleset/builtins/geo/geoip.dat b/ruleset/builtins/geo/geoip.dat new file mode 100644 index 0000000..a1c108d Binary files /dev/null and b/ruleset/builtins/geo/geoip.dat differ diff --git a/ruleset/builtins/geo/matchers_v2geo_test.go b/ruleset/builtins/geo/matchers_v2geo_test.go new file mode 100644 index 0000000..1c47aca --- /dev/null +++ b/ruleset/builtins/geo/matchers_v2geo_test.go @@ -0,0 +1,324 @@ +package geo + +import ( + "net" + "testing" + + "git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo" +) + +func TestParseGeoSiteName(t *testing.T) { + tests := []struct { + input string + wantBase string + wantAttrs []string + }{ + {"google", "google", nil}, + {"google@ads", "google", []string{"ads"}}, + {"google@ads@news", "google", []string{"ads", "news"}}, + {" google ", "google", nil}, + {" google @ ads ", "google", []string{"ads"}}, + {"openai@ ads @ news ", "openai", []string{"ads", "news"}}, + {"@onlyattrs", "", []string{"onlyattrs"}}, + {"", "", nil}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + base, attrs := parseGeoSiteName(tt.input) + if base != tt.wantBase { + t.Errorf("parseGeoSiteName(%q) base = %q, want %q", tt.input, base, tt.wantBase) + } + if len(attrs) != len(tt.wantAttrs) { + t.Fatalf("parseGeoSiteName(%q) attrs len = %d, want %d", tt.input, len(attrs), len(tt.wantAttrs)) + } + for i, attr := range attrs { + if attr != tt.wantAttrs[i] { + t.Errorf("parseGeoSiteName(%q) attrs[%d] = %q, want %q", tt.input, i, attr, tt.wantAttrs[i]) + } + } + }) + } +} + +func TestHostInfo_String(t *testing.T) { + h := HostInfo{ + Name: "example.com", + IPv4: net.ParseIP("1.2.3.4"), + IPv6: net.ParseIP("::1"), + } + want := "example.com|1.2.3.4|::1" + if got := h.String(); got != want { + t.Errorf("HostInfo.String() = %q, want %q", got, want) + } +} + +func TestHostInfo_String_Partial(t *testing.T) { + h := HostInfo{ + Name: "test.com", + IPv4: net.ParseIP("10.0.0.1"), + } + want := "test.com|10.0.0.1|" + if got := h.String(); got != want { + t.Errorf("HostInfo.String() = %q, want %q", got, want) + } +} + +func TestGeoipMatcher_Match(t *testing.T) { + _, n4, _ := net.ParseCIDR("10.0.0.0/8") + _, n4_2, _ := net.ParseCIDR("192.168.0.0/16") + m := &geoipMatcher{ + N4: []*net.IPNet{n4, n4_2}, + } + + tests := []struct { + name string + host HostInfo + want bool + }{ + {"ipv4 match", HostInfo{IPv4: net.ParseIP("10.1.2.3")}, true}, + {"ipv4 no match", HostInfo{IPv4: net.ParseIP("172.16.0.1")}, false}, + {"ipv4 match second net", HostInfo{IPv4: net.ParseIP("192.168.1.1")}, true}, + {"no ip", HostInfo{}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := m.Match(tt.host); got != tt.want { + t.Errorf("Match() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGeoipMatcher_Match_Inverse(t *testing.T) { + _, n4, _ := net.ParseCIDR("10.0.0.0/8") + m := &geoipMatcher{ + N4: []*net.IPNet{n4}, + Inverse: true, + } + + if m.Match(HostInfo{IPv4: net.ParseIP("10.1.2.3")}) { + t.Error("Inverse: inside range should return false") + } + if !m.Match(HostInfo{IPv4: net.ParseIP("172.16.0.1")}) { + t.Error("Inverse: outside range should return true") + } + if !m.Match(HostInfo{}) { + t.Error("Inverse: no IP should return true") + } +} + +func TestGeoipMatcher_Match_IPv6(t *testing.T) { + _, n6, _ := net.ParseCIDR("2001:db8::/32") + m := &geoipMatcher{ + N6: []*net.IPNet{n6}, + } + + if !m.Match(HostInfo{IPv6: net.ParseIP("2001:db8::1")}) { + t.Error("IPv6 match failed") + } + if m.Match(HostInfo{IPv6: net.ParseIP("2001:db9::1")}) { + t.Error("IPv6 should not match") + } +} + +func TestGeositeMatcher_matchDomain_Plain(t *testing.T) { + m := &geositeMatcher{} + d := geositeDomain{ + Type: geositeDomainPlain, + Value: "openai", + } + if !m.matchDomain(d, HostInfo{Name: "api.openai.com"}) { + t.Error("plain domain should match via substring") + } + if m.matchDomain(d, HostInfo{Name: "google.com"}) { + t.Error("plain domain should not match unrelated host") + } +} + +func TestGeositeMatcher_matchDomain_Full(t *testing.T) { + m := &geositeMatcher{} + d := geositeDomain{ + Type: geositeDomainFull, + Value: "example.com", + } + if !m.matchDomain(d, HostInfo{Name: "example.com"}) { + t.Error("full domain should match exact") + } + if m.matchDomain(d, HostInfo{Name: "www.example.com"}) { + t.Error("full domain should not match subdomain") + } +} + +func TestGeositeMatcher_matchDomain_Root(t *testing.T) { + m := &geositeMatcher{} + d := geositeDomain{ + Type: geositeDomainRoot, + Value: "example.com", + } + if !m.matchDomain(d, HostInfo{Name: "example.com"}) { + t.Error("root domain should match exact") + } + if !m.matchDomain(d, HostInfo{Name: "www.example.com"}) { + t.Error("root domain should match subdomain") + } + if m.matchDomain(d, HostInfo{Name: "www.example.com.au"}) { + t.Error("root domain should not match unrelated suffix") + } +} + +func TestGeositeMatcher_matchDomain_Attrs(t *testing.T) { + m := &geositeMatcher{Attrs: []string{"ads"}} + d := geositeDomain{ + Type: geositeDomainPlain, + Value: "google", + Attrs: map[string]bool{"ads": true}, + } + if !m.matchDomain(d, HostInfo{Name: "google.com"}) { + t.Error("should match when domain has required attr") + } + + dNoAttrs := geositeDomain{ + Type: geositeDomainPlain, + Value: "google", + Attrs: map[string]bool{}, + } + if m.matchDomain(dNoAttrs, HostInfo{Name: "google.com"}) { + t.Error("should not match when domain lacks required attr") + } + + dOtherAttrs := geositeDomain{ + Type: geositeDomainPlain, + Value: "google", + Attrs: map[string]bool{"news": true}, + } + if m.matchDomain(dOtherAttrs, HostInfo{Name: "google.com"}) { + t.Error("should not match when domain has wrong attr") + } +} + +func TestGeositeMatcher_Match(t *testing.T) { + m := &geositeMatcher{ + Domains: []geositeDomain{ + {Type: geositeDomainFull, Value: "exact.com"}, + {Type: geositeDomainPlain, Value: "partial"}, + }, + } + if !m.Match(HostInfo{Name: "exact.com"}) { + t.Error("should match full domain") + } + if !m.Match(HostInfo{Name: "www.partial.net"}) { + t.Error("should match partial domain") + } + if m.Match(HostInfo{Name: "other.net"}) { + t.Error("should not match unrelated host") + } +} + +func TestDomainAttributeToMap(t *testing.T) { + attrs := []*v2geo.Domain_Attribute{ + {Key: "ads"}, + {Key: "news"}, + } + got := domainAttributeToMap(attrs) + if len(got) != 2 || !got["ads"] || !got["news"] { + t.Errorf("domainAttributeToMap = %v, want {ads:true, news:true}", got) + } + + got2 := domainAttributeToMap(nil) + if len(got2) != 0 { + t.Errorf("domainAttributeToMap(nil) = %v, want empty map", got2) + } +} + +func TestNewGeoIPMatcher(t *testing.T) { + list := &v2geo.GeoIP{ + Cidr: []*v2geo.CIDR{ + {Ip: net.IPv4(10, 0, 0, 0).To4(), Prefix: 8}, + {Ip: net.IPv4(192, 168, 0, 0).To4(), Prefix: 16}, + }, + InverseMatch: false, + } + m, err := newGeoIPMatcher(list) + if err != nil { + t.Fatalf("newGeoIPMatcher error: %v", err) + } + if len(m.N4) != 2 { + t.Errorf("expected 2 IPv4 nets, got %d", len(m.N4)) + } + if m.Inverse { + t.Error("Inverse should be false") + } + // Verify sorted order: 10.0.0.0/8 < 192.168.0.0/16 + if m.N4[0].IP.String() != "10.0.0.0" { + t.Errorf("N4[0] = %s, want 10.0.0.0", m.N4[0].IP) + } + if m.N4[1].IP.String() != "192.168.0.0" { + t.Errorf("N4[1] = %s, want 192.168.0.0", m.N4[1].IP) + } +} + +func TestNewGeoIPMatcher_IPv6(t *testing.T) { + list := &v2geo.GeoIP{ + Cidr: []*v2geo.CIDR{ + {Ip: net.ParseIP("2001:db8::"), Prefix: 32}, + }, + } + m, err := newGeoIPMatcher(list) + if err != nil { + t.Fatalf("newGeoIPMatcher error: %v", err) + } + if len(m.N6) != 1 { + t.Errorf("expected 1 IPv6 net, got %d", len(m.N6)) + } +} + +func TestNewGeoIPMatcher_InvalidIPLength(t *testing.T) { + list := &v2geo.GeoIP{ + Cidr: []*v2geo.CIDR{ + {Ip: []byte{1, 2, 3}, Prefix: 24}, + }, + } + _, err := newGeoIPMatcher(list) + if err == nil { + t.Error("expected error for invalid IP length") + } +} + +func TestNewGeositeMatcher(t *testing.T) { + list := &v2geo.GeoSite{ + Domain: []*v2geo.Domain{ + {Type: v2geo.Domain_Plain, Value: "google"}, + {Type: v2geo.Domain_Full, Value: "exact.com"}, + }, + } + m, err := newGeositeMatcher(list, nil) + if err != nil { + t.Fatalf("newGeositeMatcher error: %v", err) + } + if len(m.Domains) != 2 { + t.Errorf("expected 2 domains, got %d", len(m.Domains)) + } +} + +func TestNewGeositeMatcher_WithAttrs(t *testing.T) { + list := &v2geo.GeoSite{ + Domain: []*v2geo.Domain{ + { + Type: v2geo.Domain_RootDomain, + Value: "google.com", + Attribute: []*v2geo.Domain_Attribute{ + {Key: "ads"}, + }, + }, + }, + } + m, err := newGeositeMatcher(list, []string{"ads"}) + if err != nil { + t.Fatalf("newGeositeMatcher error: %v", err) + } + if !m.Match(HostInfo{Name: "www.google.com"}) { + t.Error("should match with root domain and attr") + } +} diff --git a/ruleset/expr.go b/ruleset/expr.go index 276d7aa..d4bceaa 100644 --- a/ruleset/expr.go +++ b/ruleset/expr.go @@ -26,11 +26,14 @@ import ( // ExprRule is the external representation of an expression rule. type ExprRule struct { - Name string `yaml:"name"` - Action string `yaml:"action"` - Log bool `yaml:"log"` - Modifier ModifierEntry `yaml:"modifier"` - Expr string `yaml:"expr"` + Name string `yaml:"name"` + Action string `yaml:"action"` + Log bool `yaml:"log"` + Modifier ModifierEntry `yaml:"modifier"` + Expr string `yaml:"expr"` + StartTime string `yaml:"start_time"` + StopTime string `yaml:"stop_time"` + Weekdays []string `yaml:"weekdays"` } type ModifierEntry struct { @@ -56,6 +59,10 @@ type compiledExprRule struct { ModInstance modifier.Instance Program *vm.Program GeoSiteConditions []string + StartTimeSecs int // seconds since midnight, -1 if unset + StopTimeSecs int // seconds since midnight, -1 if unset + Weekdays []time.Weekday + WeekdaysNegated bool } var _ Ruleset = (*exprRuleset)(nil) @@ -73,10 +80,13 @@ func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer { func (r *exprRuleset) Match(info StreamInfo) MatchResult { env := streamInfoToExprEnv(info) + now := time.Now() for _, rule := range r.Rules { + if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) { + continue + } v, err := vm.Run(rule.Program, env) if err != nil { - // Log the error and continue to the next rule. r.Logger.MatchError(info, rule.Name, err) continue } @@ -163,12 +173,34 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier depAnMap[name] = a } } + startSecs := -1 + if rule.StartTime != "" { + startSecs, err = parseTimeOfDay(rule.StartTime) + if err != nil { + return nil, fmt.Errorf("rule %q has invalid start_time: %w", rule.Name, err) + } + } + stopSecs := -1 + if rule.StopTime != "" { + stopSecs, err = parseTimeOfDay(rule.StopTime) + if err != nil { + return nil, fmt.Errorf("rule %q has invalid stop_time: %w", rule.Name, err) + } + } + weekdays, weekdaysNegated, err := parseWeekdays(rule.Weekdays) + if err != nil { + return nil, fmt.Errorf("rule %q has invalid weekdays: %w", rule.Name, err) + } cr := compiledExprRule{ Name: rule.Name, Action: action, Log: rule.Log, Program: program, GeoSiteConditions: extractGeoSiteConditions(rule.Expr), + StartTimeSecs: startSecs, + StopTimeSecs: stopSecs, + Weekdays: weekdays, + WeekdaysNegated: weekdaysNegated, } if action != nil && *action == ActionModify { mod, ok := fullModMap[rule.Modifier.Name] @@ -391,6 +423,87 @@ func buildFunctionMap(config *BuiltinConfig) (map[string]*Function, *geo.GeoMatc }, geoMatcher } +func matchTime(now time.Time, startSecs, stopSecs int, weekdays []time.Weekday, negated bool) bool { + if startSecs >= 0 || stopSecs >= 0 { + currentSecs := now.Hour()*3600 + now.Minute()*60 + now.Second() + if startSecs >= 0 && stopSecs >= 0 { + if startSecs <= stopSecs { + if currentSecs < startSecs || currentSecs > stopSecs { + return false + } + } else { + if currentSecs < startSecs && currentSecs > stopSecs { + return false + } + } + } else if startSecs >= 0 { + if currentSecs < startSecs { + return false + } + } else if currentSecs > stopSecs { + return false + } + } + if len(weekdays) > 0 { + current := now.Weekday() + found := false + for _, d := range weekdays { + if current == d { + found = true + break + } + } + if negated == found { + return false + } + } + return true +} + +func parseTimeOfDay(s string) (int, error) { + t, err := time.Parse("15:04:05", s) + if err != nil { + return -1, fmt.Errorf("invalid time %q (expected hh:mm:ss)", s) + } + return t.Hour()*3600 + t.Minute()*60 + t.Second(), nil +} + +func parseWeekdays(days []string) ([]time.Weekday, bool, error) { + if len(days) == 0 { + return nil, false, nil + } + negated := false + parsed := make([]time.Weekday, 0, len(days)) + for i, d := range days { + d = strings.TrimSpace(d) + if i == 0 && strings.HasPrefix(d, "!") { + negated = true + d = strings.TrimSpace(strings.TrimPrefix(d, "!")) + } + var wd time.Weekday + switch strings.ToLower(d) { + case "sun", "sunday": + wd = time.Sunday + case "mon", "monday": + wd = time.Monday + case "tue", "tues", "tuesday": + wd = time.Tuesday + case "wed", "wednesday": + wd = time.Wednesday + case "thu", "thur", "thurs", "thursday": + wd = time.Thursday + case "fri", "friday": + wd = time.Friday + case "sat", "saturday": + wd = time.Saturday + default: + return nil, false, fmt.Errorf("invalid weekday %q", d) + } + parsed = append(parsed, wd) + } + return parsed, negated, nil +} + const rulesetLogMetaKey = "_ruleset" func addGeoSiteLogMetadata(info StreamInfo, gm *geo.GeoMatcher, conditions []string) StreamInfo { diff --git a/ruleset/interface_test.go b/ruleset/interface_test.go new file mode 100644 index 0000000..f24b780 --- /dev/null +++ b/ruleset/interface_test.go @@ -0,0 +1,121 @@ +package ruleset + +import ( + "net" + "testing" + + "git.difuse.io/Difuse/Mellaris/analyzer" +) + +func TestAction_String(t *testing.T) { + tests := []struct { + action Action + want string + }{ + {ActionMaybe, "maybe"}, + {ActionAllow, "allow"}, + {ActionBlock, "block"}, + {ActionDrop, "drop"}, + {ActionModify, "modify"}, + {Action(99), "unknown"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + if got := tt.action.String(); got != tt.want { + t.Errorf("Action.String() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestProtocol_String(t *testing.T) { + tests := []struct { + protocol Protocol + want string + }{ + {ProtocolTCP, "tcp"}, + {ProtocolUDP, "udp"}, + {Protocol(99), "unknown"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + if got := tt.protocol.String(); got != tt.want { + t.Errorf("Protocol.String() = %q, want %q", got, tt.protocol) + } + }) + } +} + +func TestProtocol_Constants(t *testing.T) { + if ProtocolTCP != 0 { + t.Errorf("ProtocolTCP = %d, want 0", ProtocolTCP) + } + if ProtocolUDP != 1 { + t.Errorf("ProtocolUDP = %d, want 1", ProtocolUDP) + } +} + +func TestAction_Constants(t *testing.T) { + if ActionMaybe != 0 { + t.Errorf("ActionMaybe = %d, want 0", ActionMaybe) + } + if ActionAllow != 1 { + t.Errorf("ActionAllow = %d, want 1", ActionAllow) + } + if ActionBlock != 2 { + t.Errorf("ActionBlock = %d, want 2", ActionBlock) + } +} + +func TestStreamInfo_SrcString(t *testing.T) { + info := StreamInfo{ + SrcIP: net.ParseIP("192.168.1.1"), + SrcPort: 8080, + } + want := "192.168.1.1:8080" + if got := info.SrcString(); got != want { + t.Errorf("SrcString() = %q, want %q", got, want) + } +} + +func TestStreamInfo_DstString(t *testing.T) { + info := StreamInfo{ + DstIP: net.ParseIP("10.0.0.1"), + DstPort: 443, + } + want := "10.0.0.1:443" + if got := info.DstString(); got != want { + t.Errorf("DstString() = %q, want %q", got, want) + } +} + +func TestStreamInfo_SrcString_IPv6(t *testing.T) { + info := StreamInfo{ + SrcIP: net.ParseIP("::1"), + SrcPort: 53, + } + want := "[::1]:53" + if got := info.SrcString(); got != want { + t.Errorf("SrcString() = %q, want %q", got, want) + } +} + +func TestMatchResult_ZeroValue(t *testing.T) { + var mr MatchResult + if mr.Action != ActionMaybe { + t.Errorf("zero MatchResult.Action = %v, want ActionMaybe (0)", mr.Action) + } +} + +func TestStreamInfo_PropsInitialization(t *testing.T) { + info := StreamInfo{ + Props: analyzer.CombinedPropMap{ + "tls": analyzer.PropMap{"sni": "example.com"}, + }, + } + if info.Props.Get("tls", "sni") != "example.com" { + t.Error("StreamInfo.Props not properly initialized") + } +}