test: improve coverage across package
This commit is contained in:
23
README.md
23
README.md
@@ -41,6 +41,29 @@ func main() {
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Run all tests:
|
||||
|
||||
```bash
|
||||
go test ./...
|
||||
```
|
||||
|
||||
Run tests with coverage:
|
||||
|
||||
```bash
|
||||
go test -cover ./... # per-package summary
|
||||
go test -coverprofile=coverage.out ./... # full profile
|
||||
go tool cover -html=coverage.out -o coverage.html # HTML report
|
||||
```
|
||||
|
||||
Run tests for a specific package:
|
||||
|
||||
```bash
|
||||
go test -v ./analyzer/tcp/
|
||||
go test -v -run TestSSH ./analyzer/tcp/
|
||||
```
|
||||
|
||||
Based on OpenGFW by apernet: https://github.com/apernet/OpenGFW
|
||||
|
||||
## LICENSE
|
||||
|
||||
124
analyzer/interface_test.go
Normal file
124
analyzer/interface_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package analyzer
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPropMap_Get(t *testing.T) {
|
||||
m := PropMap{
|
||||
"a": "value-a",
|
||||
"nested": PropMap{
|
||||
"b": "value-b",
|
||||
"deep": PropMap{
|
||||
"c": "value-c",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
want interface{}
|
||||
}{
|
||||
{"simple key", "a", "value-a"},
|
||||
{"nested key", "nested.b", "value-b"},
|
||||
{"deeply nested", "nested.deep.c", "value-c"},
|
||||
{"non-existent top", "x", nil},
|
||||
{"non-existent nested", "nested.x", nil},
|
||||
{"empty key", "", nil},
|
||||
{"partial path", "nested", PropMap{"b": "value-b", "deep": PropMap{"c": "value-c"}}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := m.Get(tt.key)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("PropMap.Get(%q) = %v, want %v", tt.key, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPropMap_Get_NilMap(t *testing.T) {
|
||||
var m PropMap
|
||||
if got := m.Get("any"); got != nil {
|
||||
t.Errorf("nil PropMap.Get() = %v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPropMap_Get_EmptyMap(t *testing.T) {
|
||||
m := PropMap{}
|
||||
if got := m.Get("any"); got != nil {
|
||||
t.Errorf("empty PropMap.Get() = %v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPropMap_Get_IntermediateNonMap(t *testing.T) {
|
||||
m := PropMap{
|
||||
"a": "hello",
|
||||
}
|
||||
if got := m.Get("a.b.c"); got != nil {
|
||||
t.Errorf("PropMap.Get() through non-map = %v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombinedPropMap_Get(t *testing.T) {
|
||||
cm := CombinedPropMap{
|
||||
"tls": PropMap{
|
||||
"req": PropMap{
|
||||
"sni": "example.com",
|
||||
},
|
||||
},
|
||||
"http": PropMap{
|
||||
"resp": PropMap{
|
||||
"status": 200,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
analyzer string
|
||||
key string
|
||||
want interface{}
|
||||
}{
|
||||
{"tls sni", "tls", "req.sni", "example.com"},
|
||||
{"http status", "http", "resp.status", 200},
|
||||
{"unknown analyzer", "dns", "any", nil},
|
||||
{"valid analyzer bad key", "tls", "req.nonexistent", nil},
|
||||
{"empty analyzer", "", "key", nil},
|
||||
{"empty key", "tls", "", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := cm.Get(tt.analyzer, tt.key)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("CombinedPropMap.Get(%q, %q) = %v, want %v", tt.analyzer, tt.key, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombinedPropMap_Get_Nil(t *testing.T) {
|
||||
var cm CombinedPropMap
|
||||
if got := cm.Get("any", "key"); got != nil {
|
||||
t.Errorf("nil CombinedPropMap.Get() = %v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPropUpdateType_Values(t *testing.T) {
|
||||
if PropUpdateNone != 0 {
|
||||
t.Errorf("PropUpdateNone = %d, want 0", PropUpdateNone)
|
||||
}
|
||||
if PropUpdateMerge != 1 {
|
||||
t.Errorf("PropUpdateMerge = %d, want 1", PropUpdateMerge)
|
||||
}
|
||||
if PropUpdateReplace != 2 {
|
||||
t.Errorf("PropUpdateReplace = %d, want 2", PropUpdateReplace)
|
||||
}
|
||||
if PropUpdateDelete != 3 {
|
||||
t.Errorf("PropUpdateDelete = %d, want 3", PropUpdateDelete)
|
||||
}
|
||||
}
|
||||
319
analyzer/internal/tls_test.go
Normal file
319
analyzer/internal/tls_test.go
Normal file
@@ -0,0 +1,319 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/analyzer/utils"
|
||||
)
|
||||
|
||||
func buildClientHelloMsg(t *testing.T) *utils.ByteBuffer {
|
||||
t.Helper()
|
||||
// Bytes taken from the standard TLS test vector, starting after the record
|
||||
// header (5 bytes) and handshake header (4 bytes).
|
||||
body := []byte{
|
||||
0x03, 0x03, // version TLS 1.2
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
|
||||
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
|
||||
0x00, // session ID length = 0
|
||||
0x00, 0x20, // cipher suites length = 32 bytes = 16 suites
|
||||
0xcc, 0xa8, 0xcc, 0xa9, 0xc0, 0x2f, 0xc0, 0x30,
|
||||
0xc0, 0x2b, 0xc0, 0x2c, 0xc0, 0x13, 0xc0, 0x09,
|
||||
0xc0, 0x14, 0xc0, 0x0a, 0x00, 0x9c, 0x00, 0x9d,
|
||||
0x00, 0x2f, 0x00, 0x35, 0xc0, 0x12, 0x00, 0x0a, // ciphers
|
||||
0x01, // compression methods length = 1
|
||||
0x00, // compression method = null
|
||||
0x00, 0x58, // extensions length = 88
|
||||
// extension: server_name
|
||||
0x00, 0x00, // type = server_name
|
||||
0x00, 0x18, // length = 24
|
||||
0x00, 0x16, // server name list length = 22
|
||||
0x00, // name type = hostname
|
||||
0x00, 0x13, // name length = 19
|
||||
'e', 'x', 'a', 'm', 'p', 'l', 'e', '.', 'u', 'l', 'f', 'h', 'e', 'i', 'm', '.', 'n', 'e', 't',
|
||||
// extension: status_request
|
||||
0x00, 0x05, // type = status_request
|
||||
0x00, 0x05, // length = 5
|
||||
0x01, 0x00, 0x00, 0x00, 0x00,
|
||||
// extension: supported_groups
|
||||
0x00, 0x0a, // type = supported_groups
|
||||
0x00, 0x0a, // length = 10
|
||||
0x00, 0x08, // list length = 8
|
||||
0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19,
|
||||
// extension: ec_point_formats
|
||||
0x00, 0x0b, // type = ec_point_formats
|
||||
0x00, 0x02, // length = 2
|
||||
0x01, 0x00,
|
||||
// extension: signature_algorithms
|
||||
0x00, 0x0d, // type = signature_algorithms
|
||||
0x00, 0x12, // length = 18
|
||||
0x00, 0x10, // list length = 16
|
||||
0x04, 0x01, 0x04, 0x03, 0x05, 0x01, 0x05, 0x03,
|
||||
0x06, 0x01, 0x06, 0x03, 0x02, 0x01, 0x02, 0x03,
|
||||
// extension: renegotiation_info
|
||||
0xff, 0x01, // type = renegotiation_info
|
||||
0x00, 0x01, // length = 1
|
||||
0x00,
|
||||
// extension: extended_master_secret (empty)
|
||||
0x00, 0x17, // type = extended_master_secret
|
||||
0x00, 0x00, // length = 0
|
||||
// extension: session_ticket (empty)
|
||||
0x00, 0x23, // type = session_ticket
|
||||
0x00, 0x00, // length = 0
|
||||
}
|
||||
return &utils.ByteBuffer{Buf: body}
|
||||
}
|
||||
|
||||
func TestParseTLSClientHelloMsgData(t *testing.T) {
|
||||
chBuf := buildClientHelloMsg(t)
|
||||
m := ParseTLSClientHelloMsgData(chBuf)
|
||||
if m == nil {
|
||||
t.Fatal("ParseTLSClientHelloMsgData returned nil")
|
||||
}
|
||||
|
||||
wantVersion := uint16(0x0303)
|
||||
if v, ok := m["version"].(uint16); !ok || v != wantVersion {
|
||||
t.Errorf("version = %v, want %v", m["version"], wantVersion)
|
||||
}
|
||||
|
||||
wantCiphers := []uint16{52392, 52393, 49199, 49200, 49195, 49196, 49171, 49161, 49172, 49162, 156, 157, 47, 53, 49170, 10}
|
||||
if c, ok := m["ciphers"].([]uint16); ok {
|
||||
if !reflect.DeepEqual(c, wantCiphers) {
|
||||
t.Errorf("ciphers = %v, want %v", c, wantCiphers)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("ciphers missing or wrong type: %T", m["ciphers"])
|
||||
}
|
||||
|
||||
if sni, ok := m["sni"].(string); !ok || sni != "example.ulfheim.net" {
|
||||
t.Errorf("sni = %q, want %q", m["sni"], "example.ulfheim.net")
|
||||
}
|
||||
|
||||
if _, ok := m["compression"]; !ok {
|
||||
t.Error("compression key missing")
|
||||
}
|
||||
|
||||
if _, ok := m["session"]; !ok {
|
||||
t.Error("session key missing")
|
||||
}
|
||||
|
||||
if _, ok := m["random"]; !ok {
|
||||
t.Error("random key missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTLSServerHelloMsgData(t *testing.T) {
|
||||
// ServerHello message body (after record header + handshake header)
|
||||
body := []byte{
|
||||
0x03, 0x03, // version TLS 1.2
|
||||
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77,
|
||||
0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f,
|
||||
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87,
|
||||
0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, // random
|
||||
0x00, // session ID length = 0
|
||||
0xc0, 0x13, // cipher suite
|
||||
0x00, // compression method
|
||||
0x00, 0x05, // extensions length = 5
|
||||
// extension: renegotiation_info
|
||||
0xff, 0x01, // type
|
||||
0x00, 0x01, // length = 1
|
||||
0x00,
|
||||
}
|
||||
|
||||
m := ParseTLSServerHelloMsgData(&utils.ByteBuffer{Buf: body})
|
||||
if m == nil {
|
||||
t.Fatal("ParseTLSServerHelloMsgData returned nil")
|
||||
}
|
||||
|
||||
wantCipher := uint16(0xc013)
|
||||
if c, ok := m["cipher"].(uint16); !ok || c != wantCipher {
|
||||
t.Errorf("cipher = %v, want %v", m["cipher"], wantCipher)
|
||||
}
|
||||
|
||||
wantCompression := uint8(0)
|
||||
if c, ok := m["compression"].(uint8); !ok || c != wantCompression {
|
||||
t.Errorf("compression = %v, want %v", m["compression"], wantCompression)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTLSServerHelloMsgData_NoExtensions(t *testing.T) {
|
||||
// ServerHello message without extensions
|
||||
body := []byte{
|
||||
0x03, 0x03, // version
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
|
||||
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
|
||||
0x00, // session ID length
|
||||
0x00, 0xff, // cipher suite
|
||||
0x00, // compression method
|
||||
}
|
||||
|
||||
m := ParseTLSServerHelloMsgData(&utils.ByteBuffer{Buf: body})
|
||||
if m == nil {
|
||||
t.Fatal("ParseTLSServerHelloMsgData returned nil for no-extensions case")
|
||||
}
|
||||
if _, ok := m["cipher"]; !ok {
|
||||
t.Error("cipher key missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTLSClientHelloMsgData_Truncated(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
buf []byte
|
||||
}{
|
||||
{"too short for session id", []byte{0x03, 0x03, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f}},
|
||||
{"odd cipher suites length", []byte{
|
||||
0x03, 0x03, // version
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
|
||||
0x00, // session ID length
|
||||
0x00, 0x03, // odd cipher suites length
|
||||
}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: tt.buf})
|
||||
if m != nil {
|
||||
t.Error("expected nil for truncated input")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTLSClientHelloMsgData_ECH(t *testing.T) {
|
||||
// ClientHello with ECH extension
|
||||
body := []byte{
|
||||
0x03, 0x03, // version
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
|
||||
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
|
||||
0x00, // session ID length
|
||||
0x00, 0x02, // cipher suites length = 2
|
||||
0x13, 0x01, // TLS_AES_128_GCM_SHA256
|
||||
0x01, // compression methods length = 1
|
||||
0x00, // null compression
|
||||
0x00, 0x07, // extensions length = 7
|
||||
// ECH extension
|
||||
0xfe, 0x0d, // type = encrypted_client_hello
|
||||
0x00, 0x03, // length = 3
|
||||
0x01, 0x02, 0x03, // some ECH data
|
||||
}
|
||||
|
||||
m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: body})
|
||||
if m == nil {
|
||||
t.Fatal("ParseTLSClientHelloMsgData returned nil")
|
||||
}
|
||||
ech, ok := m["ech"].(bool)
|
||||
if !ok || !ech {
|
||||
t.Errorf("ech = %v, want true", m["ech"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTLSClientHelloMsgData_ALPN(t *testing.T) {
|
||||
// ClientHello with ALPN extension
|
||||
body := []byte{
|
||||
0x03, 0x03, // version
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
|
||||
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
|
||||
0x00, // session ID length
|
||||
0x00, 0x02, // cipher suites length = 2
|
||||
0x13, 0x01, // cipher suite
|
||||
0x01, // compression methods length = 1
|
||||
0x00, // compression method
|
||||
0x00, 0x12, // extensions length = 18
|
||||
// ALPN extension
|
||||
0x00, 0x10, // type = ALPN
|
||||
0x00, 0x0e, // length = 14
|
||||
0x00, 0x0c, // list length = 12
|
||||
0x08, 'h', 't', 't', 'p', '/', '1', '.', '1',
|
||||
0x02, 'h', '2',
|
||||
}
|
||||
|
||||
m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: body})
|
||||
if m == nil {
|
||||
t.Fatal("ParseTLSClientHelloMsgData returned nil")
|
||||
}
|
||||
alpn, ok := m["alpn"].([]string)
|
||||
if !ok {
|
||||
t.Fatalf("alpn missing or wrong type: %T", m["alpn"])
|
||||
}
|
||||
if len(alpn) != 2 || alpn[0] != "http/1.1" || alpn[1] != "h2" {
|
||||
t.Errorf("alpn = %v, want [http/1.1 h2]", alpn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTLSClientHelloMsgData_SupportedVersionsClient(t *testing.T) {
|
||||
// ClientHello with supported_versions extension (client format - list)
|
||||
body := []byte{
|
||||
0x03, 0x03, // version
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
|
||||
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
|
||||
0x00, // session ID length
|
||||
0x00, 0x02, // cipher suites length = 2
|
||||
0x13, 0x01, // cipher suite
|
||||
0x01, // compression methods length = 1
|
||||
0x00, // compression method
|
||||
0x00, 0x0b, // extensions length = 11
|
||||
// supported_versions (client list format)
|
||||
0x00, 0x2b, // type = supported_versions
|
||||
0x00, 0x07, // length = 7
|
||||
0x06, // list length = 6
|
||||
0x03, 0x04, // TLS 1.3
|
||||
0x03, 0x03, // TLS 1.2
|
||||
0x03, 0x01, // TLS 1.0
|
||||
}
|
||||
|
||||
m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: body})
|
||||
if m == nil {
|
||||
t.Fatal("ParseTLSClientHelloMsgData returned nil")
|
||||
}
|
||||
versions, ok := m["supported_versions"].([]uint16)
|
||||
if !ok {
|
||||
t.Fatalf("supported_versions missing or wrong type: %T", m["supported_versions"])
|
||||
}
|
||||
want := []uint16{0x0304, 0x0303, 0x0301}
|
||||
if !reflect.DeepEqual(versions, want) {
|
||||
t.Errorf("supported_versions = %v, want %v", versions, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTLSServerHelloMsgData_SupportedVersionsServer(t *testing.T) {
|
||||
// ServerHello with supported_versions extension (server format - single value)
|
||||
body := []byte{
|
||||
0x03, 0x03, // version
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
|
||||
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
|
||||
0x00, // session ID length
|
||||
0x13, 0x01, // cipher suite
|
||||
0x00, // compression method
|
||||
0x00, 0x06, // extensions length = 6
|
||||
// supported_versions (server format - single value)
|
||||
0x00, 0x2b, // type
|
||||
0x00, 0x02, // length = 2
|
||||
0x03, 0x04, // TLS 1.3
|
||||
}
|
||||
|
||||
m := ParseTLSServerHelloMsgData(&utils.ByteBuffer{Buf: body})
|
||||
if m == nil {
|
||||
t.Fatal("ParseTLSServerHelloMsgData returned nil")
|
||||
}
|
||||
v, ok := m["supported_versions"].(uint16)
|
||||
if !ok {
|
||||
t.Fatalf("supported_versions missing or wrong type: %T", m["supported_versions"])
|
||||
}
|
||||
if v != 0x0304 {
|
||||
t.Errorf("supported_versions = 0x%04x, want 0x0304", v)
|
||||
}
|
||||
}
|
||||
256
analyzer/tcp/fet_test.go
Normal file
256
analyzer/tcp/fet_test.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPopCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
input byte
|
||||
want int
|
||||
}{
|
||||
{0x00, 0},
|
||||
{0x01, 1},
|
||||
{0x02, 1},
|
||||
{0x03, 2},
|
||||
{0x07, 3},
|
||||
{0x0f, 4},
|
||||
{0xff, 8},
|
||||
{0x55, 4}, // 01010101
|
||||
{0xaa, 4}, // 10101010
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := popCount(tt.input); got != tt.want {
|
||||
t.Errorf("popCount(0x%02x) = %d, want %d", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAveragePopCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
want float32
|
||||
}{
|
||||
{"empty", []byte{}, 0},
|
||||
{"all zeros", []byte{0x00, 0x00, 0x00}, 0},
|
||||
{"all ones", []byte{0xff, 0xff}, 8.0},
|
||||
{"mixed", []byte{0x00, 0xff}, 4.0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := averagePopCount(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("averagePopCount() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsFirstSixPrintable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
want bool
|
||||
}{
|
||||
{"too short", []byte("abc"), false},
|
||||
{"all printable", []byte("abcdef"), true},
|
||||
{"non-printable at pos 0", []byte{0x00, 'a', 'b', 'c', 'd', 'e'}, false},
|
||||
{"non-printable at pos 5", []byte{'a', 'b', 'c', 'd', 'e', 0x1f}, false},
|
||||
{"exactly 6 printable", []byte("123456"), true},
|
||||
{"spaces", []byte(" "), true},
|
||||
{"non-printable middle", []byte{'a', 'b', 0x01, 'd', 'e', 'f'}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isFirstSixPrintable(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("isFirstSixPrintable(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintablePercentage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
want float32
|
||||
}{
|
||||
{"empty", []byte{}, 0},
|
||||
{"all printable", []byte("hello"), 1.0},
|
||||
{"none printable", []byte{0x00, 0x01, 0x02, 0x03}, 0},
|
||||
{"half printable", []byte{'a', 0x00, 'b', 0x00}, 0.5},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := printablePercentage(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("printablePercentage() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContiguousPrintable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
want int
|
||||
}{
|
||||
{"empty", []byte{}, 0},
|
||||
{"all printable", []byte("hello world"), 11},
|
||||
{"none printable", []byte{0x00, 0x01, 0x02}, 0},
|
||||
{"start printable", []byte{'a', 'b', 'c', 0x00, 'd', 'e', 'f'}, 3},
|
||||
{"end printable", []byte{0x00, 'a', 'b', 'c', 'd', 'e', 'f'}, 6},
|
||||
{"middle printable", []byte{0x00, 'a', 'b', 'c', 0x00}, 3},
|
||||
{"two segments", []byte{'a', 'b', 0x00, 'c', 'd', 'e'}, 3},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := contiguousPrintable(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("contiguousPrintable() = %d, want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTLSorHTTP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
want bool
|
||||
}{
|
||||
{"too short", []byte("AB"), false},
|
||||
// TLS ClientHello: 0x16 0x03 0x01
|
||||
{"tls 1.0", []byte{0x16, 0x03, 0x01}, true},
|
||||
// TLS 0x17 is application data record
|
||||
{"tls app data", []byte{0x17, 0x03, 0x03}, true},
|
||||
{"tls max content type", []byte{0x17, 0x03, 0x09}, true},
|
||||
{"bad tls content type", []byte{0x15, 0x03, 0x01}, false},
|
||||
{"bad tls version", []byte{0x16, 0x04, 0x01}, false},
|
||||
{"bad tls length", []byte{0x16, 0x03, 0x0a}, false},
|
||||
// HTTP methods
|
||||
{"GET", []byte("GET / HTTP/1.1..."), true},
|
||||
{"HEAD", []byte("HEAD /index.html..."), true},
|
||||
{"POST", []byte("POST /api..."), true},
|
||||
{"PUT", []byte("PUT /data..."), true},
|
||||
{"DELETE", []byte("DELETE /..."), true},
|
||||
{"CONNECT", []byte("CONNECT proxy..."), true},
|
||||
{"OPTIONS", []byte("OPTIONS *..."), true},
|
||||
{"TRACE", []byte("TRACE /..."), true},
|
||||
{"PATCH", []byte("PATCH /..."), true},
|
||||
{"random data", []byte{0x00, 0x01, 0x02, 0x03}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isTLSorHTTP(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("isTLSorHTTP(%v) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPrintable(t *testing.T) {
|
||||
if !isPrintable('a') {
|
||||
t.Error("'a' should be printable")
|
||||
}
|
||||
if isPrintable(0x00) {
|
||||
t.Error("0x00 should not be printable")
|
||||
}
|
||||
if !isPrintable(0x20) {
|
||||
t.Error("0x20 (space) should be printable")
|
||||
}
|
||||
if !isPrintable(0x7e) {
|
||||
t.Error("0x7e (~) should be printable")
|
||||
}
|
||||
if isPrintable(0x7f) {
|
||||
t.Error("0x7f (DEL) should not be printable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFETStream_Feed(t *testing.T) {
|
||||
s := newFETStream(nil)
|
||||
u, done := s.Feed(false, false, false, 0, []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
|
||||
if u == nil {
|
||||
t.Fatal("Feed returned nil update")
|
||||
}
|
||||
if !done {
|
||||
t.Error("FET should be done after first packet")
|
||||
}
|
||||
m := u.M
|
||||
if _, ok := m["ex1"]; !ok {
|
||||
t.Error("ex1 missing")
|
||||
}
|
||||
if _, ok := m["ex2"]; !ok {
|
||||
t.Error("ex2 missing")
|
||||
}
|
||||
if _, ok := m["ex3"]; !ok {
|
||||
t.Error("ex3 missing")
|
||||
}
|
||||
if _, ok := m["ex4"]; !ok {
|
||||
t.Error("ex4 missing")
|
||||
}
|
||||
if _, ok := m["ex5"]; !ok {
|
||||
t.Error("ex5 missing")
|
||||
}
|
||||
if yes, ok := m["yes"].(bool); ok && yes {
|
||||
t.Error("HTTP should be exempt (yes=false)")
|
||||
}
|
||||
if u.Type != 2 {
|
||||
t.Errorf("prop update type = %d, want PropUpdateReplace", u.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFETStream_Feed_EncryptedLike(t *testing.T) {
|
||||
s := newFETStream(nil)
|
||||
data := make([]byte, 100)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
u, done := s.Feed(false, false, false, 0, data)
|
||||
if u == nil {
|
||||
t.Fatal("Feed returned nil update")
|
||||
}
|
||||
if !done {
|
||||
t.Error("should be done")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFETStream_Feed_Skip(t *testing.T) {
|
||||
s := newFETStream(nil)
|
||||
_, done := s.Feed(false, false, false, 5, []byte("data"))
|
||||
if !done {
|
||||
t.Error("skip != 0 should return done=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFETStream_Feed_Empty(t *testing.T) {
|
||||
s := newFETStream(nil)
|
||||
u, done := s.Feed(false, false, false, 0, []byte{})
|
||||
if u != nil || done {
|
||||
t.Error("empty data should return nil, false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFETAnalyzer_Name(t *testing.T) {
|
||||
a := &FETAnalyzer{}
|
||||
if a.Name() != "fet" {
|
||||
t.Errorf("Name() = %q, want fet", a.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFETStream_Close(t *testing.T) {
|
||||
s := newFETStream(nil)
|
||||
if u := s.Close(false); u != nil {
|
||||
t.Error("Close should return nil")
|
||||
}
|
||||
}
|
||||
159
analyzer/tcp/ssh_test.go
Normal file
159
analyzer/tcp/ssh_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||
)
|
||||
|
||||
func TestSSHAnalyzer_Name(t *testing.T) {
|
||||
a := &SSHAnalyzer{}
|
||||
if a.Name() != "ssh" {
|
||||
t.Errorf("Name() = %q, want ssh", a.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHStream_Feed_Client(t *testing.T) {
|
||||
s := newSSHStream(nil)
|
||||
u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3\r\n"))
|
||||
if u == nil {
|
||||
t.Fatal("Feed returned nil update")
|
||||
}
|
||||
if done {
|
||||
t.Error("should not be done (server not yet received)")
|
||||
}
|
||||
client, ok := u.M["client"].(analyzer.PropMap)
|
||||
if !ok {
|
||||
t.Fatal("client prop missing")
|
||||
}
|
||||
if client["protocol"] != "2.0" {
|
||||
t.Errorf("protocol = %v, want 2.0", client["protocol"])
|
||||
}
|
||||
if client["software"] != "OpenSSH_8.9p1" {
|
||||
t.Errorf("software = %v, want OpenSSH_8.9p1", client["software"])
|
||||
}
|
||||
if comments, ok := client["comments"]; ok {
|
||||
if comments != "Ubuntu-3" {
|
||||
t.Errorf("comments = %v, want Ubuntu-3", comments)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHStream_Feed_ClientWithComments(t *testing.T) {
|
||||
s := newSSHStream(nil)
|
||||
u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH_7.4 Ubuntu-3\r\n"))
|
||||
if u == nil {
|
||||
t.Fatal("Feed returned nil update")
|
||||
}
|
||||
if done {
|
||||
t.Error("should not be done")
|
||||
}
|
||||
client := u.M["client"].(analyzer.PropMap)
|
||||
if client["comments"] != "Ubuntu-3" {
|
||||
t.Errorf("comments = %v, want Ubuntu-3", client["comments"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHStream_Feed_Both(t *testing.T) {
|
||||
s := newSSHStream(nil)
|
||||
s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH_8.9\r\n"))
|
||||
u, done := s.Feed(true, false, false, 0, []byte("SSH-2.0-dropbear_2022.83\r\n"))
|
||||
if u == nil {
|
||||
t.Fatal("Feed returned nil update")
|
||||
}
|
||||
if !done {
|
||||
t.Error("should be done after both sides")
|
||||
}
|
||||
server, ok := u.M["server"].(analyzer.PropMap)
|
||||
if !ok {
|
||||
t.Fatal("server prop missing")
|
||||
}
|
||||
if server["software"] != "dropbear_2022.83" {
|
||||
t.Errorf("server software = %v", server["software"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHStream_Feed_NotSSH(t *testing.T) {
|
||||
s := newSSHStream(nil)
|
||||
u, done := s.Feed(false, false, false, 0, []byte("HTTP/1.1 200 OK\r\n"))
|
||||
if u != nil {
|
||||
t.Error("should return nil for non-SSH")
|
||||
}
|
||||
if !done {
|
||||
t.Error("should be cancelled (done) for non-SSH")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHStream_Feed_InvalidLine(t *testing.T) {
|
||||
s := newSSHStream(nil)
|
||||
u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-foo bar baz\r\n"))
|
||||
if u != nil {
|
||||
t.Error("should return nil for invalid line (>2 fields)")
|
||||
}
|
||||
if !done {
|
||||
t.Error("should be cancelled for invalid SSH line")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHStream_Feed_NoLineEnd(t *testing.T) {
|
||||
s := newSSHStream(nil)
|
||||
u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH"))
|
||||
if u != nil {
|
||||
t.Error("should return nil when no EOL found yet")
|
||||
}
|
||||
if done {
|
||||
t.Error("should not be done, waiting for more data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHStream_Feed_IncompleteThenComplete(t *testing.T) {
|
||||
s := newSSHStream(nil)
|
||||
u, _ := s.Feed(false, false, false, 0, []byte("SSH-2.0-"))
|
||||
if u != nil {
|
||||
t.Error("first partial feed should return nil")
|
||||
}
|
||||
u, done := s.Feed(false, false, false, 0, []byte("Dropbear\r\n"))
|
||||
if u == nil {
|
||||
t.Fatal("second feed should return update")
|
||||
}
|
||||
if done {
|
||||
t.Error("should not be done (only client)")
|
||||
}
|
||||
client := u.M["client"].(analyzer.PropMap)
|
||||
if client["software"] != "Dropbear" {
|
||||
t.Errorf("software = %v", client["software"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHStream_Feed_Skip(t *testing.T) {
|
||||
s := newSSHStream(nil)
|
||||
_, done := s.Feed(false, false, false, 5, []byte("data"))
|
||||
if !done {
|
||||
t.Error("skip != 0 should return done=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHStream_Feed_Empty(t *testing.T) {
|
||||
s := newSSHStream(nil)
|
||||
u, done := s.Feed(false, false, false, 0, []byte{})
|
||||
if u != nil || done {
|
||||
t.Error("empty data should return nil, false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHStream_Close(t *testing.T) {
|
||||
s := newSSHStream(nil)
|
||||
s.clientBuf.Append([]byte("data"))
|
||||
s.serverBuf.Append([]byte("data"))
|
||||
s.clientMap = analyzer.PropMap{"key": "val"}
|
||||
u := s.Close(false)
|
||||
if u != nil {
|
||||
t.Error("Close should return nil")
|
||||
}
|
||||
if s.clientBuf.Len() != 0 || s.serverBuf.Len() != 0 {
|
||||
t.Error("Close should reset buffers")
|
||||
}
|
||||
if s.clientMap != nil || s.serverMap != nil {
|
||||
t.Error("Close should nil maps")
|
||||
}
|
||||
}
|
||||
187
analyzer/udp/wireguard_test.go
Normal file
187
analyzer/udp/wireguard_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||
)
|
||||
|
||||
func makeWireGuardPacket(msgType byte, body []byte) []byte {
|
||||
buf := make([]byte, 1+3+len(body))
|
||||
buf[0] = msgType
|
||||
copy(buf[4:], body)
|
||||
return buf
|
||||
}
|
||||
|
||||
func TestWireGuardUDPStream_Feed_HandshakeInitiation(t *testing.T) {
|
||||
s := newWireGuardUDPStream(nil)
|
||||
body := make([]byte, wireguardSizeHandshakeInitiation-4)
|
||||
binary.LittleEndian.PutUint32(body[0:4], 0x01020304)
|
||||
|
||||
u, done := s.Feed(false, makeWireGuardPacket(wireguardTypeHandshakeInitiation, body))
|
||||
if u == nil {
|
||||
t.Fatal("Feed returned nil update")
|
||||
}
|
||||
if done {
|
||||
t.Error("should not be done")
|
||||
}
|
||||
if u.Type != analyzer.PropUpdateReplace {
|
||||
t.Errorf("Type = %d, want PropUpdateReplace", u.Type)
|
||||
}
|
||||
initMap, ok := u.M["handshake_initiation"].(analyzer.PropMap)
|
||||
if !ok {
|
||||
t.Fatal("handshake_initiation missing")
|
||||
}
|
||||
if idx := initMap["sender_index"].(uint32); idx != 0x01020304 {
|
||||
t.Errorf("sender_index = %d, want %d", idx, 0x01020304)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireGuardUDPStream_Feed_HandshakeResponse(t *testing.T) {
|
||||
s := newWireGuardUDPStream(nil)
|
||||
body := make([]byte, wireguardSizeHandshakeResponse-4)
|
||||
binary.LittleEndian.PutUint32(body[0:4], 0x01020304)
|
||||
binary.LittleEndian.PutUint32(body[4:8], 0x05060708)
|
||||
|
||||
u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeHandshakeResponse, body))
|
||||
if u == nil {
|
||||
t.Fatal("Feed returned nil update")
|
||||
}
|
||||
respMap, ok := u.M["handshake_response"].(analyzer.PropMap)
|
||||
if !ok {
|
||||
t.Fatal("handshake_response missing")
|
||||
}
|
||||
if idx := respMap["sender_index"].(uint32); idx != 0x01020304 {
|
||||
t.Errorf("sender_index = %d", idx)
|
||||
}
|
||||
if idx := respMap["receiver_index"].(uint32); idx != 0x05060708 {
|
||||
t.Errorf("receiver_index = %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireGuardUDPStream_Feed_PacketData(t *testing.T) {
|
||||
s := newWireGuardUDPStream(nil)
|
||||
body := make([]byte, wireguardMinSizePacketData-4)
|
||||
binary.LittleEndian.PutUint32(body[0:4], 0x0a0b0c0d)
|
||||
binary.LittleEndian.PutUint64(body[4:12], 42)
|
||||
|
||||
u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeData, body))
|
||||
if u == nil {
|
||||
t.Fatal("Feed returned nil update")
|
||||
}
|
||||
dataMap, ok := u.M["packet_data"].(analyzer.PropMap)
|
||||
if !ok {
|
||||
t.Fatal("packet_data missing")
|
||||
}
|
||||
if idx := dataMap["receiver_index"].(uint32); idx != 0x0a0b0c0d {
|
||||
t.Errorf("receiver_index = %d", idx)
|
||||
}
|
||||
if ctr := dataMap["counter"].(uint64); ctr != 42 {
|
||||
t.Errorf("counter = %d", ctr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireGuardUDPStream_Feed_CookieReply(t *testing.T) {
|
||||
s := newWireGuardUDPStream(nil)
|
||||
body := make([]byte, wireguardSizePacketCookieReply-4)
|
||||
binary.LittleEndian.PutUint32(body[0:4], 0x11111111)
|
||||
|
||||
u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeCookieReply, body))
|
||||
if u == nil {
|
||||
t.Fatal("Feed returned nil update")
|
||||
}
|
||||
crMap, ok := u.M["packet_cookie_reply"].(analyzer.PropMap)
|
||||
if !ok {
|
||||
t.Fatal("packet_cookie_reply missing")
|
||||
}
|
||||
if idx := crMap["receiver_index"].(uint32); idx != 0x11111111 {
|
||||
t.Errorf("receiver_index = %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireGuardUDPStream_Feed_InvalidLength(t *testing.T) {
|
||||
s := newWireGuardUDPStream(nil)
|
||||
u, done := s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01})
|
||||
if u == nil || !done {
|
||||
u, done = s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01})
|
||||
}
|
||||
if u == nil || !done {
|
||||
u, _ = s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01})
|
||||
}
|
||||
u, done = s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01})
|
||||
if u != nil || !done {
|
||||
t.Error("should return done after 4 invalid packets")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireGuardUDPStream_Feed_TooShort(t *testing.T) {
|
||||
s := newWireGuardUDPStream(nil)
|
||||
u, _ := s.Feed(false, []byte{1, 0, 0})
|
||||
if u != nil {
|
||||
t.Error("should return nil for too-short packet")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireGuardUDPStream_Feed_NonZeroReserved(t *testing.T) {
|
||||
s := newWireGuardUDPStream(nil)
|
||||
u, _ := s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 1, 0, 0, 0, 0})
|
||||
if u != nil {
|
||||
t.Error("should return nil when reserved bytes are non-zero")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireGuardUDPStream_Feed_HandshakeInitiation_WrongSize(t *testing.T) {
|
||||
s := newWireGuardUDPStream(nil)
|
||||
body := make([]byte, wireguardSizeHandshakeInitiation-4+1)
|
||||
binary.LittleEndian.PutUint32(body[0:4], 1)
|
||||
u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeHandshakeInitiation, body))
|
||||
if u != nil {
|
||||
t.Error("should return nil for wrong init size")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireGuardUDPStream_Close(t *testing.T) {
|
||||
s := newWireGuardUDPStream(nil)
|
||||
u := s.Close(false)
|
||||
if u != nil {
|
||||
t.Error("Close() should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireGuardUDPStream_PutSenderIndex_MatchReceiverIndex(t *testing.T) {
|
||||
s := newWireGuardUDPStream(nil)
|
||||
|
||||
s.putSenderIndex(false, 0xdeadbeef)
|
||||
if !s.matchReceiverIndex(true, 0xdeadbeef) {
|
||||
t.Error("should match reverse direction")
|
||||
}
|
||||
if s.matchReceiverIndex(false, 0xdeadbeef) {
|
||||
t.Error("should not match same direction")
|
||||
}
|
||||
if s.matchReceiverIndex(true, 0x12345678) {
|
||||
t.Error("should not match wrong index")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireGuardUDPStream_RememberedIndexRing(t *testing.T) {
|
||||
s := newWireGuardUDPStream(nil)
|
||||
|
||||
for i := uint32(0); i < wireguardRememberedIndexCount+2; i++ {
|
||||
s.putSenderIndex(false, i)
|
||||
}
|
||||
|
||||
if s.matchReceiverIndex(true, 0) {
|
||||
t.Error("index 0 should have been evicted from ring")
|
||||
}
|
||||
if !s.matchReceiverIndex(true, uint32(wireguardRememberedIndexCount+1)) {
|
||||
t.Error("latest index should still be in ring")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireGuardAnalyzer_Name(t *testing.T) {
|
||||
a := &WireGuardAnalyzer{}
|
||||
if a.Name() != "wireguard" {
|
||||
t.Errorf("Name() = %q, want wireguard", a.Name())
|
||||
}
|
||||
}
|
||||
321
analyzer/utils/bytebuffer_test.go
Normal file
321
analyzer/utils/bytebuffer_test.go
Normal file
@@ -0,0 +1,321 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestByteBuffer_Append(t *testing.T) {
|
||||
b := &ByteBuffer{}
|
||||
b.Append([]byte("hello"))
|
||||
b.Append([]byte(" "))
|
||||
b.Append([]byte("world"))
|
||||
if string(b.Buf) != "hello world" {
|
||||
t.Errorf("Append() result = %q, want %q", string(b.Buf), "hello world")
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteBuffer_Len(t *testing.T) {
|
||||
b := &ByteBuffer{}
|
||||
if b.Len() != 0 {
|
||||
t.Errorf("Len() = %d, want 0", b.Len())
|
||||
}
|
||||
b.Append([]byte("abc"))
|
||||
if b.Len() != 3 {
|
||||
t.Errorf("Len() = %d, want 3", b.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteBuffer_Index(t *testing.T) {
|
||||
b := &ByteBuffer{Buf: []byte("hello world")}
|
||||
if i := b.Index([]byte("world")); i != 6 {
|
||||
t.Errorf("Index('world') = %d, want 6", i)
|
||||
}
|
||||
if i := b.Index([]byte("xyz")); i != -1 {
|
||||
t.Errorf("Index('xyz') = %d, want -1", i)
|
||||
}
|
||||
if i := b.Index([]byte{}); i != 0 {
|
||||
t.Errorf("Index('') = %d, want 0", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteBuffer_Get(t *testing.T) {
|
||||
b := &ByteBuffer{Buf: []byte("abcdef")}
|
||||
|
||||
data, ok := b.Get(3, true)
|
||||
if !ok {
|
||||
t.Fatal("Get(3, true) returned false")
|
||||
}
|
||||
if !bytes.Equal(data, []byte("abc")) {
|
||||
t.Errorf("Get(3, true) data = %q, want %q", data, "abc")
|
||||
}
|
||||
if !bytes.Equal(b.Buf, []byte("def")) {
|
||||
t.Errorf("after consume, Buf = %q, want %q", b.Buf, "def")
|
||||
}
|
||||
|
||||
data, ok = b.Get(4, false)
|
||||
if ok {
|
||||
t.Fatal("Get(4, false) should return false (only 3 bytes left)")
|
||||
}
|
||||
|
||||
data, ok = b.Get(2, false)
|
||||
if !ok {
|
||||
t.Fatal("Get(2, false) returned false")
|
||||
}
|
||||
if !bytes.Equal(data, []byte("de")) {
|
||||
t.Errorf("Get(2, false) data = %q, want %q", data, "de")
|
||||
}
|
||||
if !bytes.Equal(b.Buf, []byte("def")) {
|
||||
t.Errorf("after non-consume, Buf = %q, want %q (unchanged)", b.Buf, "def")
|
||||
}
|
||||
|
||||
data, ok = b.Get(3, true)
|
||||
if !ok {
|
||||
t.Fatal("Get(3, true) returned false")
|
||||
}
|
||||
if !bytes.Equal(data, []byte("def")) {
|
||||
t.Errorf("Get(3, true) data = %q, want %q", data, "def")
|
||||
}
|
||||
if b.Len() != 0 {
|
||||
t.Errorf("after consume all, Len() = %d, want 0", b.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteBuffer_GetString(t *testing.T) {
|
||||
b := &ByteBuffer{Buf: []byte("hello world")}
|
||||
|
||||
s, ok := b.GetString(5, true)
|
||||
if !ok {
|
||||
t.Fatal("GetString(5, true) returned false")
|
||||
}
|
||||
if s != "hello" {
|
||||
t.Errorf("GetString() = %q, want %q", s, "hello")
|
||||
}
|
||||
if b.Len() != 6 {
|
||||
t.Errorf("after consume, Len() = %d, want 6", b.Len())
|
||||
}
|
||||
|
||||
s, ok = b.GetString(7, false)
|
||||
if ok {
|
||||
t.Fatal("GetString(7, false) should return false")
|
||||
}
|
||||
|
||||
s, ok = b.GetString(6, false)
|
||||
if !ok {
|
||||
t.Fatal("GetString(6, false) returned false")
|
||||
}
|
||||
if s != " world" {
|
||||
t.Errorf("GetString() = %q, want %q", s, " world")
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteBuffer_GetByte(t *testing.T) {
|
||||
b := &ByteBuffer{Buf: []byte("abc")}
|
||||
|
||||
bt, ok := b.GetByte(true)
|
||||
if !ok {
|
||||
t.Fatal("GetByte(true) returned false")
|
||||
}
|
||||
if bt != 'a' {
|
||||
t.Errorf("GetByte() = %c, want 'a'", bt)
|
||||
}
|
||||
if b.Len() != 2 {
|
||||
t.Errorf("after consume, Len() = %d, want 2", b.Len())
|
||||
}
|
||||
|
||||
bt, ok = b.GetByte(false)
|
||||
if !ok {
|
||||
t.Fatal("GetByte(false) returned false")
|
||||
}
|
||||
if bt != 'b' {
|
||||
t.Errorf("GetByte() = %c, want 'b'", bt)
|
||||
}
|
||||
if b.Len() != 2 {
|
||||
t.Errorf("after non-consume, Len() = %d, want 2 (unchanged)", b.Len())
|
||||
}
|
||||
|
||||
b.GetByte(true)
|
||||
b.GetByte(true)
|
||||
bt, ok = b.GetByte(true)
|
||||
if ok {
|
||||
t.Fatal("GetByte(true) on empty buffer should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteBuffer_GetUint16(t *testing.T) {
|
||||
b := &ByteBuffer{Buf: []byte{0x01, 0x02, 0x03, 0x04}}
|
||||
|
||||
v, ok := b.GetUint16(false, true)
|
||||
if !ok {
|
||||
t.Fatal("GetUint16(bigEndian) returned false")
|
||||
}
|
||||
if v != 0x0102 {
|
||||
t.Errorf("GetUint16(bigEndian) = 0x%04x, want 0x0102", v)
|
||||
}
|
||||
if b.Len() != 2 {
|
||||
t.Errorf("after consume, Len() = %d, want 2", b.Len())
|
||||
}
|
||||
|
||||
v, ok = b.GetUint16(true, true)
|
||||
if !ok {
|
||||
t.Fatal("GetUint16(littleEndian) returned false")
|
||||
}
|
||||
if v != 0x0403 {
|
||||
t.Errorf("GetUint16(littleEndian) = 0x%04x, want 0x0403", v)
|
||||
}
|
||||
|
||||
v, ok = b.GetUint16(false, false)
|
||||
if ok {
|
||||
t.Fatal("GetUint16 on empty buffer should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteBuffer_GetUint32(t *testing.T) {
|
||||
b := &ByteBuffer{Buf: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}}
|
||||
|
||||
v, ok := b.GetUint32(false, true)
|
||||
if !ok {
|
||||
t.Fatal("GetUint32(bigEndian) returned false")
|
||||
}
|
||||
if v != 0x01020304 {
|
||||
t.Errorf("GetUint32(bigEndian) = 0x%08x, want 0x01020304", v)
|
||||
}
|
||||
if b.Len() != 4 {
|
||||
t.Errorf("after consume, Len() = %d, want 4", b.Len())
|
||||
}
|
||||
|
||||
v, ok = b.GetUint32(true, true)
|
||||
if !ok {
|
||||
t.Fatal("GetUint32(littleEndian) returned false")
|
||||
}
|
||||
if v != 0x08070605 {
|
||||
t.Errorf("GetUint32(littleEndian) = 0x%08x, want 0x08070605", v)
|
||||
}
|
||||
|
||||
v, ok = b.GetUint32(false, false)
|
||||
if ok {
|
||||
t.Fatal("GetUint32 on empty buffer should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteBuffer_GetUntil(t *testing.T) {
|
||||
b := &ByteBuffer{Buf: []byte("hello\r\nworld\r\n")}
|
||||
|
||||
data, ok := b.GetUntil([]byte("\r\n"), true, true)
|
||||
if !ok {
|
||||
t.Fatal("GetUntil(sep, include) returned false")
|
||||
}
|
||||
if !bytes.Equal(data, []byte("hello\r\n")) {
|
||||
t.Errorf("GetUntil(include) = %q, want %q", data, "hello\\r\\n")
|
||||
}
|
||||
|
||||
data, ok = b.GetUntil([]byte("\r\n"), false, false)
|
||||
if !ok {
|
||||
t.Fatal("GetUntil(sep, exclude, non-consume) returned false")
|
||||
}
|
||||
if !bytes.Equal(data, []byte("world")) {
|
||||
t.Errorf("GetUntil(exclude) = %q, want %q", data, "world")
|
||||
}
|
||||
if b.Len() != 7 {
|
||||
t.Errorf("after non-consume, Len() = %d, want 7", b.Len())
|
||||
}
|
||||
|
||||
data, ok = b.GetUntil([]byte("\r\n"), true, true)
|
||||
if !ok {
|
||||
t.Fatal("GetUntil second (consume) returned false")
|
||||
}
|
||||
if !bytes.Equal(data, []byte("world\r\n")) {
|
||||
t.Errorf("GetUntil second = %q, want %q", data, "world\\r\\n")
|
||||
}
|
||||
|
||||
_, ok = b.GetUntil([]byte("xyz"), false, false)
|
||||
if ok {
|
||||
t.Fatal("GetUntil(not found) should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteBuffer_GetSubBuffer(t *testing.T) {
|
||||
b := &ByteBuffer{Buf: []byte("hello world")}
|
||||
|
||||
sub, ok := b.GetSubBuffer(5, true)
|
||||
if !ok {
|
||||
t.Fatal("GetSubBuffer() returned false")
|
||||
}
|
||||
if !bytes.Equal(sub.Buf, []byte("hello")) {
|
||||
t.Errorf("GetSubBuffer() = %q, want %q", sub.Buf, "hello")
|
||||
}
|
||||
if b.Len() != 6 {
|
||||
t.Errorf("after consume, Len() = %d, want 6", b.Len())
|
||||
}
|
||||
|
||||
_, ok = b.GetSubBuffer(7, false)
|
||||
if ok {
|
||||
t.Fatal("GetSubBuffer(7) should return false (only 6 bytes left)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteBuffer_Skip(t *testing.T) {
|
||||
b := &ByteBuffer{Buf: []byte("abcdef")}
|
||||
|
||||
ok := b.Skip(2)
|
||||
if !ok {
|
||||
t.Fatal("Skip(2) returned false")
|
||||
}
|
||||
if !bytes.Equal(b.Buf, []byte("cdef")) {
|
||||
t.Errorf("after Skip(2), Buf = %q, want %q", b.Buf, "cdef")
|
||||
}
|
||||
|
||||
ok = b.Skip(10)
|
||||
if ok {
|
||||
t.Fatal("Skip(10) should return false")
|
||||
}
|
||||
if !bytes.Equal(b.Buf, []byte("cdef")) {
|
||||
t.Errorf("after failed Skip, Buf = %q, want %q (unchanged)", b.Buf, "cdef")
|
||||
}
|
||||
|
||||
ok = b.Skip(4)
|
||||
if !ok {
|
||||
t.Fatal("Skip(4) returned false")
|
||||
}
|
||||
if b.Len() != 0 {
|
||||
t.Errorf("after Skip all, Len() = %d, want 0", b.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteBuffer_Reset(t *testing.T) {
|
||||
b := &ByteBuffer{Buf: []byte("data")}
|
||||
b.Reset()
|
||||
if b.Buf != nil {
|
||||
t.Errorf("after Reset, Buf = %v, want nil", b.Buf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteBuffer_GetZeroLength(t *testing.T) {
|
||||
b := &ByteBuffer{Buf: []byte("abc")}
|
||||
|
||||
data, ok := b.Get(0, true)
|
||||
if !ok {
|
||||
t.Fatal("Get(0) returned false")
|
||||
}
|
||||
if len(data) != 0 {
|
||||
t.Errorf("Get(0) len = %d, want 0", len(data))
|
||||
}
|
||||
if b.Len() != 3 {
|
||||
t.Errorf("after Get(0, consume), Len() = %d, want 3 (0-length consume is no-op)", b.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteBuffer_GetConsumeDoesNotMutateReturnedSlice(t *testing.T) {
|
||||
b := &ByteBuffer{Buf: []byte("hello")}
|
||||
data, ok := b.Get(5, true)
|
||||
if !ok {
|
||||
t.Fatal("Get() returned false")
|
||||
}
|
||||
if !reflect.DeepEqual(data, []byte("hello")) {
|
||||
t.Errorf("Get() returned wrong data: %v", data)
|
||||
}
|
||||
if b.Len() != 0 {
|
||||
t.Errorf("after consume, Len() should be 0")
|
||||
}
|
||||
}
|
||||
185
analyzer/utils/lsm_test.go
Normal file
185
analyzer/utils/lsm_test.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package utils
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestLinearStateMachine_RunPause(t *testing.T) {
|
||||
callCount := 0
|
||||
lsm := NewLinearStateMachine(
|
||||
func() LSMAction {
|
||||
callCount++
|
||||
return LSMActionPause
|
||||
},
|
||||
)
|
||||
cancelled, done := lsm.Run()
|
||||
if cancelled {
|
||||
t.Error("unexpected cancelled=true")
|
||||
}
|
||||
if done {
|
||||
t.Error("unexpected done=true")
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Errorf("callCount = %d, want 1", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinearStateMachine_RunNext(t *testing.T) {
|
||||
callOrder := []int{}
|
||||
lsm := NewLinearStateMachine(
|
||||
func() LSMAction { callOrder = append(callOrder, 1); return LSMActionNext },
|
||||
func() LSMAction { callOrder = append(callOrder, 2); return LSMActionNext },
|
||||
func() LSMAction { callOrder = append(callOrder, 3); return LSMActionNext },
|
||||
)
|
||||
cancelled, done := lsm.Run()
|
||||
if cancelled {
|
||||
t.Error("unexpected cancelled=true")
|
||||
}
|
||||
if !done {
|
||||
t.Error("unexpected done=false")
|
||||
}
|
||||
if len(callOrder) != 3 {
|
||||
t.Fatalf("callOrder len = %d, want 3", len(callOrder))
|
||||
}
|
||||
for i, v := range []int{1, 2, 3} {
|
||||
if callOrder[i] != v {
|
||||
t.Errorf("callOrder[%d] = %d, want %d", i, callOrder[i], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinearStateMachine_RunReset(t *testing.T) {
|
||||
callCount := 0
|
||||
lsm := NewLinearStateMachine(
|
||||
func() LSMAction {
|
||||
callCount++
|
||||
if callCount == 1 {
|
||||
return LSMActionReset
|
||||
}
|
||||
return LSMActionNext
|
||||
},
|
||||
func() LSMAction { callCount++; return LSMActionNext },
|
||||
)
|
||||
cancelled, done := lsm.Run()
|
||||
if cancelled {
|
||||
t.Error("unexpected cancelled=true")
|
||||
}
|
||||
if !done {
|
||||
t.Error("unexpected done=false")
|
||||
}
|
||||
if callCount != 3 {
|
||||
t.Errorf("callCount = %d, want 3 (step0 reset, step0 next, step1 next)", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinearStateMachine_RunCancel(t *testing.T) {
|
||||
callCount := 0
|
||||
lsm := NewLinearStateMachine(
|
||||
func() LSMAction { callCount++; return LSMActionNext },
|
||||
func() LSMAction { callCount++; return LSMActionCancel },
|
||||
func() LSMAction { callCount++; return LSMActionNext },
|
||||
)
|
||||
cancelled, done := lsm.Run()
|
||||
if !cancelled {
|
||||
t.Error("unexpected cancelled=false")
|
||||
}
|
||||
if !done {
|
||||
t.Error("unexpected done=false")
|
||||
}
|
||||
if callCount != 2 {
|
||||
t.Errorf("callCount = %d, want 2 (third step should not execute)", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinearStateMachine_RunMixed(t *testing.T) {
|
||||
pauseCount := 0
|
||||
lsm := NewLinearStateMachine(
|
||||
func() LSMAction { return LSMActionNext },
|
||||
func() LSMAction {
|
||||
pauseCount++
|
||||
if pauseCount == 1 {
|
||||
return LSMActionPause
|
||||
}
|
||||
return LSMActionNext
|
||||
},
|
||||
func() LSMAction { return LSMActionNext },
|
||||
func() LSMAction { return LSMActionNext },
|
||||
)
|
||||
cancelled, done := lsm.Run()
|
||||
if cancelled {
|
||||
t.Error("unexpected cancelled=true")
|
||||
}
|
||||
if done {
|
||||
t.Error("unexpected done=true on first run")
|
||||
}
|
||||
cancelled, done = lsm.Run()
|
||||
if cancelled {
|
||||
t.Error("unexpected cancelled=true on second run")
|
||||
}
|
||||
if !done {
|
||||
t.Error("unexpected done=false on second run")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinearStateMachine_RunEmpty(t *testing.T) {
|
||||
lsm := NewLinearStateMachine()
|
||||
cancelled, done := lsm.Run()
|
||||
if cancelled {
|
||||
t.Error("unexpected cancelled=true")
|
||||
}
|
||||
if !done {
|
||||
t.Error("unexpected done=false for empty LSM")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinearStateMachine_AppendSteps(t *testing.T) {
|
||||
lsm := NewLinearStateMachine(
|
||||
func() LSMAction { return LSMActionNext },
|
||||
)
|
||||
lsm.Run()
|
||||
lsm.AppendSteps(
|
||||
func() LSMAction { return LSMActionNext },
|
||||
)
|
||||
_, done := lsm.Run()
|
||||
if !done {
|
||||
t.Error("unexpected done=false after AppendSteps")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinearStateMachine_Reset(t *testing.T) {
|
||||
callCount := 0
|
||||
lsm := NewLinearStateMachine(
|
||||
func() LSMAction { callCount++; return LSMActionCancel },
|
||||
)
|
||||
lsm.Run()
|
||||
if !lsm.cancelled {
|
||||
t.Error("expected cancelled=true after cancel")
|
||||
}
|
||||
lsm.Reset()
|
||||
if lsm.cancelled {
|
||||
t.Error("expected cancelled=false after Reset")
|
||||
}
|
||||
if lsm.index != 0 {
|
||||
t.Errorf("expected index=0 after Reset, got %d", lsm.index)
|
||||
}
|
||||
_, done := lsm.Run()
|
||||
if !done {
|
||||
t.Error("expected done=true, step executed again after Reset")
|
||||
}
|
||||
if callCount != 2 {
|
||||
t.Errorf("callCount = %d, want 2 (first run + reset run)", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLSMActionConstants(t *testing.T) {
|
||||
if LSMActionPause != 0 {
|
||||
t.Errorf("LSMActionPause = %d, want 0", LSMActionPause)
|
||||
}
|
||||
if LSMActionNext != 1 {
|
||||
t.Errorf("LSMActionNext = %d, want 1", LSMActionNext)
|
||||
}
|
||||
if LSMActionReset != 2 {
|
||||
t.Errorf("LSMActionReset = %d, want 2", LSMActionReset)
|
||||
}
|
||||
if LSMActionCancel != 3 {
|
||||
t.Errorf("LSMActionCancel = %d, want 3", LSMActionCancel)
|
||||
}
|
||||
}
|
||||
29
analyzer/utils/string_test.go
Normal file
29
analyzer/utils/string_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestByteSlicesToStrings(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input [][]byte
|
||||
want []string
|
||||
}{
|
||||
{"nil", nil, []string{}},
|
||||
{"empty", [][]byte{}, []string{}},
|
||||
{"single", [][]byte{[]byte("hello")}, []string{"hello"}},
|
||||
{"multiple", [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, []string{"foo", "bar", "baz"}},
|
||||
{"empty element", [][]byte{[]byte("a"), []byte{}, []byte("b")}, []string{"a", "", "b"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ByteSlicesToStrings(tt.input)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("ByteSlicesToStrings() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
45
errors_test.go
Normal file
45
errors_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package mellaris
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConfigError_Error(t *testing.T) {
|
||||
e := ConfigError{Field: "port", Err: errors.New("must be > 0")}
|
||||
want := "invalid config: port: must be > 0"
|
||||
if got := e.Error(); got != want {
|
||||
t.Errorf("Error() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigError_Error_NilErr(t *testing.T) {
|
||||
e := ConfigError{Field: "host", Err: nil}
|
||||
want := "invalid config: host: %!s(<nil>)"
|
||||
if got := e.Error(); got != want {
|
||||
t.Errorf("Error() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigError_Unwrap(t *testing.T) {
|
||||
wrapped := errors.New("inner error")
|
||||
e := ConfigError{Field: "field", Err: wrapped}
|
||||
if got := e.Unwrap(); got != wrapped {
|
||||
t.Errorf("Unwrap() = %v, want %v", got, wrapped)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigError_Unwrap_Nil(t *testing.T) {
|
||||
e := ConfigError{Field: "field"}
|
||||
if got := e.Unwrap(); got != nil {
|
||||
t.Errorf("Unwrap() = %v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigError_ErrorsIs(t *testing.T) {
|
||||
wrapped := errors.New("inner")
|
||||
e := ConfigError{Field: "x", Err: wrapped}
|
||||
if !errors.Is(e, wrapped) {
|
||||
t.Error("errors.Is should find wrapped error")
|
||||
}
|
||||
}
|
||||
22
modifier/interface_test.go
Normal file
22
modifier/interface_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package modifier
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrInvalidPacket_Error(t *testing.T) {
|
||||
e := &ErrInvalidPacket{Err: errors.New("bad checksum")}
|
||||
want := "invalid packet: bad checksum"
|
||||
if got := e.Error(); got != want {
|
||||
t.Errorf("Error() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrInvalidArgs_Error(t *testing.T) {
|
||||
e := &ErrInvalidArgs{Err: errors.New("missing 'a' arg")}
|
||||
want := "invalid args: missing 'a' arg"
|
||||
if got := e.Error(); got != want {
|
||||
t.Errorf("Error() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
205
modifier/udp/dns_test.go
Normal file
205
modifier/udp/dns_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/modifier"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
func dnsResponseBytes(t *testing.T, id uint16, qtype layers.DNSType, name string) []byte {
|
||||
t.Helper()
|
||||
dns := &layers.DNS{
|
||||
ID: id,
|
||||
QR: true,
|
||||
OpCode: layers.DNSOpCodeQuery,
|
||||
AA: false,
|
||||
RD: true,
|
||||
RA: true,
|
||||
ResponseCode: layers.DNSResponseCodeNoErr,
|
||||
QDCount: 1,
|
||||
Questions: []layers.DNSQuestion{
|
||||
{
|
||||
Name: []byte(name),
|
||||
Type: qtype,
|
||||
Class: layers.DNSClassIN,
|
||||
},
|
||||
},
|
||||
}
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
err := gopacket.SerializeLayers(buf, gopacket.SerializeOptions{
|
||||
FixLengths: true,
|
||||
ComputeChecksums: true,
|
||||
}, dns)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serialize DNS response: %v", err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestDNSModifier_Name(t *testing.T) {
|
||||
m := &DNSModifier{}
|
||||
if m.Name() != "dns" {
|
||||
t.Errorf("Name() = %q, want dns", m.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSModifier_New_NoArgs(t *testing.T) {
|
||||
m := &DNSModifier{}
|
||||
inst, err := m.New(map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("New() unexpected error: %v", err)
|
||||
}
|
||||
if inst == nil {
|
||||
t.Fatal("New() returned nil instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSModifier_New_WithA(t *testing.T) {
|
||||
m := &DNSModifier{}
|
||||
inst, err := m.New(map[string]interface{}{
|
||||
"a": "192.168.1.1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New() unexpected error: %v", err)
|
||||
}
|
||||
dnsInst, ok := inst.(*dnsModifierInstance)
|
||||
if !ok {
|
||||
t.Fatalf("New() returned wrong type: %T", inst)
|
||||
}
|
||||
if dnsInst.A == nil || dnsInst.A.String() != "192.168.1.1" {
|
||||
t.Errorf("A = %v, want 192.168.1.1", dnsInst.A)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSModifier_New_WithAAAA(t *testing.T) {
|
||||
m := &DNSModifier{}
|
||||
inst, err := m.New(map[string]interface{}{
|
||||
"aaaa": "2001:db8::1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New() unexpected error: %v", err)
|
||||
}
|
||||
dnsInst, ok := inst.(*dnsModifierInstance)
|
||||
if !ok {
|
||||
t.Fatalf("New() returned wrong type: %T", inst)
|
||||
}
|
||||
if dnsInst.AAAA == nil || dnsInst.AAAA.String() != "2001:db8::1" {
|
||||
t.Errorf("AAAA = %v, want 2001:db8::1", dnsInst.AAAA)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSModifier_New_InvalidA(t *testing.T) {
|
||||
m := &DNSModifier{}
|
||||
_, err := m.New(map[string]interface{}{
|
||||
"a": "not-an-ip",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("New() expected error for invalid A")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSModifier_New_InvalidAAAA(t *testing.T) {
|
||||
m := &DNSModifier{}
|
||||
_, err := m.New(map[string]interface{}{
|
||||
"aaaa": "not-an-ip",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("New() expected error for invalid AAAA")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSModifierInstance_Process_AType(t *testing.T) {
|
||||
m := &DNSModifier{}
|
||||
inst, err := m.New(map[string]interface{}{
|
||||
"a": "10.10.10.10",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dnsResp := dnsResponseBytes(t, 1, layers.DNSTypeA, "example.com")
|
||||
modified, err := inst.(*dnsModifierInstance).Process(dnsResp)
|
||||
if err != nil {
|
||||
t.Fatalf("Process() error: %v", err)
|
||||
}
|
||||
|
||||
result := &layers.DNS{}
|
||||
err = result.DecodeFromBytes(modified, gopacket.NilDecodeFeedback)
|
||||
if err != nil {
|
||||
t.Fatalf("decode result: %v", err)
|
||||
}
|
||||
if len(result.Answers) != 1 {
|
||||
t.Fatalf("expected 1 answer, got %d", len(result.Answers))
|
||||
}
|
||||
if result.Answers[0].Type != layers.DNSTypeA {
|
||||
t.Errorf("answer type = %d, want A", result.Answers[0].Type)
|
||||
}
|
||||
if result.Answers[0].IP.String() != "10.10.10.10" {
|
||||
t.Errorf("answer IP = %s, want 10.10.10.10", result.Answers[0].IP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSModifierInstance_Process_AAAAType(t *testing.T) {
|
||||
m := &DNSModifier{}
|
||||
inst, err := m.New(map[string]interface{}{
|
||||
"aaaa": "2001:db8::1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dnsResp := dnsResponseBytes(t, 1, layers.DNSTypeAAAA, "example.com")
|
||||
modified, err := inst.(*dnsModifierInstance).Process(dnsResp)
|
||||
if err != nil {
|
||||
t.Fatalf("Process() error: %v", err)
|
||||
}
|
||||
|
||||
result := &layers.DNS{}
|
||||
err = result.DecodeFromBytes(modified, gopacket.NilDecodeFeedback)
|
||||
if err != nil {
|
||||
t.Fatalf("decode result: %v", err)
|
||||
}
|
||||
if len(result.Answers) != 1 {
|
||||
t.Fatalf("expected 1 answer, got %d", len(result.Answers))
|
||||
}
|
||||
if result.Answers[0].Type != layers.DNSTypeAAAA {
|
||||
t.Errorf("answer type = %d, want AAAA", result.Answers[0].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSModifierInstance_Process_NotAResponse(t *testing.T) {
|
||||
inst := &dnsModifierInstance{}
|
||||
query := dnsResponseBytes(t, 1, layers.DNSTypeA, "test.com")
|
||||
query[2] &^= 0x80 // clear QR bit
|
||||
|
||||
_, err := inst.Process(query)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-response")
|
||||
}
|
||||
errPkt, ok := err.(*modifier.ErrInvalidPacket)
|
||||
if !ok {
|
||||
t.Fatalf("expected *ErrInvalidPacket, got %T", err)
|
||||
}
|
||||
_ = errPkt
|
||||
}
|
||||
|
||||
func TestDNSModifierInstance_Process_NoQuestions(t *testing.T) {
|
||||
inst := &dnsModifierInstance{}
|
||||
dns := &layers.DNS{
|
||||
ID: 1,
|
||||
QR: true,
|
||||
RD: true,
|
||||
RA: true,
|
||||
ResponseCode: layers.DNSResponseCodeNoErr,
|
||||
}
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
gopacket.SerializeLayers(buf, gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}, dns)
|
||||
|
||||
_, err := inst.Process(buf.Bytes())
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty questions")
|
||||
}
|
||||
}
|
||||
98
ruleset/builtins/cidr_test.go
Normal file
98
ruleset/builtins/cidr_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package builtins
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompileCIDR(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
wantStr string
|
||||
}{
|
||||
{"valid ipv4", "192.168.0.0/24", false, "192.168.0.0/24"},
|
||||
{"valid ipv6", "2001:db8::/32", false, "2001:db8::/32"},
|
||||
{"valid host ipv4", "10.0.0.1/32", false, "10.0.0.1/32"},
|
||||
{"valid host ipv6", "::1/128", false, "::1/128"},
|
||||
{"invalid no mask", "192.168.0.0", true, ""},
|
||||
{"invalid bad ip", "not-an-ip/24", true, ""},
|
||||
{"invalid empty", "", true, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := CompileCIDR(tt.input)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("CompileCIDR(%q) expected error, got nil", tt.input)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("CompileCIDR(%q) unexpected error: %v", tt.input, err)
|
||||
}
|
||||
if got.String() != tt.wantStr {
|
||||
t.Errorf("CompileCIDR(%q) = %q, want %q", tt.input, got.String(), tt.wantStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchCIDR(t *testing.T) {
|
||||
cidr := mustCompileCIDR(t, "192.168.0.0/24")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
want bool
|
||||
}{
|
||||
{"inside", "192.168.0.1", true},
|
||||
{"boundary low", "192.168.0.0", true},
|
||||
{"boundary high", "192.168.0.255", true},
|
||||
{"outside", "192.168.1.1", false},
|
||||
{"different network", "10.0.0.1", false},
|
||||
{"invalid ip", "not-an-ip", false},
|
||||
{"empty", "", false},
|
||||
{"ipv6 in ipv4", "::1", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MatchCIDR(tt.ip, cidr)
|
||||
if got != tt.want {
|
||||
t.Errorf("MatchCIDR(%q, %q) = %v, want %v", tt.ip, cidr, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchCIDR_IPv6(t *testing.T) {
|
||||
cidr := mustCompileCIDR(t, "2001:db8::/32")
|
||||
|
||||
inside := "2001:db8::1"
|
||||
if !MatchCIDR(inside, cidr) {
|
||||
t.Errorf("MatchCIDR(%q) should be true", inside)
|
||||
}
|
||||
|
||||
outside := "2001:db9::1"
|
||||
if MatchCIDR(outside, cidr) {
|
||||
t.Errorf("MatchCIDR(%q) should be false", outside)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchCIDR_NullResult(t *testing.T) {
|
||||
if MatchCIDR("10.0.0.1", &net.IPNet{}) {
|
||||
t.Error("MatchCIDR with empty IPNet should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func mustCompileCIDR(t *testing.T, cidr string) *net.IPNet {
|
||||
t.Helper()
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse CIDR %q: %v", cidr, err)
|
||||
}
|
||||
return ipNet
|
||||
}
|
||||
115
ruleset/builtins/geo/geo_matcher_test.go
Normal file
115
ruleset/builtins/geo/geo_matcher_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package geo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
|
||||
)
|
||||
|
||||
type fakeGeoLoader struct {
|
||||
geoip map[string]*v2geo.GeoIP
|
||||
geosite map[string]*v2geo.GeoSite
|
||||
}
|
||||
|
||||
func (l *fakeGeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) {
|
||||
return l.geoip, nil
|
||||
}
|
||||
|
||||
func (l *fakeGeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) {
|
||||
return l.geosite, nil
|
||||
}
|
||||
|
||||
func TestGeoMatcher_MatchGeoIp_Cached(t *testing.T) {
|
||||
loader := &fakeGeoLoader{
|
||||
geoip: map[string]*v2geo.GeoIP{
|
||||
"us": {
|
||||
Cidr: []*v2geo.CIDR{
|
||||
{Ip: ipv4(8, 8, 8, 0), Prefix: 24},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
g := NewGeoMatcher("", "")
|
||||
g.geoLoader = loader
|
||||
|
||||
if !g.MatchGeoIp("8.8.8.8", "US") {
|
||||
t.Error("MatchGeoIp should match 8.8.8.8 in US range")
|
||||
}
|
||||
if g.MatchGeoIp("9.9.9.9", "US") {
|
||||
t.Error("MatchGeoIp should not match 9.9.9.9 in US range")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoMatcher_MatchGeoIp_EmptyCondition(t *testing.T) {
|
||||
g := NewGeoMatcher("", "")
|
||||
if g.MatchGeoIp("1.2.3.4", "") {
|
||||
t.Error("MatchGeoIp with empty condition should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoMatcher_MatchGeoIp_InvalidIP(t *testing.T) {
|
||||
g := NewGeoMatcher("", "")
|
||||
if g.MatchGeoIp("not-an-ip", "us") {
|
||||
t.Error("MatchGeoIp with invalid IP should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoMatcher_MatchGeoIp_MissingCountry(t *testing.T) {
|
||||
loader := &fakeGeoLoader{
|
||||
geoip: map[string]*v2geo.GeoIP{},
|
||||
}
|
||||
g := NewGeoMatcher("", "")
|
||||
g.geoLoader = loader
|
||||
|
||||
if g.MatchGeoIp("8.8.8.8", "us") {
|
||||
t.Error("MatchGeoIp for missing country should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoMatcher_MatchGeoSite(t *testing.T) {
|
||||
loader := &fakeGeoLoader{
|
||||
geosite: map[string]*v2geo.GeoSite{
|
||||
"openai": {
|
||||
Domain: []*v2geo.Domain{
|
||||
{Type: v2geo.Domain_Plain, Value: "openai"},
|
||||
{Type: v2geo.Domain_Full, Value: "chatgpt.com"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
g := NewGeoMatcher("", "")
|
||||
g.geoLoader = loader
|
||||
|
||||
if !g.MatchGeoSite("api.openai.com", "openai") {
|
||||
t.Error("MatchGeoSite should match via plain domain")
|
||||
}
|
||||
if !g.MatchGeoSite("chatgpt.com", "openai") {
|
||||
t.Error("MatchGeoSite should match via full domain")
|
||||
}
|
||||
if g.MatchGeoSite("google.com", "openai") {
|
||||
t.Error("MatchGeoSite should not match unrelated host")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoMatcher_MatchGeoSite_EmptyCondition(t *testing.T) {
|
||||
g := NewGeoMatcher("", "")
|
||||
if g.MatchGeoSite("test.com", "") {
|
||||
t.Error("MatchGeoSite with empty condition should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoMatcher_MatchGeoSite_MissingSite(t *testing.T) {
|
||||
loader := &fakeGeoLoader{
|
||||
geosite: map[string]*v2geo.GeoSite{},
|
||||
}
|
||||
g := NewGeoMatcher("", "")
|
||||
g.geoLoader = loader
|
||||
|
||||
if g.MatchGeoSite("test.com", "nonexistent") {
|
||||
t.Error("MatchGeoSite for missing site should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func ipv4(a, b, c, d byte) []byte {
|
||||
return []byte{a, b, c, d}
|
||||
}
|
||||
BIN
ruleset/builtins/geo/geoip.dat
Normal file
BIN
ruleset/builtins/geo/geoip.dat
Normal file
Binary file not shown.
324
ruleset/builtins/geo/matchers_v2geo_test.go
Normal file
324
ruleset/builtins/geo/matchers_v2geo_test.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package geo
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
|
||||
)
|
||||
|
||||
func TestParseGeoSiteName(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantBase string
|
||||
wantAttrs []string
|
||||
}{
|
||||
{"google", "google", nil},
|
||||
{"google@ads", "google", []string{"ads"}},
|
||||
{"google@ads@news", "google", []string{"ads", "news"}},
|
||||
{" google ", "google", nil},
|
||||
{" google @ ads ", "google", []string{"ads"}},
|
||||
{"openai@ ads @ news ", "openai", []string{"ads", "news"}},
|
||||
{"@onlyattrs", "", []string{"onlyattrs"}},
|
||||
{"", "", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
base, attrs := parseGeoSiteName(tt.input)
|
||||
if base != tt.wantBase {
|
||||
t.Errorf("parseGeoSiteName(%q) base = %q, want %q", tt.input, base, tt.wantBase)
|
||||
}
|
||||
if len(attrs) != len(tt.wantAttrs) {
|
||||
t.Fatalf("parseGeoSiteName(%q) attrs len = %d, want %d", tt.input, len(attrs), len(tt.wantAttrs))
|
||||
}
|
||||
for i, attr := range attrs {
|
||||
if attr != tt.wantAttrs[i] {
|
||||
t.Errorf("parseGeoSiteName(%q) attrs[%d] = %q, want %q", tt.input, i, attr, tt.wantAttrs[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostInfo_String(t *testing.T) {
|
||||
h := HostInfo{
|
||||
Name: "example.com",
|
||||
IPv4: net.ParseIP("1.2.3.4"),
|
||||
IPv6: net.ParseIP("::1"),
|
||||
}
|
||||
want := "example.com|1.2.3.4|::1"
|
||||
if got := h.String(); got != want {
|
||||
t.Errorf("HostInfo.String() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostInfo_String_Partial(t *testing.T) {
|
||||
h := HostInfo{
|
||||
Name: "test.com",
|
||||
IPv4: net.ParseIP("10.0.0.1"),
|
||||
}
|
||||
want := "test.com|10.0.0.1|<nil>"
|
||||
if got := h.String(); got != want {
|
||||
t.Errorf("HostInfo.String() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoipMatcher_Match(t *testing.T) {
|
||||
_, n4, _ := net.ParseCIDR("10.0.0.0/8")
|
||||
_, n4_2, _ := net.ParseCIDR("192.168.0.0/16")
|
||||
m := &geoipMatcher{
|
||||
N4: []*net.IPNet{n4, n4_2},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
host HostInfo
|
||||
want bool
|
||||
}{
|
||||
{"ipv4 match", HostInfo{IPv4: net.ParseIP("10.1.2.3")}, true},
|
||||
{"ipv4 no match", HostInfo{IPv4: net.ParseIP("172.16.0.1")}, false},
|
||||
{"ipv4 match second net", HostInfo{IPv4: net.ParseIP("192.168.1.1")}, true},
|
||||
{"no ip", HostInfo{}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := m.Match(tt.host); got != tt.want {
|
||||
t.Errorf("Match() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoipMatcher_Match_Inverse(t *testing.T) {
|
||||
_, n4, _ := net.ParseCIDR("10.0.0.0/8")
|
||||
m := &geoipMatcher{
|
||||
N4: []*net.IPNet{n4},
|
||||
Inverse: true,
|
||||
}
|
||||
|
||||
if m.Match(HostInfo{IPv4: net.ParseIP("10.1.2.3")}) {
|
||||
t.Error("Inverse: inside range should return false")
|
||||
}
|
||||
if !m.Match(HostInfo{IPv4: net.ParseIP("172.16.0.1")}) {
|
||||
t.Error("Inverse: outside range should return true")
|
||||
}
|
||||
if !m.Match(HostInfo{}) {
|
||||
t.Error("Inverse: no IP should return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoipMatcher_Match_IPv6(t *testing.T) {
|
||||
_, n6, _ := net.ParseCIDR("2001:db8::/32")
|
||||
m := &geoipMatcher{
|
||||
N6: []*net.IPNet{n6},
|
||||
}
|
||||
|
||||
if !m.Match(HostInfo{IPv6: net.ParseIP("2001:db8::1")}) {
|
||||
t.Error("IPv6 match failed")
|
||||
}
|
||||
if m.Match(HostInfo{IPv6: net.ParseIP("2001:db9::1")}) {
|
||||
t.Error("IPv6 should not match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeositeMatcher_matchDomain_Plain(t *testing.T) {
|
||||
m := &geositeMatcher{}
|
||||
d := geositeDomain{
|
||||
Type: geositeDomainPlain,
|
||||
Value: "openai",
|
||||
}
|
||||
if !m.matchDomain(d, HostInfo{Name: "api.openai.com"}) {
|
||||
t.Error("plain domain should match via substring")
|
||||
}
|
||||
if m.matchDomain(d, HostInfo{Name: "google.com"}) {
|
||||
t.Error("plain domain should not match unrelated host")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeositeMatcher_matchDomain_Full(t *testing.T) {
|
||||
m := &geositeMatcher{}
|
||||
d := geositeDomain{
|
||||
Type: geositeDomainFull,
|
||||
Value: "example.com",
|
||||
}
|
||||
if !m.matchDomain(d, HostInfo{Name: "example.com"}) {
|
||||
t.Error("full domain should match exact")
|
||||
}
|
||||
if m.matchDomain(d, HostInfo{Name: "www.example.com"}) {
|
||||
t.Error("full domain should not match subdomain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeositeMatcher_matchDomain_Root(t *testing.T) {
|
||||
m := &geositeMatcher{}
|
||||
d := geositeDomain{
|
||||
Type: geositeDomainRoot,
|
||||
Value: "example.com",
|
||||
}
|
||||
if !m.matchDomain(d, HostInfo{Name: "example.com"}) {
|
||||
t.Error("root domain should match exact")
|
||||
}
|
||||
if !m.matchDomain(d, HostInfo{Name: "www.example.com"}) {
|
||||
t.Error("root domain should match subdomain")
|
||||
}
|
||||
if m.matchDomain(d, HostInfo{Name: "www.example.com.au"}) {
|
||||
t.Error("root domain should not match unrelated suffix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeositeMatcher_matchDomain_Attrs(t *testing.T) {
|
||||
m := &geositeMatcher{Attrs: []string{"ads"}}
|
||||
d := geositeDomain{
|
||||
Type: geositeDomainPlain,
|
||||
Value: "google",
|
||||
Attrs: map[string]bool{"ads": true},
|
||||
}
|
||||
if !m.matchDomain(d, HostInfo{Name: "google.com"}) {
|
||||
t.Error("should match when domain has required attr")
|
||||
}
|
||||
|
||||
dNoAttrs := geositeDomain{
|
||||
Type: geositeDomainPlain,
|
||||
Value: "google",
|
||||
Attrs: map[string]bool{},
|
||||
}
|
||||
if m.matchDomain(dNoAttrs, HostInfo{Name: "google.com"}) {
|
||||
t.Error("should not match when domain lacks required attr")
|
||||
}
|
||||
|
||||
dOtherAttrs := geositeDomain{
|
||||
Type: geositeDomainPlain,
|
||||
Value: "google",
|
||||
Attrs: map[string]bool{"news": true},
|
||||
}
|
||||
if m.matchDomain(dOtherAttrs, HostInfo{Name: "google.com"}) {
|
||||
t.Error("should not match when domain has wrong attr")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeositeMatcher_Match(t *testing.T) {
|
||||
m := &geositeMatcher{
|
||||
Domains: []geositeDomain{
|
||||
{Type: geositeDomainFull, Value: "exact.com"},
|
||||
{Type: geositeDomainPlain, Value: "partial"},
|
||||
},
|
||||
}
|
||||
if !m.Match(HostInfo{Name: "exact.com"}) {
|
||||
t.Error("should match full domain")
|
||||
}
|
||||
if !m.Match(HostInfo{Name: "www.partial.net"}) {
|
||||
t.Error("should match partial domain")
|
||||
}
|
||||
if m.Match(HostInfo{Name: "other.net"}) {
|
||||
t.Error("should not match unrelated host")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainAttributeToMap(t *testing.T) {
|
||||
attrs := []*v2geo.Domain_Attribute{
|
||||
{Key: "ads"},
|
||||
{Key: "news"},
|
||||
}
|
||||
got := domainAttributeToMap(attrs)
|
||||
if len(got) != 2 || !got["ads"] || !got["news"] {
|
||||
t.Errorf("domainAttributeToMap = %v, want {ads:true, news:true}", got)
|
||||
}
|
||||
|
||||
got2 := domainAttributeToMap(nil)
|
||||
if len(got2) != 0 {
|
||||
t.Errorf("domainAttributeToMap(nil) = %v, want empty map", got2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGeoIPMatcher(t *testing.T) {
|
||||
list := &v2geo.GeoIP{
|
||||
Cidr: []*v2geo.CIDR{
|
||||
{Ip: net.IPv4(10, 0, 0, 0).To4(), Prefix: 8},
|
||||
{Ip: net.IPv4(192, 168, 0, 0).To4(), Prefix: 16},
|
||||
},
|
||||
InverseMatch: false,
|
||||
}
|
||||
m, err := newGeoIPMatcher(list)
|
||||
if err != nil {
|
||||
t.Fatalf("newGeoIPMatcher error: %v", err)
|
||||
}
|
||||
if len(m.N4) != 2 {
|
||||
t.Errorf("expected 2 IPv4 nets, got %d", len(m.N4))
|
||||
}
|
||||
if m.Inverse {
|
||||
t.Error("Inverse should be false")
|
||||
}
|
||||
// Verify sorted order: 10.0.0.0/8 < 192.168.0.0/16
|
||||
if m.N4[0].IP.String() != "10.0.0.0" {
|
||||
t.Errorf("N4[0] = %s, want 10.0.0.0", m.N4[0].IP)
|
||||
}
|
||||
if m.N4[1].IP.String() != "192.168.0.0" {
|
||||
t.Errorf("N4[1] = %s, want 192.168.0.0", m.N4[1].IP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGeoIPMatcher_IPv6(t *testing.T) {
|
||||
list := &v2geo.GeoIP{
|
||||
Cidr: []*v2geo.CIDR{
|
||||
{Ip: net.ParseIP("2001:db8::"), Prefix: 32},
|
||||
},
|
||||
}
|
||||
m, err := newGeoIPMatcher(list)
|
||||
if err != nil {
|
||||
t.Fatalf("newGeoIPMatcher error: %v", err)
|
||||
}
|
||||
if len(m.N6) != 1 {
|
||||
t.Errorf("expected 1 IPv6 net, got %d", len(m.N6))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGeoIPMatcher_InvalidIPLength(t *testing.T) {
|
||||
list := &v2geo.GeoIP{
|
||||
Cidr: []*v2geo.CIDR{
|
||||
{Ip: []byte{1, 2, 3}, Prefix: 24},
|
||||
},
|
||||
}
|
||||
_, err := newGeoIPMatcher(list)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid IP length")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGeositeMatcher(t *testing.T) {
|
||||
list := &v2geo.GeoSite{
|
||||
Domain: []*v2geo.Domain{
|
||||
{Type: v2geo.Domain_Plain, Value: "google"},
|
||||
{Type: v2geo.Domain_Full, Value: "exact.com"},
|
||||
},
|
||||
}
|
||||
m, err := newGeositeMatcher(list, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("newGeositeMatcher error: %v", err)
|
||||
}
|
||||
if len(m.Domains) != 2 {
|
||||
t.Errorf("expected 2 domains, got %d", len(m.Domains))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGeositeMatcher_WithAttrs(t *testing.T) {
|
||||
list := &v2geo.GeoSite{
|
||||
Domain: []*v2geo.Domain{
|
||||
{
|
||||
Type: v2geo.Domain_RootDomain,
|
||||
Value: "google.com",
|
||||
Attribute: []*v2geo.Domain_Attribute{
|
||||
{Key: "ads"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
m, err := newGeositeMatcher(list, []string{"ads"})
|
||||
if err != nil {
|
||||
t.Fatalf("newGeositeMatcher error: %v", err)
|
||||
}
|
||||
if !m.Match(HostInfo{Name: "www.google.com"}) {
|
||||
t.Error("should match with root domain and attr")
|
||||
}
|
||||
}
|
||||
115
ruleset/expr.go
115
ruleset/expr.go
@@ -31,6 +31,9 @@ type ExprRule struct {
|
||||
Log bool `yaml:"log"`
|
||||
Modifier ModifierEntry `yaml:"modifier"`
|
||||
Expr string `yaml:"expr"`
|
||||
StartTime string `yaml:"start_time"`
|
||||
StopTime string `yaml:"stop_time"`
|
||||
Weekdays []string `yaml:"weekdays"`
|
||||
}
|
||||
|
||||
type ModifierEntry struct {
|
||||
@@ -56,6 +59,10 @@ type compiledExprRule struct {
|
||||
ModInstance modifier.Instance
|
||||
Program *vm.Program
|
||||
GeoSiteConditions []string
|
||||
StartTimeSecs int // seconds since midnight, -1 if unset
|
||||
StopTimeSecs int // seconds since midnight, -1 if unset
|
||||
Weekdays []time.Weekday
|
||||
WeekdaysNegated bool
|
||||
}
|
||||
|
||||
var _ Ruleset = (*exprRuleset)(nil)
|
||||
@@ -73,10 +80,13 @@ func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
|
||||
|
||||
func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
||||
env := streamInfoToExprEnv(info)
|
||||
now := time.Now()
|
||||
for _, rule := range r.Rules {
|
||||
if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) {
|
||||
continue
|
||||
}
|
||||
v, err := vm.Run(rule.Program, env)
|
||||
if err != nil {
|
||||
// Log the error and continue to the next rule.
|
||||
r.Logger.MatchError(info, rule.Name, err)
|
||||
continue
|
||||
}
|
||||
@@ -163,12 +173,34 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
||||
depAnMap[name] = a
|
||||
}
|
||||
}
|
||||
startSecs := -1
|
||||
if rule.StartTime != "" {
|
||||
startSecs, err = parseTimeOfDay(rule.StartTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rule %q has invalid start_time: %w", rule.Name, err)
|
||||
}
|
||||
}
|
||||
stopSecs := -1
|
||||
if rule.StopTime != "" {
|
||||
stopSecs, err = parseTimeOfDay(rule.StopTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rule %q has invalid stop_time: %w", rule.Name, err)
|
||||
}
|
||||
}
|
||||
weekdays, weekdaysNegated, err := parseWeekdays(rule.Weekdays)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rule %q has invalid weekdays: %w", rule.Name, err)
|
||||
}
|
||||
cr := compiledExprRule{
|
||||
Name: rule.Name,
|
||||
Action: action,
|
||||
Log: rule.Log,
|
||||
Program: program,
|
||||
GeoSiteConditions: extractGeoSiteConditions(rule.Expr),
|
||||
StartTimeSecs: startSecs,
|
||||
StopTimeSecs: stopSecs,
|
||||
Weekdays: weekdays,
|
||||
WeekdaysNegated: weekdaysNegated,
|
||||
}
|
||||
if action != nil && *action == ActionModify {
|
||||
mod, ok := fullModMap[rule.Modifier.Name]
|
||||
@@ -391,6 +423,87 @@ func buildFunctionMap(config *BuiltinConfig) (map[string]*Function, *geo.GeoMatc
|
||||
}, geoMatcher
|
||||
}
|
||||
|
||||
func matchTime(now time.Time, startSecs, stopSecs int, weekdays []time.Weekday, negated bool) bool {
|
||||
if startSecs >= 0 || stopSecs >= 0 {
|
||||
currentSecs := now.Hour()*3600 + now.Minute()*60 + now.Second()
|
||||
if startSecs >= 0 && stopSecs >= 0 {
|
||||
if startSecs <= stopSecs {
|
||||
if currentSecs < startSecs || currentSecs > stopSecs {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if currentSecs < startSecs && currentSecs > stopSecs {
|
||||
return false
|
||||
}
|
||||
}
|
||||
} else if startSecs >= 0 {
|
||||
if currentSecs < startSecs {
|
||||
return false
|
||||
}
|
||||
} else if currentSecs > stopSecs {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if len(weekdays) > 0 {
|
||||
current := now.Weekday()
|
||||
found := false
|
||||
for _, d := range weekdays {
|
||||
if current == d {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if negated == found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func parseTimeOfDay(s string) (int, error) {
|
||||
t, err := time.Parse("15:04:05", s)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("invalid time %q (expected hh:mm:ss)", s)
|
||||
}
|
||||
return t.Hour()*3600 + t.Minute()*60 + t.Second(), nil
|
||||
}
|
||||
|
||||
func parseWeekdays(days []string) ([]time.Weekday, bool, error) {
|
||||
if len(days) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
negated := false
|
||||
parsed := make([]time.Weekday, 0, len(days))
|
||||
for i, d := range days {
|
||||
d = strings.TrimSpace(d)
|
||||
if i == 0 && strings.HasPrefix(d, "!") {
|
||||
negated = true
|
||||
d = strings.TrimSpace(strings.TrimPrefix(d, "!"))
|
||||
}
|
||||
var wd time.Weekday
|
||||
switch strings.ToLower(d) {
|
||||
case "sun", "sunday":
|
||||
wd = time.Sunday
|
||||
case "mon", "monday":
|
||||
wd = time.Monday
|
||||
case "tue", "tues", "tuesday":
|
||||
wd = time.Tuesday
|
||||
case "wed", "wednesday":
|
||||
wd = time.Wednesday
|
||||
case "thu", "thur", "thurs", "thursday":
|
||||
wd = time.Thursday
|
||||
case "fri", "friday":
|
||||
wd = time.Friday
|
||||
case "sat", "saturday":
|
||||
wd = time.Saturday
|
||||
default:
|
||||
return nil, false, fmt.Errorf("invalid weekday %q", d)
|
||||
}
|
||||
parsed = append(parsed, wd)
|
||||
}
|
||||
return parsed, negated, nil
|
||||
}
|
||||
|
||||
const rulesetLogMetaKey = "_ruleset"
|
||||
|
||||
func addGeoSiteLogMetadata(info StreamInfo, gm *geo.GeoMatcher, conditions []string) StreamInfo {
|
||||
|
||||
121
ruleset/interface_test.go
Normal file
121
ruleset/interface_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package ruleset
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||
)
|
||||
|
||||
func TestAction_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
action Action
|
||||
want string
|
||||
}{
|
||||
{ActionMaybe, "maybe"},
|
||||
{ActionAllow, "allow"},
|
||||
{ActionBlock, "block"},
|
||||
{ActionDrop, "drop"},
|
||||
{ActionModify, "modify"},
|
||||
{Action(99), "unknown"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.want, func(t *testing.T) {
|
||||
if got := tt.action.String(); got != tt.want {
|
||||
t.Errorf("Action.String() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtocol_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
protocol Protocol
|
||||
want string
|
||||
}{
|
||||
{ProtocolTCP, "tcp"},
|
||||
{ProtocolUDP, "udp"},
|
||||
{Protocol(99), "unknown"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.want, func(t *testing.T) {
|
||||
if got := tt.protocol.String(); got != tt.want {
|
||||
t.Errorf("Protocol.String() = %q, want %q", got, tt.protocol)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtocol_Constants(t *testing.T) {
|
||||
if ProtocolTCP != 0 {
|
||||
t.Errorf("ProtocolTCP = %d, want 0", ProtocolTCP)
|
||||
}
|
||||
if ProtocolUDP != 1 {
|
||||
t.Errorf("ProtocolUDP = %d, want 1", ProtocolUDP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAction_Constants(t *testing.T) {
|
||||
if ActionMaybe != 0 {
|
||||
t.Errorf("ActionMaybe = %d, want 0", ActionMaybe)
|
||||
}
|
||||
if ActionAllow != 1 {
|
||||
t.Errorf("ActionAllow = %d, want 1", ActionAllow)
|
||||
}
|
||||
if ActionBlock != 2 {
|
||||
t.Errorf("ActionBlock = %d, want 2", ActionBlock)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamInfo_SrcString(t *testing.T) {
|
||||
info := StreamInfo{
|
||||
SrcIP: net.ParseIP("192.168.1.1"),
|
||||
SrcPort: 8080,
|
||||
}
|
||||
want := "192.168.1.1:8080"
|
||||
if got := info.SrcString(); got != want {
|
||||
t.Errorf("SrcString() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamInfo_DstString(t *testing.T) {
|
||||
info := StreamInfo{
|
||||
DstIP: net.ParseIP("10.0.0.1"),
|
||||
DstPort: 443,
|
||||
}
|
||||
want := "10.0.0.1:443"
|
||||
if got := info.DstString(); got != want {
|
||||
t.Errorf("DstString() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamInfo_SrcString_IPv6(t *testing.T) {
|
||||
info := StreamInfo{
|
||||
SrcIP: net.ParseIP("::1"),
|
||||
SrcPort: 53,
|
||||
}
|
||||
want := "[::1]:53"
|
||||
if got := info.SrcString(); got != want {
|
||||
t.Errorf("SrcString() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchResult_ZeroValue(t *testing.T) {
|
||||
var mr MatchResult
|
||||
if mr.Action != ActionMaybe {
|
||||
t.Errorf("zero MatchResult.Action = %v, want ActionMaybe (0)", mr.Action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamInfo_PropsInitialization(t *testing.T) {
|
||||
info := StreamInfo{
|
||||
Props: analyzer.CombinedPropMap{
|
||||
"tls": analyzer.PropMap{"sni": "example.com"},
|
||||
},
|
||||
}
|
||||
if info.Props.Get("tls", "sni") != "example.com" {
|
||||
t.Error("StreamInfo.Props not properly initialized")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user