test: improve coverage across package

This commit is contained in:
2026-05-01 14:09:10 +05:30
parent e1c68ec7d0
commit e3f1f5046a
18 changed files with 2652 additions and 6 deletions

View File

@@ -41,6 +41,29 @@ func main() {
}
```
## Testing
Run all tests:
```bash
go test ./...
```
Run tests with coverage:
```bash
go test -cover ./... # per-package summary
go test -coverprofile=coverage.out ./... # full profile
go tool cover -html=coverage.out -o coverage.html # HTML report
```
Run tests for a specific package:
```bash
go test -v ./analyzer/tcp/
go test -v -run TestSSH ./analyzer/tcp/
```
Based on OpenGFW by apernet: https://github.com/apernet/OpenGFW
## LICENSE

124
analyzer/interface_test.go Normal file
View File

@@ -0,0 +1,124 @@
package analyzer
import (
"reflect"
"testing"
)
func TestPropMap_Get(t *testing.T) {
m := PropMap{
"a": "value-a",
"nested": PropMap{
"b": "value-b",
"deep": PropMap{
"c": "value-c",
},
},
}
tests := []struct {
name string
key string
want interface{}
}{
{"simple key", "a", "value-a"},
{"nested key", "nested.b", "value-b"},
{"deeply nested", "nested.deep.c", "value-c"},
{"non-existent top", "x", nil},
{"non-existent nested", "nested.x", nil},
{"empty key", "", nil},
{"partial path", "nested", PropMap{"b": "value-b", "deep": PropMap{"c": "value-c"}}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := m.Get(tt.key)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("PropMap.Get(%q) = %v, want %v", tt.key, got, tt.want)
}
})
}
}
func TestPropMap_Get_NilMap(t *testing.T) {
var m PropMap
if got := m.Get("any"); got != nil {
t.Errorf("nil PropMap.Get() = %v, want nil", got)
}
}
func TestPropMap_Get_EmptyMap(t *testing.T) {
m := PropMap{}
if got := m.Get("any"); got != nil {
t.Errorf("empty PropMap.Get() = %v, want nil", got)
}
}
func TestPropMap_Get_IntermediateNonMap(t *testing.T) {
m := PropMap{
"a": "hello",
}
if got := m.Get("a.b.c"); got != nil {
t.Errorf("PropMap.Get() through non-map = %v, want nil", got)
}
}
func TestCombinedPropMap_Get(t *testing.T) {
cm := CombinedPropMap{
"tls": PropMap{
"req": PropMap{
"sni": "example.com",
},
},
"http": PropMap{
"resp": PropMap{
"status": 200,
},
},
}
tests := []struct {
name string
analyzer string
key string
want interface{}
}{
{"tls sni", "tls", "req.sni", "example.com"},
{"http status", "http", "resp.status", 200},
{"unknown analyzer", "dns", "any", nil},
{"valid analyzer bad key", "tls", "req.nonexistent", nil},
{"empty analyzer", "", "key", nil},
{"empty key", "tls", "", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := cm.Get(tt.analyzer, tt.key)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("CombinedPropMap.Get(%q, %q) = %v, want %v", tt.analyzer, tt.key, got, tt.want)
}
})
}
}
func TestCombinedPropMap_Get_Nil(t *testing.T) {
var cm CombinedPropMap
if got := cm.Get("any", "key"); got != nil {
t.Errorf("nil CombinedPropMap.Get() = %v, want nil", got)
}
}
func TestPropUpdateType_Values(t *testing.T) {
if PropUpdateNone != 0 {
t.Errorf("PropUpdateNone = %d, want 0", PropUpdateNone)
}
if PropUpdateMerge != 1 {
t.Errorf("PropUpdateMerge = %d, want 1", PropUpdateMerge)
}
if PropUpdateReplace != 2 {
t.Errorf("PropUpdateReplace = %d, want 2", PropUpdateReplace)
}
if PropUpdateDelete != 3 {
t.Errorf("PropUpdateDelete = %d, want 3", PropUpdateDelete)
}
}

View File

@@ -0,0 +1,319 @@
package internal
import (
"reflect"
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer/utils"
)
func buildClientHelloMsg(t *testing.T) *utils.ByteBuffer {
t.Helper()
// Bytes taken from the standard TLS test vector, starting after the record
// header (5 bytes) and handshake header (4 bytes).
body := []byte{
0x03, 0x03, // version TLS 1.2
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
0x00, // session ID length = 0
0x00, 0x20, // cipher suites length = 32 bytes = 16 suites
0xcc, 0xa8, 0xcc, 0xa9, 0xc0, 0x2f, 0xc0, 0x30,
0xc0, 0x2b, 0xc0, 0x2c, 0xc0, 0x13, 0xc0, 0x09,
0xc0, 0x14, 0xc0, 0x0a, 0x00, 0x9c, 0x00, 0x9d,
0x00, 0x2f, 0x00, 0x35, 0xc0, 0x12, 0x00, 0x0a, // ciphers
0x01, // compression methods length = 1
0x00, // compression method = null
0x00, 0x58, // extensions length = 88
// extension: server_name
0x00, 0x00, // type = server_name
0x00, 0x18, // length = 24
0x00, 0x16, // server name list length = 22
0x00, // name type = hostname
0x00, 0x13, // name length = 19
'e', 'x', 'a', 'm', 'p', 'l', 'e', '.', 'u', 'l', 'f', 'h', 'e', 'i', 'm', '.', 'n', 'e', 't',
// extension: status_request
0x00, 0x05, // type = status_request
0x00, 0x05, // length = 5
0x01, 0x00, 0x00, 0x00, 0x00,
// extension: supported_groups
0x00, 0x0a, // type = supported_groups
0x00, 0x0a, // length = 10
0x00, 0x08, // list length = 8
0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19,
// extension: ec_point_formats
0x00, 0x0b, // type = ec_point_formats
0x00, 0x02, // length = 2
0x01, 0x00,
// extension: signature_algorithms
0x00, 0x0d, // type = signature_algorithms
0x00, 0x12, // length = 18
0x00, 0x10, // list length = 16
0x04, 0x01, 0x04, 0x03, 0x05, 0x01, 0x05, 0x03,
0x06, 0x01, 0x06, 0x03, 0x02, 0x01, 0x02, 0x03,
// extension: renegotiation_info
0xff, 0x01, // type = renegotiation_info
0x00, 0x01, // length = 1
0x00,
// extension: extended_master_secret (empty)
0x00, 0x17, // type = extended_master_secret
0x00, 0x00, // length = 0
// extension: session_ticket (empty)
0x00, 0x23, // type = session_ticket
0x00, 0x00, // length = 0
}
return &utils.ByteBuffer{Buf: body}
}
func TestParseTLSClientHelloMsgData(t *testing.T) {
chBuf := buildClientHelloMsg(t)
m := ParseTLSClientHelloMsgData(chBuf)
if m == nil {
t.Fatal("ParseTLSClientHelloMsgData returned nil")
}
wantVersion := uint16(0x0303)
if v, ok := m["version"].(uint16); !ok || v != wantVersion {
t.Errorf("version = %v, want %v", m["version"], wantVersion)
}
wantCiphers := []uint16{52392, 52393, 49199, 49200, 49195, 49196, 49171, 49161, 49172, 49162, 156, 157, 47, 53, 49170, 10}
if c, ok := m["ciphers"].([]uint16); ok {
if !reflect.DeepEqual(c, wantCiphers) {
t.Errorf("ciphers = %v, want %v", c, wantCiphers)
}
} else {
t.Errorf("ciphers missing or wrong type: %T", m["ciphers"])
}
if sni, ok := m["sni"].(string); !ok || sni != "example.ulfheim.net" {
t.Errorf("sni = %q, want %q", m["sni"], "example.ulfheim.net")
}
if _, ok := m["compression"]; !ok {
t.Error("compression key missing")
}
if _, ok := m["session"]; !ok {
t.Error("session key missing")
}
if _, ok := m["random"]; !ok {
t.Error("random key missing")
}
}
func TestParseTLSServerHelloMsgData(t *testing.T) {
// ServerHello message body (after record header + handshake header)
body := []byte{
0x03, 0x03, // version TLS 1.2
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77,
0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f,
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87,
0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, // random
0x00, // session ID length = 0
0xc0, 0x13, // cipher suite
0x00, // compression method
0x00, 0x05, // extensions length = 5
// extension: renegotiation_info
0xff, 0x01, // type
0x00, 0x01, // length = 1
0x00,
}
m := ParseTLSServerHelloMsgData(&utils.ByteBuffer{Buf: body})
if m == nil {
t.Fatal("ParseTLSServerHelloMsgData returned nil")
}
wantCipher := uint16(0xc013)
if c, ok := m["cipher"].(uint16); !ok || c != wantCipher {
t.Errorf("cipher = %v, want %v", m["cipher"], wantCipher)
}
wantCompression := uint8(0)
if c, ok := m["compression"].(uint8); !ok || c != wantCompression {
t.Errorf("compression = %v, want %v", m["compression"], wantCompression)
}
}
func TestParseTLSServerHelloMsgData_NoExtensions(t *testing.T) {
// ServerHello message without extensions
body := []byte{
0x03, 0x03, // version
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
0x00, // session ID length
0x00, 0xff, // cipher suite
0x00, // compression method
}
m := ParseTLSServerHelloMsgData(&utils.ByteBuffer{Buf: body})
if m == nil {
t.Fatal("ParseTLSServerHelloMsgData returned nil for no-extensions case")
}
if _, ok := m["cipher"]; !ok {
t.Error("cipher key missing")
}
}
func TestParseTLSClientHelloMsgData_Truncated(t *testing.T) {
tests := []struct {
name string
buf []byte
}{
{"too short for session id", []byte{0x03, 0x03, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f}},
{"odd cipher suites length", []byte{
0x03, 0x03, // version
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
0x00, // session ID length
0x00, 0x03, // odd cipher suites length
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: tt.buf})
if m != nil {
t.Error("expected nil for truncated input")
}
})
}
}
func TestParseTLSClientHelloMsgData_ECH(t *testing.T) {
// ClientHello with ECH extension
body := []byte{
0x03, 0x03, // version
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
0x00, // session ID length
0x00, 0x02, // cipher suites length = 2
0x13, 0x01, // TLS_AES_128_GCM_SHA256
0x01, // compression methods length = 1
0x00, // null compression
0x00, 0x07, // extensions length = 7
// ECH extension
0xfe, 0x0d, // type = encrypted_client_hello
0x00, 0x03, // length = 3
0x01, 0x02, 0x03, // some ECH data
}
m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: body})
if m == nil {
t.Fatal("ParseTLSClientHelloMsgData returned nil")
}
ech, ok := m["ech"].(bool)
if !ok || !ech {
t.Errorf("ech = %v, want true", m["ech"])
}
}
func TestParseTLSClientHelloMsgData_ALPN(t *testing.T) {
// ClientHello with ALPN extension
body := []byte{
0x03, 0x03, // version
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
0x00, // session ID length
0x00, 0x02, // cipher suites length = 2
0x13, 0x01, // cipher suite
0x01, // compression methods length = 1
0x00, // compression method
0x00, 0x12, // extensions length = 18
// ALPN extension
0x00, 0x10, // type = ALPN
0x00, 0x0e, // length = 14
0x00, 0x0c, // list length = 12
0x08, 'h', 't', 't', 'p', '/', '1', '.', '1',
0x02, 'h', '2',
}
m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: body})
if m == nil {
t.Fatal("ParseTLSClientHelloMsgData returned nil")
}
alpn, ok := m["alpn"].([]string)
if !ok {
t.Fatalf("alpn missing or wrong type: %T", m["alpn"])
}
if len(alpn) != 2 || alpn[0] != "http/1.1" || alpn[1] != "h2" {
t.Errorf("alpn = %v, want [http/1.1 h2]", alpn)
}
}
func TestParseTLSClientHelloMsgData_SupportedVersionsClient(t *testing.T) {
// ClientHello with supported_versions extension (client format - list)
body := []byte{
0x03, 0x03, // version
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
0x00, // session ID length
0x00, 0x02, // cipher suites length = 2
0x13, 0x01, // cipher suite
0x01, // compression methods length = 1
0x00, // compression method
0x00, 0x0b, // extensions length = 11
// supported_versions (client list format)
0x00, 0x2b, // type = supported_versions
0x00, 0x07, // length = 7
0x06, // list length = 6
0x03, 0x04, // TLS 1.3
0x03, 0x03, // TLS 1.2
0x03, 0x01, // TLS 1.0
}
m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: body})
if m == nil {
t.Fatal("ParseTLSClientHelloMsgData returned nil")
}
versions, ok := m["supported_versions"].([]uint16)
if !ok {
t.Fatalf("supported_versions missing or wrong type: %T", m["supported_versions"])
}
want := []uint16{0x0304, 0x0303, 0x0301}
if !reflect.DeepEqual(versions, want) {
t.Errorf("supported_versions = %v, want %v", versions, want)
}
}
func TestParseTLSServerHelloMsgData_SupportedVersionsServer(t *testing.T) {
// ServerHello with supported_versions extension (server format - single value)
body := []byte{
0x03, 0x03, // version
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
0x00, // session ID length
0x13, 0x01, // cipher suite
0x00, // compression method
0x00, 0x06, // extensions length = 6
// supported_versions (server format - single value)
0x00, 0x2b, // type
0x00, 0x02, // length = 2
0x03, 0x04, // TLS 1.3
}
m := ParseTLSServerHelloMsgData(&utils.ByteBuffer{Buf: body})
if m == nil {
t.Fatal("ParseTLSServerHelloMsgData returned nil")
}
v, ok := m["supported_versions"].(uint16)
if !ok {
t.Fatalf("supported_versions missing or wrong type: %T", m["supported_versions"])
}
if v != 0x0304 {
t.Errorf("supported_versions = 0x%04x, want 0x0304", v)
}
}

256
analyzer/tcp/fet_test.go Normal file
View File

@@ -0,0 +1,256 @@
package tcp
import (
"testing"
)
func TestPopCount(t *testing.T) {
tests := []struct {
input byte
want int
}{
{0x00, 0},
{0x01, 1},
{0x02, 1},
{0x03, 2},
{0x07, 3},
{0x0f, 4},
{0xff, 8},
{0x55, 4}, // 01010101
{0xaa, 4}, // 10101010
}
for _, tt := range tests {
if got := popCount(tt.input); got != tt.want {
t.Errorf("popCount(0x%02x) = %d, want %d", tt.input, got, tt.want)
}
}
}
func TestAveragePopCount(t *testing.T) {
tests := []struct {
name string
input []byte
want float32
}{
{"empty", []byte{}, 0},
{"all zeros", []byte{0x00, 0x00, 0x00}, 0},
{"all ones", []byte{0xff, 0xff}, 8.0},
{"mixed", []byte{0x00, 0xff}, 4.0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := averagePopCount(tt.input)
if got != tt.want {
t.Errorf("averagePopCount() = %v, want %v", got, tt.want)
}
})
}
}
func TestIsFirstSixPrintable(t *testing.T) {
tests := []struct {
name string
input []byte
want bool
}{
{"too short", []byte("abc"), false},
{"all printable", []byte("abcdef"), true},
{"non-printable at pos 0", []byte{0x00, 'a', 'b', 'c', 'd', 'e'}, false},
{"non-printable at pos 5", []byte{'a', 'b', 'c', 'd', 'e', 0x1f}, false},
{"exactly 6 printable", []byte("123456"), true},
{"spaces", []byte(" "), true},
{"non-printable middle", []byte{'a', 'b', 0x01, 'd', 'e', 'f'}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isFirstSixPrintable(tt.input)
if got != tt.want {
t.Errorf("isFirstSixPrintable(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestPrintablePercentage(t *testing.T) {
tests := []struct {
name string
input []byte
want float32
}{
{"empty", []byte{}, 0},
{"all printable", []byte("hello"), 1.0},
{"none printable", []byte{0x00, 0x01, 0x02, 0x03}, 0},
{"half printable", []byte{'a', 0x00, 'b', 0x00}, 0.5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := printablePercentage(tt.input)
if got != tt.want {
t.Errorf("printablePercentage() = %v, want %v", got, tt.want)
}
})
}
}
func TestContiguousPrintable(t *testing.T) {
tests := []struct {
name string
input []byte
want int
}{
{"empty", []byte{}, 0},
{"all printable", []byte("hello world"), 11},
{"none printable", []byte{0x00, 0x01, 0x02}, 0},
{"start printable", []byte{'a', 'b', 'c', 0x00, 'd', 'e', 'f'}, 3},
{"end printable", []byte{0x00, 'a', 'b', 'c', 'd', 'e', 'f'}, 6},
{"middle printable", []byte{0x00, 'a', 'b', 'c', 0x00}, 3},
{"two segments", []byte{'a', 'b', 0x00, 'c', 'd', 'e'}, 3},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := contiguousPrintable(tt.input)
if got != tt.want {
t.Errorf("contiguousPrintable() = %d, want %d", got, tt.want)
}
})
}
}
func TestIsTLSorHTTP(t *testing.T) {
tests := []struct {
name string
input []byte
want bool
}{
{"too short", []byte("AB"), false},
// TLS ClientHello: 0x16 0x03 0x01
{"tls 1.0", []byte{0x16, 0x03, 0x01}, true},
// TLS 0x17 is application data record
{"tls app data", []byte{0x17, 0x03, 0x03}, true},
{"tls max content type", []byte{0x17, 0x03, 0x09}, true},
{"bad tls content type", []byte{0x15, 0x03, 0x01}, false},
{"bad tls version", []byte{0x16, 0x04, 0x01}, false},
{"bad tls length", []byte{0x16, 0x03, 0x0a}, false},
// HTTP methods
{"GET", []byte("GET / HTTP/1.1..."), true},
{"HEAD", []byte("HEAD /index.html..."), true},
{"POST", []byte("POST /api..."), true},
{"PUT", []byte("PUT /data..."), true},
{"DELETE", []byte("DELETE /..."), true},
{"CONNECT", []byte("CONNECT proxy..."), true},
{"OPTIONS", []byte("OPTIONS *..."), true},
{"TRACE", []byte("TRACE /..."), true},
{"PATCH", []byte("PATCH /..."), true},
{"random data", []byte{0x00, 0x01, 0x02, 0x03}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isTLSorHTTP(tt.input)
if got != tt.want {
t.Errorf("isTLSorHTTP(%v) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestIsPrintable(t *testing.T) {
if !isPrintable('a') {
t.Error("'a' should be printable")
}
if isPrintable(0x00) {
t.Error("0x00 should not be printable")
}
if !isPrintable(0x20) {
t.Error("0x20 (space) should be printable")
}
if !isPrintable(0x7e) {
t.Error("0x7e (~) should be printable")
}
if isPrintable(0x7f) {
t.Error("0x7f (DEL) should not be printable")
}
}
func TestFETStream_Feed(t *testing.T) {
s := newFETStream(nil)
u, done := s.Feed(false, false, false, 0, []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
if u == nil {
t.Fatal("Feed returned nil update")
}
if !done {
t.Error("FET should be done after first packet")
}
m := u.M
if _, ok := m["ex1"]; !ok {
t.Error("ex1 missing")
}
if _, ok := m["ex2"]; !ok {
t.Error("ex2 missing")
}
if _, ok := m["ex3"]; !ok {
t.Error("ex3 missing")
}
if _, ok := m["ex4"]; !ok {
t.Error("ex4 missing")
}
if _, ok := m["ex5"]; !ok {
t.Error("ex5 missing")
}
if yes, ok := m["yes"].(bool); ok && yes {
t.Error("HTTP should be exempt (yes=false)")
}
if u.Type != 2 {
t.Errorf("prop update type = %d, want PropUpdateReplace", u.Type)
}
}
func TestFETStream_Feed_EncryptedLike(t *testing.T) {
s := newFETStream(nil)
data := make([]byte, 100)
for i := range data {
data[i] = byte(i % 256)
}
u, done := s.Feed(false, false, false, 0, data)
if u == nil {
t.Fatal("Feed returned nil update")
}
if !done {
t.Error("should be done")
}
}
func TestFETStream_Feed_Skip(t *testing.T) {
s := newFETStream(nil)
_, done := s.Feed(false, false, false, 5, []byte("data"))
if !done {
t.Error("skip != 0 should return done=true")
}
}
func TestFETStream_Feed_Empty(t *testing.T) {
s := newFETStream(nil)
u, done := s.Feed(false, false, false, 0, []byte{})
if u != nil || done {
t.Error("empty data should return nil, false")
}
}
func TestFETAnalyzer_Name(t *testing.T) {
a := &FETAnalyzer{}
if a.Name() != "fet" {
t.Errorf("Name() = %q, want fet", a.Name())
}
}
func TestFETStream_Close(t *testing.T) {
s := newFETStream(nil)
if u := s.Close(false); u != nil {
t.Error("Close should return nil")
}
}

159
analyzer/tcp/ssh_test.go Normal file
View File

@@ -0,0 +1,159 @@
package tcp
import (
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
)
func TestSSHAnalyzer_Name(t *testing.T) {
a := &SSHAnalyzer{}
if a.Name() != "ssh" {
t.Errorf("Name() = %q, want ssh", a.Name())
}
}
func TestSSHStream_Feed_Client(t *testing.T) {
s := newSSHStream(nil)
u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3\r\n"))
if u == nil {
t.Fatal("Feed returned nil update")
}
if done {
t.Error("should not be done (server not yet received)")
}
client, ok := u.M["client"].(analyzer.PropMap)
if !ok {
t.Fatal("client prop missing")
}
if client["protocol"] != "2.0" {
t.Errorf("protocol = %v, want 2.0", client["protocol"])
}
if client["software"] != "OpenSSH_8.9p1" {
t.Errorf("software = %v, want OpenSSH_8.9p1", client["software"])
}
if comments, ok := client["comments"]; ok {
if comments != "Ubuntu-3" {
t.Errorf("comments = %v, want Ubuntu-3", comments)
}
}
}
func TestSSHStream_Feed_ClientWithComments(t *testing.T) {
s := newSSHStream(nil)
u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH_7.4 Ubuntu-3\r\n"))
if u == nil {
t.Fatal("Feed returned nil update")
}
if done {
t.Error("should not be done")
}
client := u.M["client"].(analyzer.PropMap)
if client["comments"] != "Ubuntu-3" {
t.Errorf("comments = %v, want Ubuntu-3", client["comments"])
}
}
func TestSSHStream_Feed_Both(t *testing.T) {
s := newSSHStream(nil)
s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH_8.9\r\n"))
u, done := s.Feed(true, false, false, 0, []byte("SSH-2.0-dropbear_2022.83\r\n"))
if u == nil {
t.Fatal("Feed returned nil update")
}
if !done {
t.Error("should be done after both sides")
}
server, ok := u.M["server"].(analyzer.PropMap)
if !ok {
t.Fatal("server prop missing")
}
if server["software"] != "dropbear_2022.83" {
t.Errorf("server software = %v", server["software"])
}
}
func TestSSHStream_Feed_NotSSH(t *testing.T) {
s := newSSHStream(nil)
u, done := s.Feed(false, false, false, 0, []byte("HTTP/1.1 200 OK\r\n"))
if u != nil {
t.Error("should return nil for non-SSH")
}
if !done {
t.Error("should be cancelled (done) for non-SSH")
}
}
func TestSSHStream_Feed_InvalidLine(t *testing.T) {
s := newSSHStream(nil)
u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-foo bar baz\r\n"))
if u != nil {
t.Error("should return nil for invalid line (>2 fields)")
}
if !done {
t.Error("should be cancelled for invalid SSH line")
}
}
func TestSSHStream_Feed_NoLineEnd(t *testing.T) {
s := newSSHStream(nil)
u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH"))
if u != nil {
t.Error("should return nil when no EOL found yet")
}
if done {
t.Error("should not be done, waiting for more data")
}
}
func TestSSHStream_Feed_IncompleteThenComplete(t *testing.T) {
s := newSSHStream(nil)
u, _ := s.Feed(false, false, false, 0, []byte("SSH-2.0-"))
if u != nil {
t.Error("first partial feed should return nil")
}
u, done := s.Feed(false, false, false, 0, []byte("Dropbear\r\n"))
if u == nil {
t.Fatal("second feed should return update")
}
if done {
t.Error("should not be done (only client)")
}
client := u.M["client"].(analyzer.PropMap)
if client["software"] != "Dropbear" {
t.Errorf("software = %v", client["software"])
}
}
func TestSSHStream_Feed_Skip(t *testing.T) {
s := newSSHStream(nil)
_, done := s.Feed(false, false, false, 5, []byte("data"))
if !done {
t.Error("skip != 0 should return done=true")
}
}
func TestSSHStream_Feed_Empty(t *testing.T) {
s := newSSHStream(nil)
u, done := s.Feed(false, false, false, 0, []byte{})
if u != nil || done {
t.Error("empty data should return nil, false")
}
}
func TestSSHStream_Close(t *testing.T) {
s := newSSHStream(nil)
s.clientBuf.Append([]byte("data"))
s.serverBuf.Append([]byte("data"))
s.clientMap = analyzer.PropMap{"key": "val"}
u := s.Close(false)
if u != nil {
t.Error("Close should return nil")
}
if s.clientBuf.Len() != 0 || s.serverBuf.Len() != 0 {
t.Error("Close should reset buffers")
}
if s.clientMap != nil || s.serverMap != nil {
t.Error("Close should nil maps")
}
}

View File

@@ -0,0 +1,187 @@
package udp
import (
"encoding/binary"
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
)
func makeWireGuardPacket(msgType byte, body []byte) []byte {
buf := make([]byte, 1+3+len(body))
buf[0] = msgType
copy(buf[4:], body)
return buf
}
func TestWireGuardUDPStream_Feed_HandshakeInitiation(t *testing.T) {
s := newWireGuardUDPStream(nil)
body := make([]byte, wireguardSizeHandshakeInitiation-4)
binary.LittleEndian.PutUint32(body[0:4], 0x01020304)
u, done := s.Feed(false, makeWireGuardPacket(wireguardTypeHandshakeInitiation, body))
if u == nil {
t.Fatal("Feed returned nil update")
}
if done {
t.Error("should not be done")
}
if u.Type != analyzer.PropUpdateReplace {
t.Errorf("Type = %d, want PropUpdateReplace", u.Type)
}
initMap, ok := u.M["handshake_initiation"].(analyzer.PropMap)
if !ok {
t.Fatal("handshake_initiation missing")
}
if idx := initMap["sender_index"].(uint32); idx != 0x01020304 {
t.Errorf("sender_index = %d, want %d", idx, 0x01020304)
}
}
func TestWireGuardUDPStream_Feed_HandshakeResponse(t *testing.T) {
s := newWireGuardUDPStream(nil)
body := make([]byte, wireguardSizeHandshakeResponse-4)
binary.LittleEndian.PutUint32(body[0:4], 0x01020304)
binary.LittleEndian.PutUint32(body[4:8], 0x05060708)
u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeHandshakeResponse, body))
if u == nil {
t.Fatal("Feed returned nil update")
}
respMap, ok := u.M["handshake_response"].(analyzer.PropMap)
if !ok {
t.Fatal("handshake_response missing")
}
if idx := respMap["sender_index"].(uint32); idx != 0x01020304 {
t.Errorf("sender_index = %d", idx)
}
if idx := respMap["receiver_index"].(uint32); idx != 0x05060708 {
t.Errorf("receiver_index = %d", idx)
}
}
func TestWireGuardUDPStream_Feed_PacketData(t *testing.T) {
s := newWireGuardUDPStream(nil)
body := make([]byte, wireguardMinSizePacketData-4)
binary.LittleEndian.PutUint32(body[0:4], 0x0a0b0c0d)
binary.LittleEndian.PutUint64(body[4:12], 42)
u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeData, body))
if u == nil {
t.Fatal("Feed returned nil update")
}
dataMap, ok := u.M["packet_data"].(analyzer.PropMap)
if !ok {
t.Fatal("packet_data missing")
}
if idx := dataMap["receiver_index"].(uint32); idx != 0x0a0b0c0d {
t.Errorf("receiver_index = %d", idx)
}
if ctr := dataMap["counter"].(uint64); ctr != 42 {
t.Errorf("counter = %d", ctr)
}
}
func TestWireGuardUDPStream_Feed_CookieReply(t *testing.T) {
s := newWireGuardUDPStream(nil)
body := make([]byte, wireguardSizePacketCookieReply-4)
binary.LittleEndian.PutUint32(body[0:4], 0x11111111)
u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeCookieReply, body))
if u == nil {
t.Fatal("Feed returned nil update")
}
crMap, ok := u.M["packet_cookie_reply"].(analyzer.PropMap)
if !ok {
t.Fatal("packet_cookie_reply missing")
}
if idx := crMap["receiver_index"].(uint32); idx != 0x11111111 {
t.Errorf("receiver_index = %d", idx)
}
}
func TestWireGuardUDPStream_Feed_InvalidLength(t *testing.T) {
s := newWireGuardUDPStream(nil)
u, done := s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01})
if u == nil || !done {
u, done = s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01})
}
if u == nil || !done {
u, _ = s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01})
}
u, done = s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01})
if u != nil || !done {
t.Error("should return done after 4 invalid packets")
}
}
func TestWireGuardUDPStream_Feed_TooShort(t *testing.T) {
s := newWireGuardUDPStream(nil)
u, _ := s.Feed(false, []byte{1, 0, 0})
if u != nil {
t.Error("should return nil for too-short packet")
}
}
func TestWireGuardUDPStream_Feed_NonZeroReserved(t *testing.T) {
s := newWireGuardUDPStream(nil)
u, _ := s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 1, 0, 0, 0, 0})
if u != nil {
t.Error("should return nil when reserved bytes are non-zero")
}
}
func TestWireGuardUDPStream_Feed_HandshakeInitiation_WrongSize(t *testing.T) {
s := newWireGuardUDPStream(nil)
body := make([]byte, wireguardSizeHandshakeInitiation-4+1)
binary.LittleEndian.PutUint32(body[0:4], 1)
u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeHandshakeInitiation, body))
if u != nil {
t.Error("should return nil for wrong init size")
}
}
func TestWireGuardUDPStream_Close(t *testing.T) {
s := newWireGuardUDPStream(nil)
u := s.Close(false)
if u != nil {
t.Error("Close() should return nil")
}
}
func TestWireGuardUDPStream_PutSenderIndex_MatchReceiverIndex(t *testing.T) {
s := newWireGuardUDPStream(nil)
s.putSenderIndex(false, 0xdeadbeef)
if !s.matchReceiverIndex(true, 0xdeadbeef) {
t.Error("should match reverse direction")
}
if s.matchReceiverIndex(false, 0xdeadbeef) {
t.Error("should not match same direction")
}
if s.matchReceiverIndex(true, 0x12345678) {
t.Error("should not match wrong index")
}
}
func TestWireGuardUDPStream_RememberedIndexRing(t *testing.T) {
s := newWireGuardUDPStream(nil)
for i := uint32(0); i < wireguardRememberedIndexCount+2; i++ {
s.putSenderIndex(false, i)
}
if s.matchReceiverIndex(true, 0) {
t.Error("index 0 should have been evicted from ring")
}
if !s.matchReceiverIndex(true, uint32(wireguardRememberedIndexCount+1)) {
t.Error("latest index should still be in ring")
}
}
func TestWireGuardAnalyzer_Name(t *testing.T) {
a := &WireGuardAnalyzer{}
if a.Name() != "wireguard" {
t.Errorf("Name() = %q, want wireguard", a.Name())
}
}

View File

@@ -0,0 +1,321 @@
package utils
import (
"bytes"
"reflect"
"testing"
)
func TestByteBuffer_Append(t *testing.T) {
b := &ByteBuffer{}
b.Append([]byte("hello"))
b.Append([]byte(" "))
b.Append([]byte("world"))
if string(b.Buf) != "hello world" {
t.Errorf("Append() result = %q, want %q", string(b.Buf), "hello world")
}
}
func TestByteBuffer_Len(t *testing.T) {
b := &ByteBuffer{}
if b.Len() != 0 {
t.Errorf("Len() = %d, want 0", b.Len())
}
b.Append([]byte("abc"))
if b.Len() != 3 {
t.Errorf("Len() = %d, want 3", b.Len())
}
}
func TestByteBuffer_Index(t *testing.T) {
b := &ByteBuffer{Buf: []byte("hello world")}
if i := b.Index([]byte("world")); i != 6 {
t.Errorf("Index('world') = %d, want 6", i)
}
if i := b.Index([]byte("xyz")); i != -1 {
t.Errorf("Index('xyz') = %d, want -1", i)
}
if i := b.Index([]byte{}); i != 0 {
t.Errorf("Index('') = %d, want 0", i)
}
}
func TestByteBuffer_Get(t *testing.T) {
b := &ByteBuffer{Buf: []byte("abcdef")}
data, ok := b.Get(3, true)
if !ok {
t.Fatal("Get(3, true) returned false")
}
if !bytes.Equal(data, []byte("abc")) {
t.Errorf("Get(3, true) data = %q, want %q", data, "abc")
}
if !bytes.Equal(b.Buf, []byte("def")) {
t.Errorf("after consume, Buf = %q, want %q", b.Buf, "def")
}
data, ok = b.Get(4, false)
if ok {
t.Fatal("Get(4, false) should return false (only 3 bytes left)")
}
data, ok = b.Get(2, false)
if !ok {
t.Fatal("Get(2, false) returned false")
}
if !bytes.Equal(data, []byte("de")) {
t.Errorf("Get(2, false) data = %q, want %q", data, "de")
}
if !bytes.Equal(b.Buf, []byte("def")) {
t.Errorf("after non-consume, Buf = %q, want %q (unchanged)", b.Buf, "def")
}
data, ok = b.Get(3, true)
if !ok {
t.Fatal("Get(3, true) returned false")
}
if !bytes.Equal(data, []byte("def")) {
t.Errorf("Get(3, true) data = %q, want %q", data, "def")
}
if b.Len() != 0 {
t.Errorf("after consume all, Len() = %d, want 0", b.Len())
}
}
func TestByteBuffer_GetString(t *testing.T) {
b := &ByteBuffer{Buf: []byte("hello world")}
s, ok := b.GetString(5, true)
if !ok {
t.Fatal("GetString(5, true) returned false")
}
if s != "hello" {
t.Errorf("GetString() = %q, want %q", s, "hello")
}
if b.Len() != 6 {
t.Errorf("after consume, Len() = %d, want 6", b.Len())
}
s, ok = b.GetString(7, false)
if ok {
t.Fatal("GetString(7, false) should return false")
}
s, ok = b.GetString(6, false)
if !ok {
t.Fatal("GetString(6, false) returned false")
}
if s != " world" {
t.Errorf("GetString() = %q, want %q", s, " world")
}
}
func TestByteBuffer_GetByte(t *testing.T) {
b := &ByteBuffer{Buf: []byte("abc")}
bt, ok := b.GetByte(true)
if !ok {
t.Fatal("GetByte(true) returned false")
}
if bt != 'a' {
t.Errorf("GetByte() = %c, want 'a'", bt)
}
if b.Len() != 2 {
t.Errorf("after consume, Len() = %d, want 2", b.Len())
}
bt, ok = b.GetByte(false)
if !ok {
t.Fatal("GetByte(false) returned false")
}
if bt != 'b' {
t.Errorf("GetByte() = %c, want 'b'", bt)
}
if b.Len() != 2 {
t.Errorf("after non-consume, Len() = %d, want 2 (unchanged)", b.Len())
}
b.GetByte(true)
b.GetByte(true)
bt, ok = b.GetByte(true)
if ok {
t.Fatal("GetByte(true) on empty buffer should return false")
}
}
func TestByteBuffer_GetUint16(t *testing.T) {
b := &ByteBuffer{Buf: []byte{0x01, 0x02, 0x03, 0x04}}
v, ok := b.GetUint16(false, true)
if !ok {
t.Fatal("GetUint16(bigEndian) returned false")
}
if v != 0x0102 {
t.Errorf("GetUint16(bigEndian) = 0x%04x, want 0x0102", v)
}
if b.Len() != 2 {
t.Errorf("after consume, Len() = %d, want 2", b.Len())
}
v, ok = b.GetUint16(true, true)
if !ok {
t.Fatal("GetUint16(littleEndian) returned false")
}
if v != 0x0403 {
t.Errorf("GetUint16(littleEndian) = 0x%04x, want 0x0403", v)
}
v, ok = b.GetUint16(false, false)
if ok {
t.Fatal("GetUint16 on empty buffer should return false")
}
}
func TestByteBuffer_GetUint32(t *testing.T) {
b := &ByteBuffer{Buf: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}}
v, ok := b.GetUint32(false, true)
if !ok {
t.Fatal("GetUint32(bigEndian) returned false")
}
if v != 0x01020304 {
t.Errorf("GetUint32(bigEndian) = 0x%08x, want 0x01020304", v)
}
if b.Len() != 4 {
t.Errorf("after consume, Len() = %d, want 4", b.Len())
}
v, ok = b.GetUint32(true, true)
if !ok {
t.Fatal("GetUint32(littleEndian) returned false")
}
if v != 0x08070605 {
t.Errorf("GetUint32(littleEndian) = 0x%08x, want 0x08070605", v)
}
v, ok = b.GetUint32(false, false)
if ok {
t.Fatal("GetUint32 on empty buffer should return false")
}
}
func TestByteBuffer_GetUntil(t *testing.T) {
b := &ByteBuffer{Buf: []byte("hello\r\nworld\r\n")}
data, ok := b.GetUntil([]byte("\r\n"), true, true)
if !ok {
t.Fatal("GetUntil(sep, include) returned false")
}
if !bytes.Equal(data, []byte("hello\r\n")) {
t.Errorf("GetUntil(include) = %q, want %q", data, "hello\\r\\n")
}
data, ok = b.GetUntil([]byte("\r\n"), false, false)
if !ok {
t.Fatal("GetUntil(sep, exclude, non-consume) returned false")
}
if !bytes.Equal(data, []byte("world")) {
t.Errorf("GetUntil(exclude) = %q, want %q", data, "world")
}
if b.Len() != 7 {
t.Errorf("after non-consume, Len() = %d, want 7", b.Len())
}
data, ok = b.GetUntil([]byte("\r\n"), true, true)
if !ok {
t.Fatal("GetUntil second (consume) returned false")
}
if !bytes.Equal(data, []byte("world\r\n")) {
t.Errorf("GetUntil second = %q, want %q", data, "world\\r\\n")
}
_, ok = b.GetUntil([]byte("xyz"), false, false)
if ok {
t.Fatal("GetUntil(not found) should return false")
}
}
func TestByteBuffer_GetSubBuffer(t *testing.T) {
b := &ByteBuffer{Buf: []byte("hello world")}
sub, ok := b.GetSubBuffer(5, true)
if !ok {
t.Fatal("GetSubBuffer() returned false")
}
if !bytes.Equal(sub.Buf, []byte("hello")) {
t.Errorf("GetSubBuffer() = %q, want %q", sub.Buf, "hello")
}
if b.Len() != 6 {
t.Errorf("after consume, Len() = %d, want 6", b.Len())
}
_, ok = b.GetSubBuffer(7, false)
if ok {
t.Fatal("GetSubBuffer(7) should return false (only 6 bytes left)")
}
}
func TestByteBuffer_Skip(t *testing.T) {
b := &ByteBuffer{Buf: []byte("abcdef")}
ok := b.Skip(2)
if !ok {
t.Fatal("Skip(2) returned false")
}
if !bytes.Equal(b.Buf, []byte("cdef")) {
t.Errorf("after Skip(2), Buf = %q, want %q", b.Buf, "cdef")
}
ok = b.Skip(10)
if ok {
t.Fatal("Skip(10) should return false")
}
if !bytes.Equal(b.Buf, []byte("cdef")) {
t.Errorf("after failed Skip, Buf = %q, want %q (unchanged)", b.Buf, "cdef")
}
ok = b.Skip(4)
if !ok {
t.Fatal("Skip(4) returned false")
}
if b.Len() != 0 {
t.Errorf("after Skip all, Len() = %d, want 0", b.Len())
}
}
func TestByteBuffer_Reset(t *testing.T) {
b := &ByteBuffer{Buf: []byte("data")}
b.Reset()
if b.Buf != nil {
t.Errorf("after Reset, Buf = %v, want nil", b.Buf)
}
}
func TestByteBuffer_GetZeroLength(t *testing.T) {
b := &ByteBuffer{Buf: []byte("abc")}
data, ok := b.Get(0, true)
if !ok {
t.Fatal("Get(0) returned false")
}
if len(data) != 0 {
t.Errorf("Get(0) len = %d, want 0", len(data))
}
if b.Len() != 3 {
t.Errorf("after Get(0, consume), Len() = %d, want 3 (0-length consume is no-op)", b.Len())
}
}
func TestByteBuffer_GetConsumeDoesNotMutateReturnedSlice(t *testing.T) {
b := &ByteBuffer{Buf: []byte("hello")}
data, ok := b.Get(5, true)
if !ok {
t.Fatal("Get() returned false")
}
if !reflect.DeepEqual(data, []byte("hello")) {
t.Errorf("Get() returned wrong data: %v", data)
}
if b.Len() != 0 {
t.Errorf("after consume, Len() should be 0")
}
}

185
analyzer/utils/lsm_test.go Normal file
View File

@@ -0,0 +1,185 @@
package utils
import "testing"
func TestLinearStateMachine_RunPause(t *testing.T) {
callCount := 0
lsm := NewLinearStateMachine(
func() LSMAction {
callCount++
return LSMActionPause
},
)
cancelled, done := lsm.Run()
if cancelled {
t.Error("unexpected cancelled=true")
}
if done {
t.Error("unexpected done=true")
}
if callCount != 1 {
t.Errorf("callCount = %d, want 1", callCount)
}
}
func TestLinearStateMachine_RunNext(t *testing.T) {
callOrder := []int{}
lsm := NewLinearStateMachine(
func() LSMAction { callOrder = append(callOrder, 1); return LSMActionNext },
func() LSMAction { callOrder = append(callOrder, 2); return LSMActionNext },
func() LSMAction { callOrder = append(callOrder, 3); return LSMActionNext },
)
cancelled, done := lsm.Run()
if cancelled {
t.Error("unexpected cancelled=true")
}
if !done {
t.Error("unexpected done=false")
}
if len(callOrder) != 3 {
t.Fatalf("callOrder len = %d, want 3", len(callOrder))
}
for i, v := range []int{1, 2, 3} {
if callOrder[i] != v {
t.Errorf("callOrder[%d] = %d, want %d", i, callOrder[i], v)
}
}
}
func TestLinearStateMachine_RunReset(t *testing.T) {
callCount := 0
lsm := NewLinearStateMachine(
func() LSMAction {
callCount++
if callCount == 1 {
return LSMActionReset
}
return LSMActionNext
},
func() LSMAction { callCount++; return LSMActionNext },
)
cancelled, done := lsm.Run()
if cancelled {
t.Error("unexpected cancelled=true")
}
if !done {
t.Error("unexpected done=false")
}
if callCount != 3 {
t.Errorf("callCount = %d, want 3 (step0 reset, step0 next, step1 next)", callCount)
}
}
func TestLinearStateMachine_RunCancel(t *testing.T) {
callCount := 0
lsm := NewLinearStateMachine(
func() LSMAction { callCount++; return LSMActionNext },
func() LSMAction { callCount++; return LSMActionCancel },
func() LSMAction { callCount++; return LSMActionNext },
)
cancelled, done := lsm.Run()
if !cancelled {
t.Error("unexpected cancelled=false")
}
if !done {
t.Error("unexpected done=false")
}
if callCount != 2 {
t.Errorf("callCount = %d, want 2 (third step should not execute)", callCount)
}
}
func TestLinearStateMachine_RunMixed(t *testing.T) {
pauseCount := 0
lsm := NewLinearStateMachine(
func() LSMAction { return LSMActionNext },
func() LSMAction {
pauseCount++
if pauseCount == 1 {
return LSMActionPause
}
return LSMActionNext
},
func() LSMAction { return LSMActionNext },
func() LSMAction { return LSMActionNext },
)
cancelled, done := lsm.Run()
if cancelled {
t.Error("unexpected cancelled=true")
}
if done {
t.Error("unexpected done=true on first run")
}
cancelled, done = lsm.Run()
if cancelled {
t.Error("unexpected cancelled=true on second run")
}
if !done {
t.Error("unexpected done=false on second run")
}
}
func TestLinearStateMachine_RunEmpty(t *testing.T) {
lsm := NewLinearStateMachine()
cancelled, done := lsm.Run()
if cancelled {
t.Error("unexpected cancelled=true")
}
if !done {
t.Error("unexpected done=false for empty LSM")
}
}
func TestLinearStateMachine_AppendSteps(t *testing.T) {
lsm := NewLinearStateMachine(
func() LSMAction { return LSMActionNext },
)
lsm.Run()
lsm.AppendSteps(
func() LSMAction { return LSMActionNext },
)
_, done := lsm.Run()
if !done {
t.Error("unexpected done=false after AppendSteps")
}
}
func TestLinearStateMachine_Reset(t *testing.T) {
callCount := 0
lsm := NewLinearStateMachine(
func() LSMAction { callCount++; return LSMActionCancel },
)
lsm.Run()
if !lsm.cancelled {
t.Error("expected cancelled=true after cancel")
}
lsm.Reset()
if lsm.cancelled {
t.Error("expected cancelled=false after Reset")
}
if lsm.index != 0 {
t.Errorf("expected index=0 after Reset, got %d", lsm.index)
}
_, done := lsm.Run()
if !done {
t.Error("expected done=true, step executed again after Reset")
}
if callCount != 2 {
t.Errorf("callCount = %d, want 2 (first run + reset run)", callCount)
}
}
func TestLSMActionConstants(t *testing.T) {
if LSMActionPause != 0 {
t.Errorf("LSMActionPause = %d, want 0", LSMActionPause)
}
if LSMActionNext != 1 {
t.Errorf("LSMActionNext = %d, want 1", LSMActionNext)
}
if LSMActionReset != 2 {
t.Errorf("LSMActionReset = %d, want 2", LSMActionReset)
}
if LSMActionCancel != 3 {
t.Errorf("LSMActionCancel = %d, want 3", LSMActionCancel)
}
}

View File

@@ -0,0 +1,29 @@
package utils
import (
"reflect"
"testing"
)
func TestByteSlicesToStrings(t *testing.T) {
tests := []struct {
name string
input [][]byte
want []string
}{
{"nil", nil, []string{}},
{"empty", [][]byte{}, []string{}},
{"single", [][]byte{[]byte("hello")}, []string{"hello"}},
{"multiple", [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, []string{"foo", "bar", "baz"}},
{"empty element", [][]byte{[]byte("a"), []byte{}, []byte("b")}, []string{"a", "", "b"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ByteSlicesToStrings(tt.input)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ByteSlicesToStrings() = %v, want %v", got, tt.want)
}
})
}
}

45
errors_test.go Normal file
View File

@@ -0,0 +1,45 @@
package mellaris
import (
"errors"
"testing"
)
func TestConfigError_Error(t *testing.T) {
e := ConfigError{Field: "port", Err: errors.New("must be > 0")}
want := "invalid config: port: must be > 0"
if got := e.Error(); got != want {
t.Errorf("Error() = %q, want %q", got, want)
}
}
func TestConfigError_Error_NilErr(t *testing.T) {
e := ConfigError{Field: "host", Err: nil}
want := "invalid config: host: %!s(<nil>)"
if got := e.Error(); got != want {
t.Errorf("Error() = %q, want %q", got, want)
}
}
func TestConfigError_Unwrap(t *testing.T) {
wrapped := errors.New("inner error")
e := ConfigError{Field: "field", Err: wrapped}
if got := e.Unwrap(); got != wrapped {
t.Errorf("Unwrap() = %v, want %v", got, wrapped)
}
}
func TestConfigError_Unwrap_Nil(t *testing.T) {
e := ConfigError{Field: "field"}
if got := e.Unwrap(); got != nil {
t.Errorf("Unwrap() = %v, want nil", got)
}
}
func TestConfigError_ErrorsIs(t *testing.T) {
wrapped := errors.New("inner")
e := ConfigError{Field: "x", Err: wrapped}
if !errors.Is(e, wrapped) {
t.Error("errors.Is should find wrapped error")
}
}

View File

@@ -0,0 +1,22 @@
package modifier
import (
"errors"
"testing"
)
func TestErrInvalidPacket_Error(t *testing.T) {
e := &ErrInvalidPacket{Err: errors.New("bad checksum")}
want := "invalid packet: bad checksum"
if got := e.Error(); got != want {
t.Errorf("Error() = %q, want %q", got, want)
}
}
func TestErrInvalidArgs_Error(t *testing.T) {
e := &ErrInvalidArgs{Err: errors.New("missing 'a' arg")}
want := "invalid args: missing 'a' arg"
if got := e.Error(); got != want {
t.Errorf("Error() = %q, want %q", got, want)
}
}

205
modifier/udp/dns_test.go Normal file
View File

@@ -0,0 +1,205 @@
package udp
import (
"testing"
"git.difuse.io/Difuse/Mellaris/modifier"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
func dnsResponseBytes(t *testing.T, id uint16, qtype layers.DNSType, name string) []byte {
t.Helper()
dns := &layers.DNS{
ID: id,
QR: true,
OpCode: layers.DNSOpCodeQuery,
AA: false,
RD: true,
RA: true,
ResponseCode: layers.DNSResponseCodeNoErr,
QDCount: 1,
Questions: []layers.DNSQuestion{
{
Name: []byte(name),
Type: qtype,
Class: layers.DNSClassIN,
},
},
}
buf := gopacket.NewSerializeBuffer()
err := gopacket.SerializeLayers(buf, gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}, dns)
if err != nil {
t.Fatalf("failed to serialize DNS response: %v", err)
}
return buf.Bytes()
}
func TestDNSModifier_Name(t *testing.T) {
m := &DNSModifier{}
if m.Name() != "dns" {
t.Errorf("Name() = %q, want dns", m.Name())
}
}
func TestDNSModifier_New_NoArgs(t *testing.T) {
m := &DNSModifier{}
inst, err := m.New(map[string]interface{}{})
if err != nil {
t.Fatalf("New() unexpected error: %v", err)
}
if inst == nil {
t.Fatal("New() returned nil instance")
}
}
func TestDNSModifier_New_WithA(t *testing.T) {
m := &DNSModifier{}
inst, err := m.New(map[string]interface{}{
"a": "192.168.1.1",
})
if err != nil {
t.Fatalf("New() unexpected error: %v", err)
}
dnsInst, ok := inst.(*dnsModifierInstance)
if !ok {
t.Fatalf("New() returned wrong type: %T", inst)
}
if dnsInst.A == nil || dnsInst.A.String() != "192.168.1.1" {
t.Errorf("A = %v, want 192.168.1.1", dnsInst.A)
}
}
func TestDNSModifier_New_WithAAAA(t *testing.T) {
m := &DNSModifier{}
inst, err := m.New(map[string]interface{}{
"aaaa": "2001:db8::1",
})
if err != nil {
t.Fatalf("New() unexpected error: %v", err)
}
dnsInst, ok := inst.(*dnsModifierInstance)
if !ok {
t.Fatalf("New() returned wrong type: %T", inst)
}
if dnsInst.AAAA == nil || dnsInst.AAAA.String() != "2001:db8::1" {
t.Errorf("AAAA = %v, want 2001:db8::1", dnsInst.AAAA)
}
}
func TestDNSModifier_New_InvalidA(t *testing.T) {
m := &DNSModifier{}
_, err := m.New(map[string]interface{}{
"a": "not-an-ip",
})
if err == nil {
t.Fatal("New() expected error for invalid A")
}
}
func TestDNSModifier_New_InvalidAAAA(t *testing.T) {
m := &DNSModifier{}
_, err := m.New(map[string]interface{}{
"aaaa": "not-an-ip",
})
if err == nil {
t.Fatal("New() expected error for invalid AAAA")
}
}
func TestDNSModifierInstance_Process_AType(t *testing.T) {
m := &DNSModifier{}
inst, err := m.New(map[string]interface{}{
"a": "10.10.10.10",
})
if err != nil {
t.Fatal(err)
}
dnsResp := dnsResponseBytes(t, 1, layers.DNSTypeA, "example.com")
modified, err := inst.(*dnsModifierInstance).Process(dnsResp)
if err != nil {
t.Fatalf("Process() error: %v", err)
}
result := &layers.DNS{}
err = result.DecodeFromBytes(modified, gopacket.NilDecodeFeedback)
if err != nil {
t.Fatalf("decode result: %v", err)
}
if len(result.Answers) != 1 {
t.Fatalf("expected 1 answer, got %d", len(result.Answers))
}
if result.Answers[0].Type != layers.DNSTypeA {
t.Errorf("answer type = %d, want A", result.Answers[0].Type)
}
if result.Answers[0].IP.String() != "10.10.10.10" {
t.Errorf("answer IP = %s, want 10.10.10.10", result.Answers[0].IP)
}
}
func TestDNSModifierInstance_Process_AAAAType(t *testing.T) {
m := &DNSModifier{}
inst, err := m.New(map[string]interface{}{
"aaaa": "2001:db8::1",
})
if err != nil {
t.Fatal(err)
}
dnsResp := dnsResponseBytes(t, 1, layers.DNSTypeAAAA, "example.com")
modified, err := inst.(*dnsModifierInstance).Process(dnsResp)
if err != nil {
t.Fatalf("Process() error: %v", err)
}
result := &layers.DNS{}
err = result.DecodeFromBytes(modified, gopacket.NilDecodeFeedback)
if err != nil {
t.Fatalf("decode result: %v", err)
}
if len(result.Answers) != 1 {
t.Fatalf("expected 1 answer, got %d", len(result.Answers))
}
if result.Answers[0].Type != layers.DNSTypeAAAA {
t.Errorf("answer type = %d, want AAAA", result.Answers[0].Type)
}
}
func TestDNSModifierInstance_Process_NotAResponse(t *testing.T) {
inst := &dnsModifierInstance{}
query := dnsResponseBytes(t, 1, layers.DNSTypeA, "test.com")
query[2] &^= 0x80 // clear QR bit
_, err := inst.Process(query)
if err == nil {
t.Fatal("expected error for non-response")
}
errPkt, ok := err.(*modifier.ErrInvalidPacket)
if !ok {
t.Fatalf("expected *ErrInvalidPacket, got %T", err)
}
_ = errPkt
}
func TestDNSModifierInstance_Process_NoQuestions(t *testing.T) {
inst := &dnsModifierInstance{}
dns := &layers.DNS{
ID: 1,
QR: true,
RD: true,
RA: true,
ResponseCode: layers.DNSResponseCodeNoErr,
}
buf := gopacket.NewSerializeBuffer()
gopacket.SerializeLayers(buf, gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}, dns)
_, err := inst.Process(buf.Bytes())
if err == nil {
t.Fatal("expected error for empty questions")
}
}

View File

@@ -0,0 +1,98 @@
package builtins
import (
"net"
"testing"
)
func TestCompileCIDR(t *testing.T) {
tests := []struct {
name string
input string
wantErr bool
wantStr string
}{
{"valid ipv4", "192.168.0.0/24", false, "192.168.0.0/24"},
{"valid ipv6", "2001:db8::/32", false, "2001:db8::/32"},
{"valid host ipv4", "10.0.0.1/32", false, "10.0.0.1/32"},
{"valid host ipv6", "::1/128", false, "::1/128"},
{"invalid no mask", "192.168.0.0", true, ""},
{"invalid bad ip", "not-an-ip/24", true, ""},
{"invalid empty", "", true, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := CompileCIDR(tt.input)
if tt.wantErr {
if err == nil {
t.Errorf("CompileCIDR(%q) expected error, got nil", tt.input)
}
return
}
if err != nil {
t.Fatalf("CompileCIDR(%q) unexpected error: %v", tt.input, err)
}
if got.String() != tt.wantStr {
t.Errorf("CompileCIDR(%q) = %q, want %q", tt.input, got.String(), tt.wantStr)
}
})
}
}
func TestMatchCIDR(t *testing.T) {
cidr := mustCompileCIDR(t, "192.168.0.0/24")
tests := []struct {
name string
ip string
want bool
}{
{"inside", "192.168.0.1", true},
{"boundary low", "192.168.0.0", true},
{"boundary high", "192.168.0.255", true},
{"outside", "192.168.1.1", false},
{"different network", "10.0.0.1", false},
{"invalid ip", "not-an-ip", false},
{"empty", "", false},
{"ipv6 in ipv4", "::1", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := MatchCIDR(tt.ip, cidr)
if got != tt.want {
t.Errorf("MatchCIDR(%q, %q) = %v, want %v", tt.ip, cidr, got, tt.want)
}
})
}
}
func TestMatchCIDR_IPv6(t *testing.T) {
cidr := mustCompileCIDR(t, "2001:db8::/32")
inside := "2001:db8::1"
if !MatchCIDR(inside, cidr) {
t.Errorf("MatchCIDR(%q) should be true", inside)
}
outside := "2001:db9::1"
if MatchCIDR(outside, cidr) {
t.Errorf("MatchCIDR(%q) should be false", outside)
}
}
func TestMatchCIDR_NullResult(t *testing.T) {
if MatchCIDR("10.0.0.1", &net.IPNet{}) {
t.Error("MatchCIDR with empty IPNet should return false")
}
}
func mustCompileCIDR(t *testing.T, cidr string) *net.IPNet {
t.Helper()
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
t.Fatalf("failed to parse CIDR %q: %v", cidr, err)
}
return ipNet
}

View File

@@ -0,0 +1,115 @@
package geo
import (
"testing"
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
)
type fakeGeoLoader struct {
geoip map[string]*v2geo.GeoIP
geosite map[string]*v2geo.GeoSite
}
func (l *fakeGeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) {
return l.geoip, nil
}
func (l *fakeGeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) {
return l.geosite, nil
}
func TestGeoMatcher_MatchGeoIp_Cached(t *testing.T) {
loader := &fakeGeoLoader{
geoip: map[string]*v2geo.GeoIP{
"us": {
Cidr: []*v2geo.CIDR{
{Ip: ipv4(8, 8, 8, 0), Prefix: 24},
},
},
},
}
g := NewGeoMatcher("", "")
g.geoLoader = loader
if !g.MatchGeoIp("8.8.8.8", "US") {
t.Error("MatchGeoIp should match 8.8.8.8 in US range")
}
if g.MatchGeoIp("9.9.9.9", "US") {
t.Error("MatchGeoIp should not match 9.9.9.9 in US range")
}
}
func TestGeoMatcher_MatchGeoIp_EmptyCondition(t *testing.T) {
g := NewGeoMatcher("", "")
if g.MatchGeoIp("1.2.3.4", "") {
t.Error("MatchGeoIp with empty condition should return false")
}
}
func TestGeoMatcher_MatchGeoIp_InvalidIP(t *testing.T) {
g := NewGeoMatcher("", "")
if g.MatchGeoIp("not-an-ip", "us") {
t.Error("MatchGeoIp with invalid IP should return false")
}
}
func TestGeoMatcher_MatchGeoIp_MissingCountry(t *testing.T) {
loader := &fakeGeoLoader{
geoip: map[string]*v2geo.GeoIP{},
}
g := NewGeoMatcher("", "")
g.geoLoader = loader
if g.MatchGeoIp("8.8.8.8", "us") {
t.Error("MatchGeoIp for missing country should return false")
}
}
func TestGeoMatcher_MatchGeoSite(t *testing.T) {
loader := &fakeGeoLoader{
geosite: map[string]*v2geo.GeoSite{
"openai": {
Domain: []*v2geo.Domain{
{Type: v2geo.Domain_Plain, Value: "openai"},
{Type: v2geo.Domain_Full, Value: "chatgpt.com"},
},
},
},
}
g := NewGeoMatcher("", "")
g.geoLoader = loader
if !g.MatchGeoSite("api.openai.com", "openai") {
t.Error("MatchGeoSite should match via plain domain")
}
if !g.MatchGeoSite("chatgpt.com", "openai") {
t.Error("MatchGeoSite should match via full domain")
}
if g.MatchGeoSite("google.com", "openai") {
t.Error("MatchGeoSite should not match unrelated host")
}
}
func TestGeoMatcher_MatchGeoSite_EmptyCondition(t *testing.T) {
g := NewGeoMatcher("", "")
if g.MatchGeoSite("test.com", "") {
t.Error("MatchGeoSite with empty condition should return false")
}
}
func TestGeoMatcher_MatchGeoSite_MissingSite(t *testing.T) {
loader := &fakeGeoLoader{
geosite: map[string]*v2geo.GeoSite{},
}
g := NewGeoMatcher("", "")
g.geoLoader = loader
if g.MatchGeoSite("test.com", "nonexistent") {
t.Error("MatchGeoSite for missing site should return false")
}
}
func ipv4(a, b, c, d byte) []byte {
return []byte{a, b, c, d}
}

Binary file not shown.

View File

@@ -0,0 +1,324 @@
package geo
import (
"net"
"testing"
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
)
func TestParseGeoSiteName(t *testing.T) {
tests := []struct {
input string
wantBase string
wantAttrs []string
}{
{"google", "google", nil},
{"google@ads", "google", []string{"ads"}},
{"google@ads@news", "google", []string{"ads", "news"}},
{" google ", "google", nil},
{" google @ ads ", "google", []string{"ads"}},
{"openai@ ads @ news ", "openai", []string{"ads", "news"}},
{"@onlyattrs", "", []string{"onlyattrs"}},
{"", "", nil},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
base, attrs := parseGeoSiteName(tt.input)
if base != tt.wantBase {
t.Errorf("parseGeoSiteName(%q) base = %q, want %q", tt.input, base, tt.wantBase)
}
if len(attrs) != len(tt.wantAttrs) {
t.Fatalf("parseGeoSiteName(%q) attrs len = %d, want %d", tt.input, len(attrs), len(tt.wantAttrs))
}
for i, attr := range attrs {
if attr != tt.wantAttrs[i] {
t.Errorf("parseGeoSiteName(%q) attrs[%d] = %q, want %q", tt.input, i, attr, tt.wantAttrs[i])
}
}
})
}
}
func TestHostInfo_String(t *testing.T) {
h := HostInfo{
Name: "example.com",
IPv4: net.ParseIP("1.2.3.4"),
IPv6: net.ParseIP("::1"),
}
want := "example.com|1.2.3.4|::1"
if got := h.String(); got != want {
t.Errorf("HostInfo.String() = %q, want %q", got, want)
}
}
func TestHostInfo_String_Partial(t *testing.T) {
h := HostInfo{
Name: "test.com",
IPv4: net.ParseIP("10.0.0.1"),
}
want := "test.com|10.0.0.1|<nil>"
if got := h.String(); got != want {
t.Errorf("HostInfo.String() = %q, want %q", got, want)
}
}
func TestGeoipMatcher_Match(t *testing.T) {
_, n4, _ := net.ParseCIDR("10.0.0.0/8")
_, n4_2, _ := net.ParseCIDR("192.168.0.0/16")
m := &geoipMatcher{
N4: []*net.IPNet{n4, n4_2},
}
tests := []struct {
name string
host HostInfo
want bool
}{
{"ipv4 match", HostInfo{IPv4: net.ParseIP("10.1.2.3")}, true},
{"ipv4 no match", HostInfo{IPv4: net.ParseIP("172.16.0.1")}, false},
{"ipv4 match second net", HostInfo{IPv4: net.ParseIP("192.168.1.1")}, true},
{"no ip", HostInfo{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := m.Match(tt.host); got != tt.want {
t.Errorf("Match() = %v, want %v", got, tt.want)
}
})
}
}
func TestGeoipMatcher_Match_Inverse(t *testing.T) {
_, n4, _ := net.ParseCIDR("10.0.0.0/8")
m := &geoipMatcher{
N4: []*net.IPNet{n4},
Inverse: true,
}
if m.Match(HostInfo{IPv4: net.ParseIP("10.1.2.3")}) {
t.Error("Inverse: inside range should return false")
}
if !m.Match(HostInfo{IPv4: net.ParseIP("172.16.0.1")}) {
t.Error("Inverse: outside range should return true")
}
if !m.Match(HostInfo{}) {
t.Error("Inverse: no IP should return true")
}
}
func TestGeoipMatcher_Match_IPv6(t *testing.T) {
_, n6, _ := net.ParseCIDR("2001:db8::/32")
m := &geoipMatcher{
N6: []*net.IPNet{n6},
}
if !m.Match(HostInfo{IPv6: net.ParseIP("2001:db8::1")}) {
t.Error("IPv6 match failed")
}
if m.Match(HostInfo{IPv6: net.ParseIP("2001:db9::1")}) {
t.Error("IPv6 should not match")
}
}
func TestGeositeMatcher_matchDomain_Plain(t *testing.T) {
m := &geositeMatcher{}
d := geositeDomain{
Type: geositeDomainPlain,
Value: "openai",
}
if !m.matchDomain(d, HostInfo{Name: "api.openai.com"}) {
t.Error("plain domain should match via substring")
}
if m.matchDomain(d, HostInfo{Name: "google.com"}) {
t.Error("plain domain should not match unrelated host")
}
}
func TestGeositeMatcher_matchDomain_Full(t *testing.T) {
m := &geositeMatcher{}
d := geositeDomain{
Type: geositeDomainFull,
Value: "example.com",
}
if !m.matchDomain(d, HostInfo{Name: "example.com"}) {
t.Error("full domain should match exact")
}
if m.matchDomain(d, HostInfo{Name: "www.example.com"}) {
t.Error("full domain should not match subdomain")
}
}
func TestGeositeMatcher_matchDomain_Root(t *testing.T) {
m := &geositeMatcher{}
d := geositeDomain{
Type: geositeDomainRoot,
Value: "example.com",
}
if !m.matchDomain(d, HostInfo{Name: "example.com"}) {
t.Error("root domain should match exact")
}
if !m.matchDomain(d, HostInfo{Name: "www.example.com"}) {
t.Error("root domain should match subdomain")
}
if m.matchDomain(d, HostInfo{Name: "www.example.com.au"}) {
t.Error("root domain should not match unrelated suffix")
}
}
func TestGeositeMatcher_matchDomain_Attrs(t *testing.T) {
m := &geositeMatcher{Attrs: []string{"ads"}}
d := geositeDomain{
Type: geositeDomainPlain,
Value: "google",
Attrs: map[string]bool{"ads": true},
}
if !m.matchDomain(d, HostInfo{Name: "google.com"}) {
t.Error("should match when domain has required attr")
}
dNoAttrs := geositeDomain{
Type: geositeDomainPlain,
Value: "google",
Attrs: map[string]bool{},
}
if m.matchDomain(dNoAttrs, HostInfo{Name: "google.com"}) {
t.Error("should not match when domain lacks required attr")
}
dOtherAttrs := geositeDomain{
Type: geositeDomainPlain,
Value: "google",
Attrs: map[string]bool{"news": true},
}
if m.matchDomain(dOtherAttrs, HostInfo{Name: "google.com"}) {
t.Error("should not match when domain has wrong attr")
}
}
func TestGeositeMatcher_Match(t *testing.T) {
m := &geositeMatcher{
Domains: []geositeDomain{
{Type: geositeDomainFull, Value: "exact.com"},
{Type: geositeDomainPlain, Value: "partial"},
},
}
if !m.Match(HostInfo{Name: "exact.com"}) {
t.Error("should match full domain")
}
if !m.Match(HostInfo{Name: "www.partial.net"}) {
t.Error("should match partial domain")
}
if m.Match(HostInfo{Name: "other.net"}) {
t.Error("should not match unrelated host")
}
}
func TestDomainAttributeToMap(t *testing.T) {
attrs := []*v2geo.Domain_Attribute{
{Key: "ads"},
{Key: "news"},
}
got := domainAttributeToMap(attrs)
if len(got) != 2 || !got["ads"] || !got["news"] {
t.Errorf("domainAttributeToMap = %v, want {ads:true, news:true}", got)
}
got2 := domainAttributeToMap(nil)
if len(got2) != 0 {
t.Errorf("domainAttributeToMap(nil) = %v, want empty map", got2)
}
}
func TestNewGeoIPMatcher(t *testing.T) {
list := &v2geo.GeoIP{
Cidr: []*v2geo.CIDR{
{Ip: net.IPv4(10, 0, 0, 0).To4(), Prefix: 8},
{Ip: net.IPv4(192, 168, 0, 0).To4(), Prefix: 16},
},
InverseMatch: false,
}
m, err := newGeoIPMatcher(list)
if err != nil {
t.Fatalf("newGeoIPMatcher error: %v", err)
}
if len(m.N4) != 2 {
t.Errorf("expected 2 IPv4 nets, got %d", len(m.N4))
}
if m.Inverse {
t.Error("Inverse should be false")
}
// Verify sorted order: 10.0.0.0/8 < 192.168.0.0/16
if m.N4[0].IP.String() != "10.0.0.0" {
t.Errorf("N4[0] = %s, want 10.0.0.0", m.N4[0].IP)
}
if m.N4[1].IP.String() != "192.168.0.0" {
t.Errorf("N4[1] = %s, want 192.168.0.0", m.N4[1].IP)
}
}
func TestNewGeoIPMatcher_IPv6(t *testing.T) {
list := &v2geo.GeoIP{
Cidr: []*v2geo.CIDR{
{Ip: net.ParseIP("2001:db8::"), Prefix: 32},
},
}
m, err := newGeoIPMatcher(list)
if err != nil {
t.Fatalf("newGeoIPMatcher error: %v", err)
}
if len(m.N6) != 1 {
t.Errorf("expected 1 IPv6 net, got %d", len(m.N6))
}
}
func TestNewGeoIPMatcher_InvalidIPLength(t *testing.T) {
list := &v2geo.GeoIP{
Cidr: []*v2geo.CIDR{
{Ip: []byte{1, 2, 3}, Prefix: 24},
},
}
_, err := newGeoIPMatcher(list)
if err == nil {
t.Error("expected error for invalid IP length")
}
}
func TestNewGeositeMatcher(t *testing.T) {
list := &v2geo.GeoSite{
Domain: []*v2geo.Domain{
{Type: v2geo.Domain_Plain, Value: "google"},
{Type: v2geo.Domain_Full, Value: "exact.com"},
},
}
m, err := newGeositeMatcher(list, nil)
if err != nil {
t.Fatalf("newGeositeMatcher error: %v", err)
}
if len(m.Domains) != 2 {
t.Errorf("expected 2 domains, got %d", len(m.Domains))
}
}
func TestNewGeositeMatcher_WithAttrs(t *testing.T) {
list := &v2geo.GeoSite{
Domain: []*v2geo.Domain{
{
Type: v2geo.Domain_RootDomain,
Value: "google.com",
Attribute: []*v2geo.Domain_Attribute{
{Key: "ads"},
},
},
},
}
m, err := newGeositeMatcher(list, []string{"ads"})
if err != nil {
t.Fatalf("newGeositeMatcher error: %v", err)
}
if !m.Match(HostInfo{Name: "www.google.com"}) {
t.Error("should match with root domain and attr")
}
}

View File

@@ -31,6 +31,9 @@ type ExprRule struct {
Log bool `yaml:"log"`
Modifier ModifierEntry `yaml:"modifier"`
Expr string `yaml:"expr"`
StartTime string `yaml:"start_time"`
StopTime string `yaml:"stop_time"`
Weekdays []string `yaml:"weekdays"`
}
type ModifierEntry struct {
@@ -56,6 +59,10 @@ type compiledExprRule struct {
ModInstance modifier.Instance
Program *vm.Program
GeoSiteConditions []string
StartTimeSecs int // seconds since midnight, -1 if unset
StopTimeSecs int // seconds since midnight, -1 if unset
Weekdays []time.Weekday
WeekdaysNegated bool
}
var _ Ruleset = (*exprRuleset)(nil)
@@ -73,10 +80,13 @@ func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
func (r *exprRuleset) Match(info StreamInfo) MatchResult {
env := streamInfoToExprEnv(info)
now := time.Now()
for _, rule := range r.Rules {
if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) {
continue
}
v, err := vm.Run(rule.Program, env)
if err != nil {
// Log the error and continue to the next rule.
r.Logger.MatchError(info, rule.Name, err)
continue
}
@@ -163,12 +173,34 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
depAnMap[name] = a
}
}
startSecs := -1
if rule.StartTime != "" {
startSecs, err = parseTimeOfDay(rule.StartTime)
if err != nil {
return nil, fmt.Errorf("rule %q has invalid start_time: %w", rule.Name, err)
}
}
stopSecs := -1
if rule.StopTime != "" {
stopSecs, err = parseTimeOfDay(rule.StopTime)
if err != nil {
return nil, fmt.Errorf("rule %q has invalid stop_time: %w", rule.Name, err)
}
}
weekdays, weekdaysNegated, err := parseWeekdays(rule.Weekdays)
if err != nil {
return nil, fmt.Errorf("rule %q has invalid weekdays: %w", rule.Name, err)
}
cr := compiledExprRule{
Name: rule.Name,
Action: action,
Log: rule.Log,
Program: program,
GeoSiteConditions: extractGeoSiteConditions(rule.Expr),
StartTimeSecs: startSecs,
StopTimeSecs: stopSecs,
Weekdays: weekdays,
WeekdaysNegated: weekdaysNegated,
}
if action != nil && *action == ActionModify {
mod, ok := fullModMap[rule.Modifier.Name]
@@ -391,6 +423,87 @@ func buildFunctionMap(config *BuiltinConfig) (map[string]*Function, *geo.GeoMatc
}, geoMatcher
}
func matchTime(now time.Time, startSecs, stopSecs int, weekdays []time.Weekday, negated bool) bool {
if startSecs >= 0 || stopSecs >= 0 {
currentSecs := now.Hour()*3600 + now.Minute()*60 + now.Second()
if startSecs >= 0 && stopSecs >= 0 {
if startSecs <= stopSecs {
if currentSecs < startSecs || currentSecs > stopSecs {
return false
}
} else {
if currentSecs < startSecs && currentSecs > stopSecs {
return false
}
}
} else if startSecs >= 0 {
if currentSecs < startSecs {
return false
}
} else if currentSecs > stopSecs {
return false
}
}
if len(weekdays) > 0 {
current := now.Weekday()
found := false
for _, d := range weekdays {
if current == d {
found = true
break
}
}
if negated == found {
return false
}
}
return true
}
func parseTimeOfDay(s string) (int, error) {
t, err := time.Parse("15:04:05", s)
if err != nil {
return -1, fmt.Errorf("invalid time %q (expected hh:mm:ss)", s)
}
return t.Hour()*3600 + t.Minute()*60 + t.Second(), nil
}
func parseWeekdays(days []string) ([]time.Weekday, bool, error) {
if len(days) == 0 {
return nil, false, nil
}
negated := false
parsed := make([]time.Weekday, 0, len(days))
for i, d := range days {
d = strings.TrimSpace(d)
if i == 0 && strings.HasPrefix(d, "!") {
negated = true
d = strings.TrimSpace(strings.TrimPrefix(d, "!"))
}
var wd time.Weekday
switch strings.ToLower(d) {
case "sun", "sunday":
wd = time.Sunday
case "mon", "monday":
wd = time.Monday
case "tue", "tues", "tuesday":
wd = time.Tuesday
case "wed", "wednesday":
wd = time.Wednesday
case "thu", "thur", "thurs", "thursday":
wd = time.Thursday
case "fri", "friday":
wd = time.Friday
case "sat", "saturday":
wd = time.Saturday
default:
return nil, false, fmt.Errorf("invalid weekday %q", d)
}
parsed = append(parsed, wd)
}
return parsed, negated, nil
}
const rulesetLogMetaKey = "_ruleset"
func addGeoSiteLogMetadata(info StreamInfo, gm *geo.GeoMatcher, conditions []string) StreamInfo {

121
ruleset/interface_test.go Normal file
View File

@@ -0,0 +1,121 @@
package ruleset
import (
"net"
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
)
func TestAction_String(t *testing.T) {
tests := []struct {
action Action
want string
}{
{ActionMaybe, "maybe"},
{ActionAllow, "allow"},
{ActionBlock, "block"},
{ActionDrop, "drop"},
{ActionModify, "modify"},
{Action(99), "unknown"},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
if got := tt.action.String(); got != tt.want {
t.Errorf("Action.String() = %q, want %q", got, tt.want)
}
})
}
}
func TestProtocol_String(t *testing.T) {
tests := []struct {
protocol Protocol
want string
}{
{ProtocolTCP, "tcp"},
{ProtocolUDP, "udp"},
{Protocol(99), "unknown"},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
if got := tt.protocol.String(); got != tt.want {
t.Errorf("Protocol.String() = %q, want %q", got, tt.protocol)
}
})
}
}
func TestProtocol_Constants(t *testing.T) {
if ProtocolTCP != 0 {
t.Errorf("ProtocolTCP = %d, want 0", ProtocolTCP)
}
if ProtocolUDP != 1 {
t.Errorf("ProtocolUDP = %d, want 1", ProtocolUDP)
}
}
func TestAction_Constants(t *testing.T) {
if ActionMaybe != 0 {
t.Errorf("ActionMaybe = %d, want 0", ActionMaybe)
}
if ActionAllow != 1 {
t.Errorf("ActionAllow = %d, want 1", ActionAllow)
}
if ActionBlock != 2 {
t.Errorf("ActionBlock = %d, want 2", ActionBlock)
}
}
func TestStreamInfo_SrcString(t *testing.T) {
info := StreamInfo{
SrcIP: net.ParseIP("192.168.1.1"),
SrcPort: 8080,
}
want := "192.168.1.1:8080"
if got := info.SrcString(); got != want {
t.Errorf("SrcString() = %q, want %q", got, want)
}
}
func TestStreamInfo_DstString(t *testing.T) {
info := StreamInfo{
DstIP: net.ParseIP("10.0.0.1"),
DstPort: 443,
}
want := "10.0.0.1:443"
if got := info.DstString(); got != want {
t.Errorf("DstString() = %q, want %q", got, want)
}
}
func TestStreamInfo_SrcString_IPv6(t *testing.T) {
info := StreamInfo{
SrcIP: net.ParseIP("::1"),
SrcPort: 53,
}
want := "[::1]:53"
if got := info.SrcString(); got != want {
t.Errorf("SrcString() = %q, want %q", got, want)
}
}
func TestMatchResult_ZeroValue(t *testing.T) {
var mr MatchResult
if mr.Action != ActionMaybe {
t.Errorf("zero MatchResult.Action = %v, want ActionMaybe (0)", mr.Action)
}
}
func TestStreamInfo_PropsInitialization(t *testing.T) {
info := StreamInfo{
Props: analyzer.CombinedPropMap{
"tls": analyzer.PropMap{"sni": "example.com"},
},
}
if info.Props.Get("tls", "sni") != "example.com" {
t.Error("StreamInfo.Props not properly initialized")
}
}