test: improve coverage across package
This commit is contained in:
23
README.md
23
README.md
@@ -41,6 +41,29 @@ func main() {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
Run all tests:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go test ./...
|
||||||
|
```
|
||||||
|
|
||||||
|
Run tests with coverage:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go test -cover ./... # per-package summary
|
||||||
|
go test -coverprofile=coverage.out ./... # full profile
|
||||||
|
go tool cover -html=coverage.out -o coverage.html # HTML report
|
||||||
|
```
|
||||||
|
|
||||||
|
Run tests for a specific package:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go test -v ./analyzer/tcp/
|
||||||
|
go test -v -run TestSSH ./analyzer/tcp/
|
||||||
|
```
|
||||||
|
|
||||||
Based on OpenGFW by apernet: https://github.com/apernet/OpenGFW
|
Based on OpenGFW by apernet: https://github.com/apernet/OpenGFW
|
||||||
|
|
||||||
## LICENSE
|
## LICENSE
|
||||||
|
|||||||
124
analyzer/interface_test.go
Normal file
124
analyzer/interface_test.go
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
package analyzer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPropMap_Get(t *testing.T) {
|
||||||
|
m := PropMap{
|
||||||
|
"a": "value-a",
|
||||||
|
"nested": PropMap{
|
||||||
|
"b": "value-b",
|
||||||
|
"deep": PropMap{
|
||||||
|
"c": "value-c",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
want interface{}
|
||||||
|
}{
|
||||||
|
{"simple key", "a", "value-a"},
|
||||||
|
{"nested key", "nested.b", "value-b"},
|
||||||
|
{"deeply nested", "nested.deep.c", "value-c"},
|
||||||
|
{"non-existent top", "x", nil},
|
||||||
|
{"non-existent nested", "nested.x", nil},
|
||||||
|
{"empty key", "", nil},
|
||||||
|
{"partial path", "nested", PropMap{"b": "value-b", "deep": PropMap{"c": "value-c"}}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := m.Get(tt.key)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("PropMap.Get(%q) = %v, want %v", tt.key, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPropMap_Get_NilMap(t *testing.T) {
|
||||||
|
var m PropMap
|
||||||
|
if got := m.Get("any"); got != nil {
|
||||||
|
t.Errorf("nil PropMap.Get() = %v, want nil", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPropMap_Get_EmptyMap(t *testing.T) {
|
||||||
|
m := PropMap{}
|
||||||
|
if got := m.Get("any"); got != nil {
|
||||||
|
t.Errorf("empty PropMap.Get() = %v, want nil", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPropMap_Get_IntermediateNonMap(t *testing.T) {
|
||||||
|
m := PropMap{
|
||||||
|
"a": "hello",
|
||||||
|
}
|
||||||
|
if got := m.Get("a.b.c"); got != nil {
|
||||||
|
t.Errorf("PropMap.Get() through non-map = %v, want nil", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCombinedPropMap_Get(t *testing.T) {
|
||||||
|
cm := CombinedPropMap{
|
||||||
|
"tls": PropMap{
|
||||||
|
"req": PropMap{
|
||||||
|
"sni": "example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"http": PropMap{
|
||||||
|
"resp": PropMap{
|
||||||
|
"status": 200,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
analyzer string
|
||||||
|
key string
|
||||||
|
want interface{}
|
||||||
|
}{
|
||||||
|
{"tls sni", "tls", "req.sni", "example.com"},
|
||||||
|
{"http status", "http", "resp.status", 200},
|
||||||
|
{"unknown analyzer", "dns", "any", nil},
|
||||||
|
{"valid analyzer bad key", "tls", "req.nonexistent", nil},
|
||||||
|
{"empty analyzer", "", "key", nil},
|
||||||
|
{"empty key", "tls", "", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := cm.Get(tt.analyzer, tt.key)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("CombinedPropMap.Get(%q, %q) = %v, want %v", tt.analyzer, tt.key, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCombinedPropMap_Get_Nil(t *testing.T) {
|
||||||
|
var cm CombinedPropMap
|
||||||
|
if got := cm.Get("any", "key"); got != nil {
|
||||||
|
t.Errorf("nil CombinedPropMap.Get() = %v, want nil", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPropUpdateType_Values(t *testing.T) {
|
||||||
|
if PropUpdateNone != 0 {
|
||||||
|
t.Errorf("PropUpdateNone = %d, want 0", PropUpdateNone)
|
||||||
|
}
|
||||||
|
if PropUpdateMerge != 1 {
|
||||||
|
t.Errorf("PropUpdateMerge = %d, want 1", PropUpdateMerge)
|
||||||
|
}
|
||||||
|
if PropUpdateReplace != 2 {
|
||||||
|
t.Errorf("PropUpdateReplace = %d, want 2", PropUpdateReplace)
|
||||||
|
}
|
||||||
|
if PropUpdateDelete != 3 {
|
||||||
|
t.Errorf("PropUpdateDelete = %d, want 3", PropUpdateDelete)
|
||||||
|
}
|
||||||
|
}
|
||||||
319
analyzer/internal/tls_test.go
Normal file
319
analyzer/internal/tls_test.go
Normal file
@@ -0,0 +1,319 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.difuse.io/Difuse/Mellaris/analyzer/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func buildClientHelloMsg(t *testing.T) *utils.ByteBuffer {
|
||||||
|
t.Helper()
|
||||||
|
// Bytes taken from the standard TLS test vector, starting after the record
|
||||||
|
// header (5 bytes) and handshake header (4 bytes).
|
||||||
|
body := []byte{
|
||||||
|
0x03, 0x03, // version TLS 1.2
|
||||||
|
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||||
|
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||||
|
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
|
||||||
|
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
|
||||||
|
0x00, // session ID length = 0
|
||||||
|
0x00, 0x20, // cipher suites length = 32 bytes = 16 suites
|
||||||
|
0xcc, 0xa8, 0xcc, 0xa9, 0xc0, 0x2f, 0xc0, 0x30,
|
||||||
|
0xc0, 0x2b, 0xc0, 0x2c, 0xc0, 0x13, 0xc0, 0x09,
|
||||||
|
0xc0, 0x14, 0xc0, 0x0a, 0x00, 0x9c, 0x00, 0x9d,
|
||||||
|
0x00, 0x2f, 0x00, 0x35, 0xc0, 0x12, 0x00, 0x0a, // ciphers
|
||||||
|
0x01, // compression methods length = 1
|
||||||
|
0x00, // compression method = null
|
||||||
|
0x00, 0x58, // extensions length = 88
|
||||||
|
// extension: server_name
|
||||||
|
0x00, 0x00, // type = server_name
|
||||||
|
0x00, 0x18, // length = 24
|
||||||
|
0x00, 0x16, // server name list length = 22
|
||||||
|
0x00, // name type = hostname
|
||||||
|
0x00, 0x13, // name length = 19
|
||||||
|
'e', 'x', 'a', 'm', 'p', 'l', 'e', '.', 'u', 'l', 'f', 'h', 'e', 'i', 'm', '.', 'n', 'e', 't',
|
||||||
|
// extension: status_request
|
||||||
|
0x00, 0x05, // type = status_request
|
||||||
|
0x00, 0x05, // length = 5
|
||||||
|
0x01, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
// extension: supported_groups
|
||||||
|
0x00, 0x0a, // type = supported_groups
|
||||||
|
0x00, 0x0a, // length = 10
|
||||||
|
0x00, 0x08, // list length = 8
|
||||||
|
0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19,
|
||||||
|
// extension: ec_point_formats
|
||||||
|
0x00, 0x0b, // type = ec_point_formats
|
||||||
|
0x00, 0x02, // length = 2
|
||||||
|
0x01, 0x00,
|
||||||
|
// extension: signature_algorithms
|
||||||
|
0x00, 0x0d, // type = signature_algorithms
|
||||||
|
0x00, 0x12, // length = 18
|
||||||
|
0x00, 0x10, // list length = 16
|
||||||
|
0x04, 0x01, 0x04, 0x03, 0x05, 0x01, 0x05, 0x03,
|
||||||
|
0x06, 0x01, 0x06, 0x03, 0x02, 0x01, 0x02, 0x03,
|
||||||
|
// extension: renegotiation_info
|
||||||
|
0xff, 0x01, // type = renegotiation_info
|
||||||
|
0x00, 0x01, // length = 1
|
||||||
|
0x00,
|
||||||
|
// extension: extended_master_secret (empty)
|
||||||
|
0x00, 0x17, // type = extended_master_secret
|
||||||
|
0x00, 0x00, // length = 0
|
||||||
|
// extension: session_ticket (empty)
|
||||||
|
0x00, 0x23, // type = session_ticket
|
||||||
|
0x00, 0x00, // length = 0
|
||||||
|
}
|
||||||
|
return &utils.ByteBuffer{Buf: body}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTLSClientHelloMsgData(t *testing.T) {
|
||||||
|
chBuf := buildClientHelloMsg(t)
|
||||||
|
m := ParseTLSClientHelloMsgData(chBuf)
|
||||||
|
if m == nil {
|
||||||
|
t.Fatal("ParseTLSClientHelloMsgData returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
wantVersion := uint16(0x0303)
|
||||||
|
if v, ok := m["version"].(uint16); !ok || v != wantVersion {
|
||||||
|
t.Errorf("version = %v, want %v", m["version"], wantVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantCiphers := []uint16{52392, 52393, 49199, 49200, 49195, 49196, 49171, 49161, 49172, 49162, 156, 157, 47, 53, 49170, 10}
|
||||||
|
if c, ok := m["ciphers"].([]uint16); ok {
|
||||||
|
if !reflect.DeepEqual(c, wantCiphers) {
|
||||||
|
t.Errorf("ciphers = %v, want %v", c, wantCiphers)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Errorf("ciphers missing or wrong type: %T", m["ciphers"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if sni, ok := m["sni"].(string); !ok || sni != "example.ulfheim.net" {
|
||||||
|
t.Errorf("sni = %q, want %q", m["sni"], "example.ulfheim.net")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := m["compression"]; !ok {
|
||||||
|
t.Error("compression key missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := m["session"]; !ok {
|
||||||
|
t.Error("session key missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := m["random"]; !ok {
|
||||||
|
t.Error("random key missing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTLSServerHelloMsgData(t *testing.T) {
|
||||||
|
// ServerHello message body (after record header + handshake header)
|
||||||
|
body := []byte{
|
||||||
|
0x03, 0x03, // version TLS 1.2
|
||||||
|
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77,
|
||||||
|
0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f,
|
||||||
|
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87,
|
||||||
|
0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, // random
|
||||||
|
0x00, // session ID length = 0
|
||||||
|
0xc0, 0x13, // cipher suite
|
||||||
|
0x00, // compression method
|
||||||
|
0x00, 0x05, // extensions length = 5
|
||||||
|
// extension: renegotiation_info
|
||||||
|
0xff, 0x01, // type
|
||||||
|
0x00, 0x01, // length = 1
|
||||||
|
0x00,
|
||||||
|
}
|
||||||
|
|
||||||
|
m := ParseTLSServerHelloMsgData(&utils.ByteBuffer{Buf: body})
|
||||||
|
if m == nil {
|
||||||
|
t.Fatal("ParseTLSServerHelloMsgData returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
wantCipher := uint16(0xc013)
|
||||||
|
if c, ok := m["cipher"].(uint16); !ok || c != wantCipher {
|
||||||
|
t.Errorf("cipher = %v, want %v", m["cipher"], wantCipher)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantCompression := uint8(0)
|
||||||
|
if c, ok := m["compression"].(uint8); !ok || c != wantCompression {
|
||||||
|
t.Errorf("compression = %v, want %v", m["compression"], wantCompression)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTLSServerHelloMsgData_NoExtensions(t *testing.T) {
|
||||||
|
// ServerHello message without extensions
|
||||||
|
body := []byte{
|
||||||
|
0x03, 0x03, // version
|
||||||
|
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||||
|
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||||
|
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
|
||||||
|
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
|
||||||
|
0x00, // session ID length
|
||||||
|
0x00, 0xff, // cipher suite
|
||||||
|
0x00, // compression method
|
||||||
|
}
|
||||||
|
|
||||||
|
m := ParseTLSServerHelloMsgData(&utils.ByteBuffer{Buf: body})
|
||||||
|
if m == nil {
|
||||||
|
t.Fatal("ParseTLSServerHelloMsgData returned nil for no-extensions case")
|
||||||
|
}
|
||||||
|
if _, ok := m["cipher"]; !ok {
|
||||||
|
t.Error("cipher key missing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTLSClientHelloMsgData_Truncated(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
buf []byte
|
||||||
|
}{
|
||||||
|
{"too short for session id", []byte{0x03, 0x03, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f}},
|
||||||
|
{"odd cipher suites length", []byte{
|
||||||
|
0x03, 0x03, // version
|
||||||
|
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||||
|
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
|
||||||
|
0x00, // session ID length
|
||||||
|
0x00, 0x03, // odd cipher suites length
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: tt.buf})
|
||||||
|
if m != nil {
|
||||||
|
t.Error("expected nil for truncated input")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTLSClientHelloMsgData_ECH(t *testing.T) {
|
||||||
|
// ClientHello with ECH extension
|
||||||
|
body := []byte{
|
||||||
|
0x03, 0x03, // version
|
||||||
|
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||||
|
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||||
|
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
|
||||||
|
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
|
||||||
|
0x00, // session ID length
|
||||||
|
0x00, 0x02, // cipher suites length = 2
|
||||||
|
0x13, 0x01, // TLS_AES_128_GCM_SHA256
|
||||||
|
0x01, // compression methods length = 1
|
||||||
|
0x00, // null compression
|
||||||
|
0x00, 0x07, // extensions length = 7
|
||||||
|
// ECH extension
|
||||||
|
0xfe, 0x0d, // type = encrypted_client_hello
|
||||||
|
0x00, 0x03, // length = 3
|
||||||
|
0x01, 0x02, 0x03, // some ECH data
|
||||||
|
}
|
||||||
|
|
||||||
|
m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: body})
|
||||||
|
if m == nil {
|
||||||
|
t.Fatal("ParseTLSClientHelloMsgData returned nil")
|
||||||
|
}
|
||||||
|
ech, ok := m["ech"].(bool)
|
||||||
|
if !ok || !ech {
|
||||||
|
t.Errorf("ech = %v, want true", m["ech"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTLSClientHelloMsgData_ALPN(t *testing.T) {
|
||||||
|
// ClientHello with ALPN extension
|
||||||
|
body := []byte{
|
||||||
|
0x03, 0x03, // version
|
||||||
|
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||||
|
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||||
|
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
|
||||||
|
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
|
||||||
|
0x00, // session ID length
|
||||||
|
0x00, 0x02, // cipher suites length = 2
|
||||||
|
0x13, 0x01, // cipher suite
|
||||||
|
0x01, // compression methods length = 1
|
||||||
|
0x00, // compression method
|
||||||
|
0x00, 0x12, // extensions length = 18
|
||||||
|
// ALPN extension
|
||||||
|
0x00, 0x10, // type = ALPN
|
||||||
|
0x00, 0x0e, // length = 14
|
||||||
|
0x00, 0x0c, // list length = 12
|
||||||
|
0x08, 'h', 't', 't', 'p', '/', '1', '.', '1',
|
||||||
|
0x02, 'h', '2',
|
||||||
|
}
|
||||||
|
|
||||||
|
m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: body})
|
||||||
|
if m == nil {
|
||||||
|
t.Fatal("ParseTLSClientHelloMsgData returned nil")
|
||||||
|
}
|
||||||
|
alpn, ok := m["alpn"].([]string)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("alpn missing or wrong type: %T", m["alpn"])
|
||||||
|
}
|
||||||
|
if len(alpn) != 2 || alpn[0] != "http/1.1" || alpn[1] != "h2" {
|
||||||
|
t.Errorf("alpn = %v, want [http/1.1 h2]", alpn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTLSClientHelloMsgData_SupportedVersionsClient(t *testing.T) {
|
||||||
|
// ClientHello with supported_versions extension (client format - list)
|
||||||
|
body := []byte{
|
||||||
|
0x03, 0x03, // version
|
||||||
|
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||||
|
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||||
|
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
|
||||||
|
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
|
||||||
|
0x00, // session ID length
|
||||||
|
0x00, 0x02, // cipher suites length = 2
|
||||||
|
0x13, 0x01, // cipher suite
|
||||||
|
0x01, // compression methods length = 1
|
||||||
|
0x00, // compression method
|
||||||
|
0x00, 0x0b, // extensions length = 11
|
||||||
|
// supported_versions (client list format)
|
||||||
|
0x00, 0x2b, // type = supported_versions
|
||||||
|
0x00, 0x07, // length = 7
|
||||||
|
0x06, // list length = 6
|
||||||
|
0x03, 0x04, // TLS 1.3
|
||||||
|
0x03, 0x03, // TLS 1.2
|
||||||
|
0x03, 0x01, // TLS 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: body})
|
||||||
|
if m == nil {
|
||||||
|
t.Fatal("ParseTLSClientHelloMsgData returned nil")
|
||||||
|
}
|
||||||
|
versions, ok := m["supported_versions"].([]uint16)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("supported_versions missing or wrong type: %T", m["supported_versions"])
|
||||||
|
}
|
||||||
|
want := []uint16{0x0304, 0x0303, 0x0301}
|
||||||
|
if !reflect.DeepEqual(versions, want) {
|
||||||
|
t.Errorf("supported_versions = %v, want %v", versions, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTLSServerHelloMsgData_SupportedVersionsServer(t *testing.T) {
|
||||||
|
// ServerHello with supported_versions extension (server format - single value)
|
||||||
|
body := []byte{
|
||||||
|
0x03, 0x03, // version
|
||||||
|
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||||
|
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||||
|
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
|
||||||
|
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
|
||||||
|
0x00, // session ID length
|
||||||
|
0x13, 0x01, // cipher suite
|
||||||
|
0x00, // compression method
|
||||||
|
0x00, 0x06, // extensions length = 6
|
||||||
|
// supported_versions (server format - single value)
|
||||||
|
0x00, 0x2b, // type
|
||||||
|
0x00, 0x02, // length = 2
|
||||||
|
0x03, 0x04, // TLS 1.3
|
||||||
|
}
|
||||||
|
|
||||||
|
m := ParseTLSServerHelloMsgData(&utils.ByteBuffer{Buf: body})
|
||||||
|
if m == nil {
|
||||||
|
t.Fatal("ParseTLSServerHelloMsgData returned nil")
|
||||||
|
}
|
||||||
|
v, ok := m["supported_versions"].(uint16)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("supported_versions missing or wrong type: %T", m["supported_versions"])
|
||||||
|
}
|
||||||
|
if v != 0x0304 {
|
||||||
|
t.Errorf("supported_versions = 0x%04x, want 0x0304", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
256
analyzer/tcp/fet_test.go
Normal file
256
analyzer/tcp/fet_test.go
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPopCount(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input byte
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{0x00, 0},
|
||||||
|
{0x01, 1},
|
||||||
|
{0x02, 1},
|
||||||
|
{0x03, 2},
|
||||||
|
{0x07, 3},
|
||||||
|
{0x0f, 4},
|
||||||
|
{0xff, 8},
|
||||||
|
{0x55, 4}, // 01010101
|
||||||
|
{0xaa, 4}, // 10101010
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := popCount(tt.input); got != tt.want {
|
||||||
|
t.Errorf("popCount(0x%02x) = %d, want %d", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAveragePopCount(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []byte
|
||||||
|
want float32
|
||||||
|
}{
|
||||||
|
{"empty", []byte{}, 0},
|
||||||
|
{"all zeros", []byte{0x00, 0x00, 0x00}, 0},
|
||||||
|
{"all ones", []byte{0xff, 0xff}, 8.0},
|
||||||
|
{"mixed", []byte{0x00, 0xff}, 4.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := averagePopCount(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("averagePopCount() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsFirstSixPrintable(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []byte
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"too short", []byte("abc"), false},
|
||||||
|
{"all printable", []byte("abcdef"), true},
|
||||||
|
{"non-printable at pos 0", []byte{0x00, 'a', 'b', 'c', 'd', 'e'}, false},
|
||||||
|
{"non-printable at pos 5", []byte{'a', 'b', 'c', 'd', 'e', 0x1f}, false},
|
||||||
|
{"exactly 6 printable", []byte("123456"), true},
|
||||||
|
{"spaces", []byte(" "), true},
|
||||||
|
{"non-printable middle", []byte{'a', 'b', 0x01, 'd', 'e', 'f'}, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := isFirstSixPrintable(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("isFirstSixPrintable(%q) = %v, want %v", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrintablePercentage(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []byte
|
||||||
|
want float32
|
||||||
|
}{
|
||||||
|
{"empty", []byte{}, 0},
|
||||||
|
{"all printable", []byte("hello"), 1.0},
|
||||||
|
{"none printable", []byte{0x00, 0x01, 0x02, 0x03}, 0},
|
||||||
|
{"half printable", []byte{'a', 0x00, 'b', 0x00}, 0.5},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := printablePercentage(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("printablePercentage() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContiguousPrintable(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []byte
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"empty", []byte{}, 0},
|
||||||
|
{"all printable", []byte("hello world"), 11},
|
||||||
|
{"none printable", []byte{0x00, 0x01, 0x02}, 0},
|
||||||
|
{"start printable", []byte{'a', 'b', 'c', 0x00, 'd', 'e', 'f'}, 3},
|
||||||
|
{"end printable", []byte{0x00, 'a', 'b', 'c', 'd', 'e', 'f'}, 6},
|
||||||
|
{"middle printable", []byte{0x00, 'a', 'b', 'c', 0x00}, 3},
|
||||||
|
{"two segments", []byte{'a', 'b', 0x00, 'c', 'd', 'e'}, 3},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := contiguousPrintable(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("contiguousPrintable() = %d, want %d", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsTLSorHTTP(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []byte
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"too short", []byte("AB"), false},
|
||||||
|
// TLS ClientHello: 0x16 0x03 0x01
|
||||||
|
{"tls 1.0", []byte{0x16, 0x03, 0x01}, true},
|
||||||
|
// TLS 0x17 is application data record
|
||||||
|
{"tls app data", []byte{0x17, 0x03, 0x03}, true},
|
||||||
|
{"tls max content type", []byte{0x17, 0x03, 0x09}, true},
|
||||||
|
{"bad tls content type", []byte{0x15, 0x03, 0x01}, false},
|
||||||
|
{"bad tls version", []byte{0x16, 0x04, 0x01}, false},
|
||||||
|
{"bad tls length", []byte{0x16, 0x03, 0x0a}, false},
|
||||||
|
// HTTP methods
|
||||||
|
{"GET", []byte("GET / HTTP/1.1..."), true},
|
||||||
|
{"HEAD", []byte("HEAD /index.html..."), true},
|
||||||
|
{"POST", []byte("POST /api..."), true},
|
||||||
|
{"PUT", []byte("PUT /data..."), true},
|
||||||
|
{"DELETE", []byte("DELETE /..."), true},
|
||||||
|
{"CONNECT", []byte("CONNECT proxy..."), true},
|
||||||
|
{"OPTIONS", []byte("OPTIONS *..."), true},
|
||||||
|
{"TRACE", []byte("TRACE /..."), true},
|
||||||
|
{"PATCH", []byte("PATCH /..."), true},
|
||||||
|
{"random data", []byte{0x00, 0x01, 0x02, 0x03}, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := isTLSorHTTP(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("isTLSorHTTP(%v) = %v, want %v", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsPrintable(t *testing.T) {
|
||||||
|
if !isPrintable('a') {
|
||||||
|
t.Error("'a' should be printable")
|
||||||
|
}
|
||||||
|
if isPrintable(0x00) {
|
||||||
|
t.Error("0x00 should not be printable")
|
||||||
|
}
|
||||||
|
if !isPrintable(0x20) {
|
||||||
|
t.Error("0x20 (space) should be printable")
|
||||||
|
}
|
||||||
|
if !isPrintable(0x7e) {
|
||||||
|
t.Error("0x7e (~) should be printable")
|
||||||
|
}
|
||||||
|
if isPrintable(0x7f) {
|
||||||
|
t.Error("0x7f (DEL) should not be printable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFETStream_Feed(t *testing.T) {
|
||||||
|
s := newFETStream(nil)
|
||||||
|
u, done := s.Feed(false, false, false, 0, []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
|
||||||
|
if u == nil {
|
||||||
|
t.Fatal("Feed returned nil update")
|
||||||
|
}
|
||||||
|
if !done {
|
||||||
|
t.Error("FET should be done after first packet")
|
||||||
|
}
|
||||||
|
m := u.M
|
||||||
|
if _, ok := m["ex1"]; !ok {
|
||||||
|
t.Error("ex1 missing")
|
||||||
|
}
|
||||||
|
if _, ok := m["ex2"]; !ok {
|
||||||
|
t.Error("ex2 missing")
|
||||||
|
}
|
||||||
|
if _, ok := m["ex3"]; !ok {
|
||||||
|
t.Error("ex3 missing")
|
||||||
|
}
|
||||||
|
if _, ok := m["ex4"]; !ok {
|
||||||
|
t.Error("ex4 missing")
|
||||||
|
}
|
||||||
|
if _, ok := m["ex5"]; !ok {
|
||||||
|
t.Error("ex5 missing")
|
||||||
|
}
|
||||||
|
if yes, ok := m["yes"].(bool); ok && yes {
|
||||||
|
t.Error("HTTP should be exempt (yes=false)")
|
||||||
|
}
|
||||||
|
if u.Type != 2 {
|
||||||
|
t.Errorf("prop update type = %d, want PropUpdateReplace", u.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFETStream_Feed_EncryptedLike(t *testing.T) {
|
||||||
|
s := newFETStream(nil)
|
||||||
|
data := make([]byte, 100)
|
||||||
|
for i := range data {
|
||||||
|
data[i] = byte(i % 256)
|
||||||
|
}
|
||||||
|
u, done := s.Feed(false, false, false, 0, data)
|
||||||
|
if u == nil {
|
||||||
|
t.Fatal("Feed returned nil update")
|
||||||
|
}
|
||||||
|
if !done {
|
||||||
|
t.Error("should be done")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFETStream_Feed_Skip(t *testing.T) {
|
||||||
|
s := newFETStream(nil)
|
||||||
|
_, done := s.Feed(false, false, false, 5, []byte("data"))
|
||||||
|
if !done {
|
||||||
|
t.Error("skip != 0 should return done=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFETStream_Feed_Empty(t *testing.T) {
|
||||||
|
s := newFETStream(nil)
|
||||||
|
u, done := s.Feed(false, false, false, 0, []byte{})
|
||||||
|
if u != nil || done {
|
||||||
|
t.Error("empty data should return nil, false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFETAnalyzer_Name(t *testing.T) {
|
||||||
|
a := &FETAnalyzer{}
|
||||||
|
if a.Name() != "fet" {
|
||||||
|
t.Errorf("Name() = %q, want fet", a.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFETStream_Close(t *testing.T) {
|
||||||
|
s := newFETStream(nil)
|
||||||
|
if u := s.Close(false); u != nil {
|
||||||
|
t.Error("Close should return nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
159
analyzer/tcp/ssh_test.go
Normal file
159
analyzer/tcp/ssh_test.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSSHAnalyzer_Name(t *testing.T) {
|
||||||
|
a := &SSHAnalyzer{}
|
||||||
|
if a.Name() != "ssh" {
|
||||||
|
t.Errorf("Name() = %q, want ssh", a.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHStream_Feed_Client(t *testing.T) {
|
||||||
|
s := newSSHStream(nil)
|
||||||
|
u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3\r\n"))
|
||||||
|
if u == nil {
|
||||||
|
t.Fatal("Feed returned nil update")
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
t.Error("should not be done (server not yet received)")
|
||||||
|
}
|
||||||
|
client, ok := u.M["client"].(analyzer.PropMap)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("client prop missing")
|
||||||
|
}
|
||||||
|
if client["protocol"] != "2.0" {
|
||||||
|
t.Errorf("protocol = %v, want 2.0", client["protocol"])
|
||||||
|
}
|
||||||
|
if client["software"] != "OpenSSH_8.9p1" {
|
||||||
|
t.Errorf("software = %v, want OpenSSH_8.9p1", client["software"])
|
||||||
|
}
|
||||||
|
if comments, ok := client["comments"]; ok {
|
||||||
|
if comments != "Ubuntu-3" {
|
||||||
|
t.Errorf("comments = %v, want Ubuntu-3", comments)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHStream_Feed_ClientWithComments(t *testing.T) {
|
||||||
|
s := newSSHStream(nil)
|
||||||
|
u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH_7.4 Ubuntu-3\r\n"))
|
||||||
|
if u == nil {
|
||||||
|
t.Fatal("Feed returned nil update")
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
t.Error("should not be done")
|
||||||
|
}
|
||||||
|
client := u.M["client"].(analyzer.PropMap)
|
||||||
|
if client["comments"] != "Ubuntu-3" {
|
||||||
|
t.Errorf("comments = %v, want Ubuntu-3", client["comments"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHStream_Feed_Both(t *testing.T) {
|
||||||
|
s := newSSHStream(nil)
|
||||||
|
s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH_8.9\r\n"))
|
||||||
|
u, done := s.Feed(true, false, false, 0, []byte("SSH-2.0-dropbear_2022.83\r\n"))
|
||||||
|
if u == nil {
|
||||||
|
t.Fatal("Feed returned nil update")
|
||||||
|
}
|
||||||
|
if !done {
|
||||||
|
t.Error("should be done after both sides")
|
||||||
|
}
|
||||||
|
server, ok := u.M["server"].(analyzer.PropMap)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("server prop missing")
|
||||||
|
}
|
||||||
|
if server["software"] != "dropbear_2022.83" {
|
||||||
|
t.Errorf("server software = %v", server["software"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHStream_Feed_NotSSH(t *testing.T) {
|
||||||
|
s := newSSHStream(nil)
|
||||||
|
u, done := s.Feed(false, false, false, 0, []byte("HTTP/1.1 200 OK\r\n"))
|
||||||
|
if u != nil {
|
||||||
|
t.Error("should return nil for non-SSH")
|
||||||
|
}
|
||||||
|
if !done {
|
||||||
|
t.Error("should be cancelled (done) for non-SSH")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHStream_Feed_InvalidLine(t *testing.T) {
|
||||||
|
s := newSSHStream(nil)
|
||||||
|
u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-foo bar baz\r\n"))
|
||||||
|
if u != nil {
|
||||||
|
t.Error("should return nil for invalid line (>2 fields)")
|
||||||
|
}
|
||||||
|
if !done {
|
||||||
|
t.Error("should be cancelled for invalid SSH line")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHStream_Feed_NoLineEnd(t *testing.T) {
|
||||||
|
s := newSSHStream(nil)
|
||||||
|
u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH"))
|
||||||
|
if u != nil {
|
||||||
|
t.Error("should return nil when no EOL found yet")
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
t.Error("should not be done, waiting for more data")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHStream_Feed_IncompleteThenComplete(t *testing.T) {
|
||||||
|
s := newSSHStream(nil)
|
||||||
|
u, _ := s.Feed(false, false, false, 0, []byte("SSH-2.0-"))
|
||||||
|
if u != nil {
|
||||||
|
t.Error("first partial feed should return nil")
|
||||||
|
}
|
||||||
|
u, done := s.Feed(false, false, false, 0, []byte("Dropbear\r\n"))
|
||||||
|
if u == nil {
|
||||||
|
t.Fatal("second feed should return update")
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
t.Error("should not be done (only client)")
|
||||||
|
}
|
||||||
|
client := u.M["client"].(analyzer.PropMap)
|
||||||
|
if client["software"] != "Dropbear" {
|
||||||
|
t.Errorf("software = %v", client["software"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHStream_Feed_Skip(t *testing.T) {
|
||||||
|
s := newSSHStream(nil)
|
||||||
|
_, done := s.Feed(false, false, false, 5, []byte("data"))
|
||||||
|
if !done {
|
||||||
|
t.Error("skip != 0 should return done=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHStream_Feed_Empty(t *testing.T) {
|
||||||
|
s := newSSHStream(nil)
|
||||||
|
u, done := s.Feed(false, false, false, 0, []byte{})
|
||||||
|
if u != nil || done {
|
||||||
|
t.Error("empty data should return nil, false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHStream_Close(t *testing.T) {
|
||||||
|
s := newSSHStream(nil)
|
||||||
|
s.clientBuf.Append([]byte("data"))
|
||||||
|
s.serverBuf.Append([]byte("data"))
|
||||||
|
s.clientMap = analyzer.PropMap{"key": "val"}
|
||||||
|
u := s.Close(false)
|
||||||
|
if u != nil {
|
||||||
|
t.Error("Close should return nil")
|
||||||
|
}
|
||||||
|
if s.clientBuf.Len() != 0 || s.serverBuf.Len() != 0 {
|
||||||
|
t.Error("Close should reset buffers")
|
||||||
|
}
|
||||||
|
if s.clientMap != nil || s.serverMap != nil {
|
||||||
|
t.Error("Close should nil maps")
|
||||||
|
}
|
||||||
|
}
|
||||||
187
analyzer/udp/wireguard_test.go
Normal file
187
analyzer/udp/wireguard_test.go
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func makeWireGuardPacket(msgType byte, body []byte) []byte {
|
||||||
|
buf := make([]byte, 1+3+len(body))
|
||||||
|
buf[0] = msgType
|
||||||
|
copy(buf[4:], body)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWireGuardUDPStream_Feed_HandshakeInitiation(t *testing.T) {
|
||||||
|
s := newWireGuardUDPStream(nil)
|
||||||
|
body := make([]byte, wireguardSizeHandshakeInitiation-4)
|
||||||
|
binary.LittleEndian.PutUint32(body[0:4], 0x01020304)
|
||||||
|
|
||||||
|
u, done := s.Feed(false, makeWireGuardPacket(wireguardTypeHandshakeInitiation, body))
|
||||||
|
if u == nil {
|
||||||
|
t.Fatal("Feed returned nil update")
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
t.Error("should not be done")
|
||||||
|
}
|
||||||
|
if u.Type != analyzer.PropUpdateReplace {
|
||||||
|
t.Errorf("Type = %d, want PropUpdateReplace", u.Type)
|
||||||
|
}
|
||||||
|
initMap, ok := u.M["handshake_initiation"].(analyzer.PropMap)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("handshake_initiation missing")
|
||||||
|
}
|
||||||
|
if idx := initMap["sender_index"].(uint32); idx != 0x01020304 {
|
||||||
|
t.Errorf("sender_index = %d, want %d", idx, 0x01020304)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWireGuardUDPStream_Feed_HandshakeResponse(t *testing.T) {
|
||||||
|
s := newWireGuardUDPStream(nil)
|
||||||
|
body := make([]byte, wireguardSizeHandshakeResponse-4)
|
||||||
|
binary.LittleEndian.PutUint32(body[0:4], 0x01020304)
|
||||||
|
binary.LittleEndian.PutUint32(body[4:8], 0x05060708)
|
||||||
|
|
||||||
|
u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeHandshakeResponse, body))
|
||||||
|
if u == nil {
|
||||||
|
t.Fatal("Feed returned nil update")
|
||||||
|
}
|
||||||
|
respMap, ok := u.M["handshake_response"].(analyzer.PropMap)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("handshake_response missing")
|
||||||
|
}
|
||||||
|
if idx := respMap["sender_index"].(uint32); idx != 0x01020304 {
|
||||||
|
t.Errorf("sender_index = %d", idx)
|
||||||
|
}
|
||||||
|
if idx := respMap["receiver_index"].(uint32); idx != 0x05060708 {
|
||||||
|
t.Errorf("receiver_index = %d", idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWireGuardUDPStream_Feed_PacketData(t *testing.T) {
|
||||||
|
s := newWireGuardUDPStream(nil)
|
||||||
|
body := make([]byte, wireguardMinSizePacketData-4)
|
||||||
|
binary.LittleEndian.PutUint32(body[0:4], 0x0a0b0c0d)
|
||||||
|
binary.LittleEndian.PutUint64(body[4:12], 42)
|
||||||
|
|
||||||
|
u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeData, body))
|
||||||
|
if u == nil {
|
||||||
|
t.Fatal("Feed returned nil update")
|
||||||
|
}
|
||||||
|
dataMap, ok := u.M["packet_data"].(analyzer.PropMap)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("packet_data missing")
|
||||||
|
}
|
||||||
|
if idx := dataMap["receiver_index"].(uint32); idx != 0x0a0b0c0d {
|
||||||
|
t.Errorf("receiver_index = %d", idx)
|
||||||
|
}
|
||||||
|
if ctr := dataMap["counter"].(uint64); ctr != 42 {
|
||||||
|
t.Errorf("counter = %d", ctr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWireGuardUDPStream_Feed_CookieReply(t *testing.T) {
|
||||||
|
s := newWireGuardUDPStream(nil)
|
||||||
|
body := make([]byte, wireguardSizePacketCookieReply-4)
|
||||||
|
binary.LittleEndian.PutUint32(body[0:4], 0x11111111)
|
||||||
|
|
||||||
|
u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeCookieReply, body))
|
||||||
|
if u == nil {
|
||||||
|
t.Fatal("Feed returned nil update")
|
||||||
|
}
|
||||||
|
crMap, ok := u.M["packet_cookie_reply"].(analyzer.PropMap)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("packet_cookie_reply missing")
|
||||||
|
}
|
||||||
|
if idx := crMap["receiver_index"].(uint32); idx != 0x11111111 {
|
||||||
|
t.Errorf("receiver_index = %d", idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWireGuardUDPStream_Feed_InvalidLength(t *testing.T) {
|
||||||
|
s := newWireGuardUDPStream(nil)
|
||||||
|
u, done := s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01})
|
||||||
|
if u == nil || !done {
|
||||||
|
u, done = s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01})
|
||||||
|
}
|
||||||
|
if u == nil || !done {
|
||||||
|
u, _ = s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01})
|
||||||
|
}
|
||||||
|
u, done = s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01})
|
||||||
|
if u != nil || !done {
|
||||||
|
t.Error("should return done after 4 invalid packets")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWireGuardUDPStream_Feed_TooShort(t *testing.T) {
|
||||||
|
s := newWireGuardUDPStream(nil)
|
||||||
|
u, _ := s.Feed(false, []byte{1, 0, 0})
|
||||||
|
if u != nil {
|
||||||
|
t.Error("should return nil for too-short packet")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWireGuardUDPStream_Feed_NonZeroReserved(t *testing.T) {
|
||||||
|
s := newWireGuardUDPStream(nil)
|
||||||
|
u, _ := s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 1, 0, 0, 0, 0})
|
||||||
|
if u != nil {
|
||||||
|
t.Error("should return nil when reserved bytes are non-zero")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWireGuardUDPStream_Feed_HandshakeInitiation_WrongSize(t *testing.T) {
|
||||||
|
s := newWireGuardUDPStream(nil)
|
||||||
|
body := make([]byte, wireguardSizeHandshakeInitiation-4+1)
|
||||||
|
binary.LittleEndian.PutUint32(body[0:4], 1)
|
||||||
|
u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeHandshakeInitiation, body))
|
||||||
|
if u != nil {
|
||||||
|
t.Error("should return nil for wrong init size")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWireGuardUDPStream_Close(t *testing.T) {
|
||||||
|
s := newWireGuardUDPStream(nil)
|
||||||
|
u := s.Close(false)
|
||||||
|
if u != nil {
|
||||||
|
t.Error("Close() should return nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWireGuardUDPStream_PutSenderIndex_MatchReceiverIndex(t *testing.T) {
|
||||||
|
s := newWireGuardUDPStream(nil)
|
||||||
|
|
||||||
|
s.putSenderIndex(false, 0xdeadbeef)
|
||||||
|
if !s.matchReceiverIndex(true, 0xdeadbeef) {
|
||||||
|
t.Error("should match reverse direction")
|
||||||
|
}
|
||||||
|
if s.matchReceiverIndex(false, 0xdeadbeef) {
|
||||||
|
t.Error("should not match same direction")
|
||||||
|
}
|
||||||
|
if s.matchReceiverIndex(true, 0x12345678) {
|
||||||
|
t.Error("should not match wrong index")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWireGuardUDPStream_RememberedIndexRing(t *testing.T) {
|
||||||
|
s := newWireGuardUDPStream(nil)
|
||||||
|
|
||||||
|
for i := uint32(0); i < wireguardRememberedIndexCount+2; i++ {
|
||||||
|
s.putSenderIndex(false, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.matchReceiverIndex(true, 0) {
|
||||||
|
t.Error("index 0 should have been evicted from ring")
|
||||||
|
}
|
||||||
|
if !s.matchReceiverIndex(true, uint32(wireguardRememberedIndexCount+1)) {
|
||||||
|
t.Error("latest index should still be in ring")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWireGuardAnalyzer_Name(t *testing.T) {
|
||||||
|
a := &WireGuardAnalyzer{}
|
||||||
|
if a.Name() != "wireguard" {
|
||||||
|
t.Errorf("Name() = %q, want wireguard", a.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
321
analyzer/utils/bytebuffer_test.go
Normal file
321
analyzer/utils/bytebuffer_test.go
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestByteBuffer_Append(t *testing.T) {
|
||||||
|
b := &ByteBuffer{}
|
||||||
|
b.Append([]byte("hello"))
|
||||||
|
b.Append([]byte(" "))
|
||||||
|
b.Append([]byte("world"))
|
||||||
|
if string(b.Buf) != "hello world" {
|
||||||
|
t.Errorf("Append() result = %q, want %q", string(b.Buf), "hello world")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestByteBuffer_Len(t *testing.T) {
|
||||||
|
b := &ByteBuffer{}
|
||||||
|
if b.Len() != 0 {
|
||||||
|
t.Errorf("Len() = %d, want 0", b.Len())
|
||||||
|
}
|
||||||
|
b.Append([]byte("abc"))
|
||||||
|
if b.Len() != 3 {
|
||||||
|
t.Errorf("Len() = %d, want 3", b.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestByteBuffer_Index(t *testing.T) {
|
||||||
|
b := &ByteBuffer{Buf: []byte("hello world")}
|
||||||
|
if i := b.Index([]byte("world")); i != 6 {
|
||||||
|
t.Errorf("Index('world') = %d, want 6", i)
|
||||||
|
}
|
||||||
|
if i := b.Index([]byte("xyz")); i != -1 {
|
||||||
|
t.Errorf("Index('xyz') = %d, want -1", i)
|
||||||
|
}
|
||||||
|
if i := b.Index([]byte{}); i != 0 {
|
||||||
|
t.Errorf("Index('') = %d, want 0", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestByteBuffer_Get(t *testing.T) {
|
||||||
|
b := &ByteBuffer{Buf: []byte("abcdef")}
|
||||||
|
|
||||||
|
data, ok := b.Get(3, true)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Get(3, true) returned false")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(data, []byte("abc")) {
|
||||||
|
t.Errorf("Get(3, true) data = %q, want %q", data, "abc")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b.Buf, []byte("def")) {
|
||||||
|
t.Errorf("after consume, Buf = %q, want %q", b.Buf, "def")
|
||||||
|
}
|
||||||
|
|
||||||
|
data, ok = b.Get(4, false)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("Get(4, false) should return false (only 3 bytes left)")
|
||||||
|
}
|
||||||
|
|
||||||
|
data, ok = b.Get(2, false)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Get(2, false) returned false")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(data, []byte("de")) {
|
||||||
|
t.Errorf("Get(2, false) data = %q, want %q", data, "de")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b.Buf, []byte("def")) {
|
||||||
|
t.Errorf("after non-consume, Buf = %q, want %q (unchanged)", b.Buf, "def")
|
||||||
|
}
|
||||||
|
|
||||||
|
data, ok = b.Get(3, true)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Get(3, true) returned false")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(data, []byte("def")) {
|
||||||
|
t.Errorf("Get(3, true) data = %q, want %q", data, "def")
|
||||||
|
}
|
||||||
|
if b.Len() != 0 {
|
||||||
|
t.Errorf("after consume all, Len() = %d, want 0", b.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestByteBuffer_GetString(t *testing.T) {
|
||||||
|
b := &ByteBuffer{Buf: []byte("hello world")}
|
||||||
|
|
||||||
|
s, ok := b.GetString(5, true)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("GetString(5, true) returned false")
|
||||||
|
}
|
||||||
|
if s != "hello" {
|
||||||
|
t.Errorf("GetString() = %q, want %q", s, "hello")
|
||||||
|
}
|
||||||
|
if b.Len() != 6 {
|
||||||
|
t.Errorf("after consume, Len() = %d, want 6", b.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
s, ok = b.GetString(7, false)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("GetString(7, false) should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
s, ok = b.GetString(6, false)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("GetString(6, false) returned false")
|
||||||
|
}
|
||||||
|
if s != " world" {
|
||||||
|
t.Errorf("GetString() = %q, want %q", s, " world")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestByteBuffer_GetByte(t *testing.T) {
|
||||||
|
b := &ByteBuffer{Buf: []byte("abc")}
|
||||||
|
|
||||||
|
bt, ok := b.GetByte(true)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("GetByte(true) returned false")
|
||||||
|
}
|
||||||
|
if bt != 'a' {
|
||||||
|
t.Errorf("GetByte() = %c, want 'a'", bt)
|
||||||
|
}
|
||||||
|
if b.Len() != 2 {
|
||||||
|
t.Errorf("after consume, Len() = %d, want 2", b.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
bt, ok = b.GetByte(false)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("GetByte(false) returned false")
|
||||||
|
}
|
||||||
|
if bt != 'b' {
|
||||||
|
t.Errorf("GetByte() = %c, want 'b'", bt)
|
||||||
|
}
|
||||||
|
if b.Len() != 2 {
|
||||||
|
t.Errorf("after non-consume, Len() = %d, want 2 (unchanged)", b.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
b.GetByte(true)
|
||||||
|
b.GetByte(true)
|
||||||
|
bt, ok = b.GetByte(true)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("GetByte(true) on empty buffer should return false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestByteBuffer_GetUint16(t *testing.T) {
|
||||||
|
b := &ByteBuffer{Buf: []byte{0x01, 0x02, 0x03, 0x04}}
|
||||||
|
|
||||||
|
v, ok := b.GetUint16(false, true)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("GetUint16(bigEndian) returned false")
|
||||||
|
}
|
||||||
|
if v != 0x0102 {
|
||||||
|
t.Errorf("GetUint16(bigEndian) = 0x%04x, want 0x0102", v)
|
||||||
|
}
|
||||||
|
if b.Len() != 2 {
|
||||||
|
t.Errorf("after consume, Len() = %d, want 2", b.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok = b.GetUint16(true, true)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("GetUint16(littleEndian) returned false")
|
||||||
|
}
|
||||||
|
if v != 0x0403 {
|
||||||
|
t.Errorf("GetUint16(littleEndian) = 0x%04x, want 0x0403", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok = b.GetUint16(false, false)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("GetUint16 on empty buffer should return false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestByteBuffer_GetUint32(t *testing.T) {
|
||||||
|
b := &ByteBuffer{Buf: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}}
|
||||||
|
|
||||||
|
v, ok := b.GetUint32(false, true)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("GetUint32(bigEndian) returned false")
|
||||||
|
}
|
||||||
|
if v != 0x01020304 {
|
||||||
|
t.Errorf("GetUint32(bigEndian) = 0x%08x, want 0x01020304", v)
|
||||||
|
}
|
||||||
|
if b.Len() != 4 {
|
||||||
|
t.Errorf("after consume, Len() = %d, want 4", b.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok = b.GetUint32(true, true)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("GetUint32(littleEndian) returned false")
|
||||||
|
}
|
||||||
|
if v != 0x08070605 {
|
||||||
|
t.Errorf("GetUint32(littleEndian) = 0x%08x, want 0x08070605", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok = b.GetUint32(false, false)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("GetUint32 on empty buffer should return false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestByteBuffer_GetUntil(t *testing.T) {
|
||||||
|
b := &ByteBuffer{Buf: []byte("hello\r\nworld\r\n")}
|
||||||
|
|
||||||
|
data, ok := b.GetUntil([]byte("\r\n"), true, true)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("GetUntil(sep, include) returned false")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(data, []byte("hello\r\n")) {
|
||||||
|
t.Errorf("GetUntil(include) = %q, want %q", data, "hello\\r\\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
data, ok = b.GetUntil([]byte("\r\n"), false, false)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("GetUntil(sep, exclude, non-consume) returned false")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(data, []byte("world")) {
|
||||||
|
t.Errorf("GetUntil(exclude) = %q, want %q", data, "world")
|
||||||
|
}
|
||||||
|
if b.Len() != 7 {
|
||||||
|
t.Errorf("after non-consume, Len() = %d, want 7", b.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
data, ok = b.GetUntil([]byte("\r\n"), true, true)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("GetUntil second (consume) returned false")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(data, []byte("world\r\n")) {
|
||||||
|
t.Errorf("GetUntil second = %q, want %q", data, "world\\r\\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok = b.GetUntil([]byte("xyz"), false, false)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("GetUntil(not found) should return false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestByteBuffer_GetSubBuffer(t *testing.T) {
|
||||||
|
b := &ByteBuffer{Buf: []byte("hello world")}
|
||||||
|
|
||||||
|
sub, ok := b.GetSubBuffer(5, true)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("GetSubBuffer() returned false")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(sub.Buf, []byte("hello")) {
|
||||||
|
t.Errorf("GetSubBuffer() = %q, want %q", sub.Buf, "hello")
|
||||||
|
}
|
||||||
|
if b.Len() != 6 {
|
||||||
|
t.Errorf("after consume, Len() = %d, want 6", b.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok = b.GetSubBuffer(7, false)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("GetSubBuffer(7) should return false (only 6 bytes left)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestByteBuffer_Skip(t *testing.T) {
|
||||||
|
b := &ByteBuffer{Buf: []byte("abcdef")}
|
||||||
|
|
||||||
|
ok := b.Skip(2)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Skip(2) returned false")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b.Buf, []byte("cdef")) {
|
||||||
|
t.Errorf("after Skip(2), Buf = %q, want %q", b.Buf, "cdef")
|
||||||
|
}
|
||||||
|
|
||||||
|
ok = b.Skip(10)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("Skip(10) should return false")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b.Buf, []byte("cdef")) {
|
||||||
|
t.Errorf("after failed Skip, Buf = %q, want %q (unchanged)", b.Buf, "cdef")
|
||||||
|
}
|
||||||
|
|
||||||
|
ok = b.Skip(4)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Skip(4) returned false")
|
||||||
|
}
|
||||||
|
if b.Len() != 0 {
|
||||||
|
t.Errorf("after Skip all, Len() = %d, want 0", b.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestByteBuffer_Reset(t *testing.T) {
|
||||||
|
b := &ByteBuffer{Buf: []byte("data")}
|
||||||
|
b.Reset()
|
||||||
|
if b.Buf != nil {
|
||||||
|
t.Errorf("after Reset, Buf = %v, want nil", b.Buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestByteBuffer_GetZeroLength(t *testing.T) {
|
||||||
|
b := &ByteBuffer{Buf: []byte("abc")}
|
||||||
|
|
||||||
|
data, ok := b.Get(0, true)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Get(0) returned false")
|
||||||
|
}
|
||||||
|
if len(data) != 0 {
|
||||||
|
t.Errorf("Get(0) len = %d, want 0", len(data))
|
||||||
|
}
|
||||||
|
if b.Len() != 3 {
|
||||||
|
t.Errorf("after Get(0, consume), Len() = %d, want 3 (0-length consume is no-op)", b.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestByteBuffer_GetConsumeDoesNotMutateReturnedSlice(t *testing.T) {
|
||||||
|
b := &ByteBuffer{Buf: []byte("hello")}
|
||||||
|
data, ok := b.Get(5, true)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Get() returned false")
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(data, []byte("hello")) {
|
||||||
|
t.Errorf("Get() returned wrong data: %v", data)
|
||||||
|
}
|
||||||
|
if b.Len() != 0 {
|
||||||
|
t.Errorf("after consume, Len() should be 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
185
analyzer/utils/lsm_test.go
Normal file
185
analyzer/utils/lsm_test.go
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestLinearStateMachine_RunPause(t *testing.T) {
|
||||||
|
callCount := 0
|
||||||
|
lsm := NewLinearStateMachine(
|
||||||
|
func() LSMAction {
|
||||||
|
callCount++
|
||||||
|
return LSMActionPause
|
||||||
|
},
|
||||||
|
)
|
||||||
|
cancelled, done := lsm.Run()
|
||||||
|
if cancelled {
|
||||||
|
t.Error("unexpected cancelled=true")
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
t.Error("unexpected done=true")
|
||||||
|
}
|
||||||
|
if callCount != 1 {
|
||||||
|
t.Errorf("callCount = %d, want 1", callCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinearStateMachine_RunNext(t *testing.T) {
|
||||||
|
callOrder := []int{}
|
||||||
|
lsm := NewLinearStateMachine(
|
||||||
|
func() LSMAction { callOrder = append(callOrder, 1); return LSMActionNext },
|
||||||
|
func() LSMAction { callOrder = append(callOrder, 2); return LSMActionNext },
|
||||||
|
func() LSMAction { callOrder = append(callOrder, 3); return LSMActionNext },
|
||||||
|
)
|
||||||
|
cancelled, done := lsm.Run()
|
||||||
|
if cancelled {
|
||||||
|
t.Error("unexpected cancelled=true")
|
||||||
|
}
|
||||||
|
if !done {
|
||||||
|
t.Error("unexpected done=false")
|
||||||
|
}
|
||||||
|
if len(callOrder) != 3 {
|
||||||
|
t.Fatalf("callOrder len = %d, want 3", len(callOrder))
|
||||||
|
}
|
||||||
|
for i, v := range []int{1, 2, 3} {
|
||||||
|
if callOrder[i] != v {
|
||||||
|
t.Errorf("callOrder[%d] = %d, want %d", i, callOrder[i], v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinearStateMachine_RunReset(t *testing.T) {
|
||||||
|
callCount := 0
|
||||||
|
lsm := NewLinearStateMachine(
|
||||||
|
func() LSMAction {
|
||||||
|
callCount++
|
||||||
|
if callCount == 1 {
|
||||||
|
return LSMActionReset
|
||||||
|
}
|
||||||
|
return LSMActionNext
|
||||||
|
},
|
||||||
|
func() LSMAction { callCount++; return LSMActionNext },
|
||||||
|
)
|
||||||
|
cancelled, done := lsm.Run()
|
||||||
|
if cancelled {
|
||||||
|
t.Error("unexpected cancelled=true")
|
||||||
|
}
|
||||||
|
if !done {
|
||||||
|
t.Error("unexpected done=false")
|
||||||
|
}
|
||||||
|
if callCount != 3 {
|
||||||
|
t.Errorf("callCount = %d, want 3 (step0 reset, step0 next, step1 next)", callCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinearStateMachine_RunCancel(t *testing.T) {
|
||||||
|
callCount := 0
|
||||||
|
lsm := NewLinearStateMachine(
|
||||||
|
func() LSMAction { callCount++; return LSMActionNext },
|
||||||
|
func() LSMAction { callCount++; return LSMActionCancel },
|
||||||
|
func() LSMAction { callCount++; return LSMActionNext },
|
||||||
|
)
|
||||||
|
cancelled, done := lsm.Run()
|
||||||
|
if !cancelled {
|
||||||
|
t.Error("unexpected cancelled=false")
|
||||||
|
}
|
||||||
|
if !done {
|
||||||
|
t.Error("unexpected done=false")
|
||||||
|
}
|
||||||
|
if callCount != 2 {
|
||||||
|
t.Errorf("callCount = %d, want 2 (third step should not execute)", callCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinearStateMachine_RunMixed(t *testing.T) {
|
||||||
|
pauseCount := 0
|
||||||
|
lsm := NewLinearStateMachine(
|
||||||
|
func() LSMAction { return LSMActionNext },
|
||||||
|
func() LSMAction {
|
||||||
|
pauseCount++
|
||||||
|
if pauseCount == 1 {
|
||||||
|
return LSMActionPause
|
||||||
|
}
|
||||||
|
return LSMActionNext
|
||||||
|
},
|
||||||
|
func() LSMAction { return LSMActionNext },
|
||||||
|
func() LSMAction { return LSMActionNext },
|
||||||
|
)
|
||||||
|
cancelled, done := lsm.Run()
|
||||||
|
if cancelled {
|
||||||
|
t.Error("unexpected cancelled=true")
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
t.Error("unexpected done=true on first run")
|
||||||
|
}
|
||||||
|
cancelled, done = lsm.Run()
|
||||||
|
if cancelled {
|
||||||
|
t.Error("unexpected cancelled=true on second run")
|
||||||
|
}
|
||||||
|
if !done {
|
||||||
|
t.Error("unexpected done=false on second run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinearStateMachine_RunEmpty(t *testing.T) {
|
||||||
|
lsm := NewLinearStateMachine()
|
||||||
|
cancelled, done := lsm.Run()
|
||||||
|
if cancelled {
|
||||||
|
t.Error("unexpected cancelled=true")
|
||||||
|
}
|
||||||
|
if !done {
|
||||||
|
t.Error("unexpected done=false for empty LSM")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinearStateMachine_AppendSteps(t *testing.T) {
|
||||||
|
lsm := NewLinearStateMachine(
|
||||||
|
func() LSMAction { return LSMActionNext },
|
||||||
|
)
|
||||||
|
lsm.Run()
|
||||||
|
lsm.AppendSteps(
|
||||||
|
func() LSMAction { return LSMActionNext },
|
||||||
|
)
|
||||||
|
_, done := lsm.Run()
|
||||||
|
if !done {
|
||||||
|
t.Error("unexpected done=false after AppendSteps")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLinearStateMachine_Reset(t *testing.T) {
|
||||||
|
callCount := 0
|
||||||
|
lsm := NewLinearStateMachine(
|
||||||
|
func() LSMAction { callCount++; return LSMActionCancel },
|
||||||
|
)
|
||||||
|
lsm.Run()
|
||||||
|
if !lsm.cancelled {
|
||||||
|
t.Error("expected cancelled=true after cancel")
|
||||||
|
}
|
||||||
|
lsm.Reset()
|
||||||
|
if lsm.cancelled {
|
||||||
|
t.Error("expected cancelled=false after Reset")
|
||||||
|
}
|
||||||
|
if lsm.index != 0 {
|
||||||
|
t.Errorf("expected index=0 after Reset, got %d", lsm.index)
|
||||||
|
}
|
||||||
|
_, done := lsm.Run()
|
||||||
|
if !done {
|
||||||
|
t.Error("expected done=true, step executed again after Reset")
|
||||||
|
}
|
||||||
|
if callCount != 2 {
|
||||||
|
t.Errorf("callCount = %d, want 2 (first run + reset run)", callCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLSMActionConstants(t *testing.T) {
|
||||||
|
if LSMActionPause != 0 {
|
||||||
|
t.Errorf("LSMActionPause = %d, want 0", LSMActionPause)
|
||||||
|
}
|
||||||
|
if LSMActionNext != 1 {
|
||||||
|
t.Errorf("LSMActionNext = %d, want 1", LSMActionNext)
|
||||||
|
}
|
||||||
|
if LSMActionReset != 2 {
|
||||||
|
t.Errorf("LSMActionReset = %d, want 2", LSMActionReset)
|
||||||
|
}
|
||||||
|
if LSMActionCancel != 3 {
|
||||||
|
t.Errorf("LSMActionCancel = %d, want 3", LSMActionCancel)
|
||||||
|
}
|
||||||
|
}
|
||||||
29
analyzer/utils/string_test.go
Normal file
29
analyzer/utils/string_test.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestByteSlicesToStrings(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input [][]byte
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{"nil", nil, []string{}},
|
||||||
|
{"empty", [][]byte{}, []string{}},
|
||||||
|
{"single", [][]byte{[]byte("hello")}, []string{"hello"}},
|
||||||
|
{"multiple", [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, []string{"foo", "bar", "baz"}},
|
||||||
|
{"empty element", [][]byte{[]byte("a"), []byte{}, []byte("b")}, []string{"a", "", "b"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := ByteSlicesToStrings(tt.input)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("ByteSlicesToStrings() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
45
errors_test.go
Normal file
45
errors_test.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package mellaris
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfigError_Error(t *testing.T) {
|
||||||
|
e := ConfigError{Field: "port", Err: errors.New("must be > 0")}
|
||||||
|
want := "invalid config: port: must be > 0"
|
||||||
|
if got := e.Error(); got != want {
|
||||||
|
t.Errorf("Error() = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigError_Error_NilErr(t *testing.T) {
|
||||||
|
e := ConfigError{Field: "host", Err: nil}
|
||||||
|
want := "invalid config: host: %!s(<nil>)"
|
||||||
|
if got := e.Error(); got != want {
|
||||||
|
t.Errorf("Error() = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigError_Unwrap(t *testing.T) {
|
||||||
|
wrapped := errors.New("inner error")
|
||||||
|
e := ConfigError{Field: "field", Err: wrapped}
|
||||||
|
if got := e.Unwrap(); got != wrapped {
|
||||||
|
t.Errorf("Unwrap() = %v, want %v", got, wrapped)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigError_Unwrap_Nil(t *testing.T) {
|
||||||
|
e := ConfigError{Field: "field"}
|
||||||
|
if got := e.Unwrap(); got != nil {
|
||||||
|
t.Errorf("Unwrap() = %v, want nil", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigError_ErrorsIs(t *testing.T) {
|
||||||
|
wrapped := errors.New("inner")
|
||||||
|
e := ConfigError{Field: "x", Err: wrapped}
|
||||||
|
if !errors.Is(e, wrapped) {
|
||||||
|
t.Error("errors.Is should find wrapped error")
|
||||||
|
}
|
||||||
|
}
|
||||||
22
modifier/interface_test.go
Normal file
22
modifier/interface_test.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package modifier
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestErrInvalidPacket_Error(t *testing.T) {
|
||||||
|
e := &ErrInvalidPacket{Err: errors.New("bad checksum")}
|
||||||
|
want := "invalid packet: bad checksum"
|
||||||
|
if got := e.Error(); got != want {
|
||||||
|
t.Errorf("Error() = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestErrInvalidArgs_Error(t *testing.T) {
|
||||||
|
e := &ErrInvalidArgs{Err: errors.New("missing 'a' arg")}
|
||||||
|
want := "invalid args: missing 'a' arg"
|
||||||
|
if got := e.Error(); got != want {
|
||||||
|
t.Errorf("Error() = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
205
modifier/udp/dns_test.go
Normal file
205
modifier/udp/dns_test.go
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.difuse.io/Difuse/Mellaris/modifier"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
)
|
||||||
|
|
||||||
|
func dnsResponseBytes(t *testing.T, id uint16, qtype layers.DNSType, name string) []byte {
|
||||||
|
t.Helper()
|
||||||
|
dns := &layers.DNS{
|
||||||
|
ID: id,
|
||||||
|
QR: true,
|
||||||
|
OpCode: layers.DNSOpCodeQuery,
|
||||||
|
AA: false,
|
||||||
|
RD: true,
|
||||||
|
RA: true,
|
||||||
|
ResponseCode: layers.DNSResponseCodeNoErr,
|
||||||
|
QDCount: 1,
|
||||||
|
Questions: []layers.DNSQuestion{
|
||||||
|
{
|
||||||
|
Name: []byte(name),
|
||||||
|
Type: qtype,
|
||||||
|
Class: layers.DNSClassIN,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
err := gopacket.SerializeLayers(buf, gopacket.SerializeOptions{
|
||||||
|
FixLengths: true,
|
||||||
|
ComputeChecksums: true,
|
||||||
|
}, dns)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to serialize DNS response: %v", err)
|
||||||
|
}
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSModifier_Name(t *testing.T) {
|
||||||
|
m := &DNSModifier{}
|
||||||
|
if m.Name() != "dns" {
|
||||||
|
t.Errorf("Name() = %q, want dns", m.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSModifier_New_NoArgs(t *testing.T) {
|
||||||
|
m := &DNSModifier{}
|
||||||
|
inst, err := m.New(map[string]interface{}{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if inst == nil {
|
||||||
|
t.Fatal("New() returned nil instance")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSModifier_New_WithA(t *testing.T) {
|
||||||
|
m := &DNSModifier{}
|
||||||
|
inst, err := m.New(map[string]interface{}{
|
||||||
|
"a": "192.168.1.1",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
dnsInst, ok := inst.(*dnsModifierInstance)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("New() returned wrong type: %T", inst)
|
||||||
|
}
|
||||||
|
if dnsInst.A == nil || dnsInst.A.String() != "192.168.1.1" {
|
||||||
|
t.Errorf("A = %v, want 192.168.1.1", dnsInst.A)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSModifier_New_WithAAAA(t *testing.T) {
|
||||||
|
m := &DNSModifier{}
|
||||||
|
inst, err := m.New(map[string]interface{}{
|
||||||
|
"aaaa": "2001:db8::1",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
dnsInst, ok := inst.(*dnsModifierInstance)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("New() returned wrong type: %T", inst)
|
||||||
|
}
|
||||||
|
if dnsInst.AAAA == nil || dnsInst.AAAA.String() != "2001:db8::1" {
|
||||||
|
t.Errorf("AAAA = %v, want 2001:db8::1", dnsInst.AAAA)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSModifier_New_InvalidA(t *testing.T) {
|
||||||
|
m := &DNSModifier{}
|
||||||
|
_, err := m.New(map[string]interface{}{
|
||||||
|
"a": "not-an-ip",
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("New() expected error for invalid A")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSModifier_New_InvalidAAAA(t *testing.T) {
|
||||||
|
m := &DNSModifier{}
|
||||||
|
_, err := m.New(map[string]interface{}{
|
||||||
|
"aaaa": "not-an-ip",
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("New() expected error for invalid AAAA")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSModifierInstance_Process_AType(t *testing.T) {
|
||||||
|
m := &DNSModifier{}
|
||||||
|
inst, err := m.New(map[string]interface{}{
|
||||||
|
"a": "10.10.10.10",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsResp := dnsResponseBytes(t, 1, layers.DNSTypeA, "example.com")
|
||||||
|
modified, err := inst.(*dnsModifierInstance).Process(dnsResp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Process() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &layers.DNS{}
|
||||||
|
err = result.DecodeFromBytes(modified, gopacket.NilDecodeFeedback)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode result: %v", err)
|
||||||
|
}
|
||||||
|
if len(result.Answers) != 1 {
|
||||||
|
t.Fatalf("expected 1 answer, got %d", len(result.Answers))
|
||||||
|
}
|
||||||
|
if result.Answers[0].Type != layers.DNSTypeA {
|
||||||
|
t.Errorf("answer type = %d, want A", result.Answers[0].Type)
|
||||||
|
}
|
||||||
|
if result.Answers[0].IP.String() != "10.10.10.10" {
|
||||||
|
t.Errorf("answer IP = %s, want 10.10.10.10", result.Answers[0].IP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSModifierInstance_Process_AAAAType(t *testing.T) {
|
||||||
|
m := &DNSModifier{}
|
||||||
|
inst, err := m.New(map[string]interface{}{
|
||||||
|
"aaaa": "2001:db8::1",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsResp := dnsResponseBytes(t, 1, layers.DNSTypeAAAA, "example.com")
|
||||||
|
modified, err := inst.(*dnsModifierInstance).Process(dnsResp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Process() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &layers.DNS{}
|
||||||
|
err = result.DecodeFromBytes(modified, gopacket.NilDecodeFeedback)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode result: %v", err)
|
||||||
|
}
|
||||||
|
if len(result.Answers) != 1 {
|
||||||
|
t.Fatalf("expected 1 answer, got %d", len(result.Answers))
|
||||||
|
}
|
||||||
|
if result.Answers[0].Type != layers.DNSTypeAAAA {
|
||||||
|
t.Errorf("answer type = %d, want AAAA", result.Answers[0].Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSModifierInstance_Process_NotAResponse(t *testing.T) {
|
||||||
|
inst := &dnsModifierInstance{}
|
||||||
|
query := dnsResponseBytes(t, 1, layers.DNSTypeA, "test.com")
|
||||||
|
query[2] &^= 0x80 // clear QR bit
|
||||||
|
|
||||||
|
_, err := inst.Process(query)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-response")
|
||||||
|
}
|
||||||
|
errPkt, ok := err.(*modifier.ErrInvalidPacket)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected *ErrInvalidPacket, got %T", err)
|
||||||
|
}
|
||||||
|
_ = errPkt
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSModifierInstance_Process_NoQuestions(t *testing.T) {
|
||||||
|
inst := &dnsModifierInstance{}
|
||||||
|
dns := &layers.DNS{
|
||||||
|
ID: 1,
|
||||||
|
QR: true,
|
||||||
|
RD: true,
|
||||||
|
RA: true,
|
||||||
|
ResponseCode: layers.DNSResponseCodeNoErr,
|
||||||
|
}
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
gopacket.SerializeLayers(buf, gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}, dns)
|
||||||
|
|
||||||
|
_, err := inst.Process(buf.Bytes())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty questions")
|
||||||
|
}
|
||||||
|
}
|
||||||
98
ruleset/builtins/cidr_test.go
Normal file
98
ruleset/builtins/cidr_test.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package builtins
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCompileCIDR(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantErr bool
|
||||||
|
wantStr string
|
||||||
|
}{
|
||||||
|
{"valid ipv4", "192.168.0.0/24", false, "192.168.0.0/24"},
|
||||||
|
{"valid ipv6", "2001:db8::/32", false, "2001:db8::/32"},
|
||||||
|
{"valid host ipv4", "10.0.0.1/32", false, "10.0.0.1/32"},
|
||||||
|
{"valid host ipv6", "::1/128", false, "::1/128"},
|
||||||
|
{"invalid no mask", "192.168.0.0", true, ""},
|
||||||
|
{"invalid bad ip", "not-an-ip/24", true, ""},
|
||||||
|
{"invalid empty", "", true, ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := CompileCIDR(tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("CompileCIDR(%q) expected error, got nil", tt.input)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CompileCIDR(%q) unexpected error: %v", tt.input, err)
|
||||||
|
}
|
||||||
|
if got.String() != tt.wantStr {
|
||||||
|
t.Errorf("CompileCIDR(%q) = %q, want %q", tt.input, got.String(), tt.wantStr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchCIDR(t *testing.T) {
|
||||||
|
cidr := mustCompileCIDR(t, "192.168.0.0/24")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ip string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"inside", "192.168.0.1", true},
|
||||||
|
{"boundary low", "192.168.0.0", true},
|
||||||
|
{"boundary high", "192.168.0.255", true},
|
||||||
|
{"outside", "192.168.1.1", false},
|
||||||
|
{"different network", "10.0.0.1", false},
|
||||||
|
{"invalid ip", "not-an-ip", false},
|
||||||
|
{"empty", "", false},
|
||||||
|
{"ipv6 in ipv4", "::1", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := MatchCIDR(tt.ip, cidr)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("MatchCIDR(%q, %q) = %v, want %v", tt.ip, cidr, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchCIDR_IPv6(t *testing.T) {
|
||||||
|
cidr := mustCompileCIDR(t, "2001:db8::/32")
|
||||||
|
|
||||||
|
inside := "2001:db8::1"
|
||||||
|
if !MatchCIDR(inside, cidr) {
|
||||||
|
t.Errorf("MatchCIDR(%q) should be true", inside)
|
||||||
|
}
|
||||||
|
|
||||||
|
outside := "2001:db9::1"
|
||||||
|
if MatchCIDR(outside, cidr) {
|
||||||
|
t.Errorf("MatchCIDR(%q) should be false", outside)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchCIDR_NullResult(t *testing.T) {
|
||||||
|
if MatchCIDR("10.0.0.1", &net.IPNet{}) {
|
||||||
|
t.Error("MatchCIDR with empty IPNet should return false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustCompileCIDR(t *testing.T, cidr string) *net.IPNet {
|
||||||
|
t.Helper()
|
||||||
|
_, ipNet, err := net.ParseCIDR(cidr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse CIDR %q: %v", cidr, err)
|
||||||
|
}
|
||||||
|
return ipNet
|
||||||
|
}
|
||||||
115
ruleset/builtins/geo/geo_matcher_test.go
Normal file
115
ruleset/builtins/geo/geo_matcher_test.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package geo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeGeoLoader struct {
|
||||||
|
geoip map[string]*v2geo.GeoIP
|
||||||
|
geosite map[string]*v2geo.GeoSite
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *fakeGeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) {
|
||||||
|
return l.geoip, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *fakeGeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) {
|
||||||
|
return l.geosite, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeoMatcher_MatchGeoIp_Cached(t *testing.T) {
|
||||||
|
loader := &fakeGeoLoader{
|
||||||
|
geoip: map[string]*v2geo.GeoIP{
|
||||||
|
"us": {
|
||||||
|
Cidr: []*v2geo.CIDR{
|
||||||
|
{Ip: ipv4(8, 8, 8, 0), Prefix: 24},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
g := NewGeoMatcher("", "")
|
||||||
|
g.geoLoader = loader
|
||||||
|
|
||||||
|
if !g.MatchGeoIp("8.8.8.8", "US") {
|
||||||
|
t.Error("MatchGeoIp should match 8.8.8.8 in US range")
|
||||||
|
}
|
||||||
|
if g.MatchGeoIp("9.9.9.9", "US") {
|
||||||
|
t.Error("MatchGeoIp should not match 9.9.9.9 in US range")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeoMatcher_MatchGeoIp_EmptyCondition(t *testing.T) {
|
||||||
|
g := NewGeoMatcher("", "")
|
||||||
|
if g.MatchGeoIp("1.2.3.4", "") {
|
||||||
|
t.Error("MatchGeoIp with empty condition should return false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeoMatcher_MatchGeoIp_InvalidIP(t *testing.T) {
|
||||||
|
g := NewGeoMatcher("", "")
|
||||||
|
if g.MatchGeoIp("not-an-ip", "us") {
|
||||||
|
t.Error("MatchGeoIp with invalid IP should return false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeoMatcher_MatchGeoIp_MissingCountry(t *testing.T) {
|
||||||
|
loader := &fakeGeoLoader{
|
||||||
|
geoip: map[string]*v2geo.GeoIP{},
|
||||||
|
}
|
||||||
|
g := NewGeoMatcher("", "")
|
||||||
|
g.geoLoader = loader
|
||||||
|
|
||||||
|
if g.MatchGeoIp("8.8.8.8", "us") {
|
||||||
|
t.Error("MatchGeoIp for missing country should return false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeoMatcher_MatchGeoSite(t *testing.T) {
|
||||||
|
loader := &fakeGeoLoader{
|
||||||
|
geosite: map[string]*v2geo.GeoSite{
|
||||||
|
"openai": {
|
||||||
|
Domain: []*v2geo.Domain{
|
||||||
|
{Type: v2geo.Domain_Plain, Value: "openai"},
|
||||||
|
{Type: v2geo.Domain_Full, Value: "chatgpt.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
g := NewGeoMatcher("", "")
|
||||||
|
g.geoLoader = loader
|
||||||
|
|
||||||
|
if !g.MatchGeoSite("api.openai.com", "openai") {
|
||||||
|
t.Error("MatchGeoSite should match via plain domain")
|
||||||
|
}
|
||||||
|
if !g.MatchGeoSite("chatgpt.com", "openai") {
|
||||||
|
t.Error("MatchGeoSite should match via full domain")
|
||||||
|
}
|
||||||
|
if g.MatchGeoSite("google.com", "openai") {
|
||||||
|
t.Error("MatchGeoSite should not match unrelated host")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeoMatcher_MatchGeoSite_EmptyCondition(t *testing.T) {
|
||||||
|
g := NewGeoMatcher("", "")
|
||||||
|
if g.MatchGeoSite("test.com", "") {
|
||||||
|
t.Error("MatchGeoSite with empty condition should return false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeoMatcher_MatchGeoSite_MissingSite(t *testing.T) {
|
||||||
|
loader := &fakeGeoLoader{
|
||||||
|
geosite: map[string]*v2geo.GeoSite{},
|
||||||
|
}
|
||||||
|
g := NewGeoMatcher("", "")
|
||||||
|
g.geoLoader = loader
|
||||||
|
|
||||||
|
if g.MatchGeoSite("test.com", "nonexistent") {
|
||||||
|
t.Error("MatchGeoSite for missing site should return false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ipv4(a, b, c, d byte) []byte {
|
||||||
|
return []byte{a, b, c, d}
|
||||||
|
}
|
||||||
BIN
ruleset/builtins/geo/geoip.dat
Normal file
BIN
ruleset/builtins/geo/geoip.dat
Normal file
Binary file not shown.
324
ruleset/builtins/geo/matchers_v2geo_test.go
Normal file
324
ruleset/builtins/geo/matchers_v2geo_test.go
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
package geo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseGeoSiteName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
wantBase string
|
||||||
|
wantAttrs []string
|
||||||
|
}{
|
||||||
|
{"google", "google", nil},
|
||||||
|
{"google@ads", "google", []string{"ads"}},
|
||||||
|
{"google@ads@news", "google", []string{"ads", "news"}},
|
||||||
|
{" google ", "google", nil},
|
||||||
|
{" google @ ads ", "google", []string{"ads"}},
|
||||||
|
{"openai@ ads @ news ", "openai", []string{"ads", "news"}},
|
||||||
|
{"@onlyattrs", "", []string{"onlyattrs"}},
|
||||||
|
{"", "", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
base, attrs := parseGeoSiteName(tt.input)
|
||||||
|
if base != tt.wantBase {
|
||||||
|
t.Errorf("parseGeoSiteName(%q) base = %q, want %q", tt.input, base, tt.wantBase)
|
||||||
|
}
|
||||||
|
if len(attrs) != len(tt.wantAttrs) {
|
||||||
|
t.Fatalf("parseGeoSiteName(%q) attrs len = %d, want %d", tt.input, len(attrs), len(tt.wantAttrs))
|
||||||
|
}
|
||||||
|
for i, attr := range attrs {
|
||||||
|
if attr != tt.wantAttrs[i] {
|
||||||
|
t.Errorf("parseGeoSiteName(%q) attrs[%d] = %q, want %q", tt.input, i, attr, tt.wantAttrs[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostInfo_String(t *testing.T) {
|
||||||
|
h := HostInfo{
|
||||||
|
Name: "example.com",
|
||||||
|
IPv4: net.ParseIP("1.2.3.4"),
|
||||||
|
IPv6: net.ParseIP("::1"),
|
||||||
|
}
|
||||||
|
want := "example.com|1.2.3.4|::1"
|
||||||
|
if got := h.String(); got != want {
|
||||||
|
t.Errorf("HostInfo.String() = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostInfo_String_Partial(t *testing.T) {
|
||||||
|
h := HostInfo{
|
||||||
|
Name: "test.com",
|
||||||
|
IPv4: net.ParseIP("10.0.0.1"),
|
||||||
|
}
|
||||||
|
want := "test.com|10.0.0.1|<nil>"
|
||||||
|
if got := h.String(); got != want {
|
||||||
|
t.Errorf("HostInfo.String() = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeoipMatcher_Match(t *testing.T) {
|
||||||
|
_, n4, _ := net.ParseCIDR("10.0.0.0/8")
|
||||||
|
_, n4_2, _ := net.ParseCIDR("192.168.0.0/16")
|
||||||
|
m := &geoipMatcher{
|
||||||
|
N4: []*net.IPNet{n4, n4_2},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
host HostInfo
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"ipv4 match", HostInfo{IPv4: net.ParseIP("10.1.2.3")}, true},
|
||||||
|
{"ipv4 no match", HostInfo{IPv4: net.ParseIP("172.16.0.1")}, false},
|
||||||
|
{"ipv4 match second net", HostInfo{IPv4: net.ParseIP("192.168.1.1")}, true},
|
||||||
|
{"no ip", HostInfo{}, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := m.Match(tt.host); got != tt.want {
|
||||||
|
t.Errorf("Match() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeoipMatcher_Match_Inverse(t *testing.T) {
|
||||||
|
_, n4, _ := net.ParseCIDR("10.0.0.0/8")
|
||||||
|
m := &geoipMatcher{
|
||||||
|
N4: []*net.IPNet{n4},
|
||||||
|
Inverse: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Match(HostInfo{IPv4: net.ParseIP("10.1.2.3")}) {
|
||||||
|
t.Error("Inverse: inside range should return false")
|
||||||
|
}
|
||||||
|
if !m.Match(HostInfo{IPv4: net.ParseIP("172.16.0.1")}) {
|
||||||
|
t.Error("Inverse: outside range should return true")
|
||||||
|
}
|
||||||
|
if !m.Match(HostInfo{}) {
|
||||||
|
t.Error("Inverse: no IP should return true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeoipMatcher_Match_IPv6(t *testing.T) {
|
||||||
|
_, n6, _ := net.ParseCIDR("2001:db8::/32")
|
||||||
|
m := &geoipMatcher{
|
||||||
|
N6: []*net.IPNet{n6},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.Match(HostInfo{IPv6: net.ParseIP("2001:db8::1")}) {
|
||||||
|
t.Error("IPv6 match failed")
|
||||||
|
}
|
||||||
|
if m.Match(HostInfo{IPv6: net.ParseIP("2001:db9::1")}) {
|
||||||
|
t.Error("IPv6 should not match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeositeMatcher_matchDomain_Plain(t *testing.T) {
|
||||||
|
m := &geositeMatcher{}
|
||||||
|
d := geositeDomain{
|
||||||
|
Type: geositeDomainPlain,
|
||||||
|
Value: "openai",
|
||||||
|
}
|
||||||
|
if !m.matchDomain(d, HostInfo{Name: "api.openai.com"}) {
|
||||||
|
t.Error("plain domain should match via substring")
|
||||||
|
}
|
||||||
|
if m.matchDomain(d, HostInfo{Name: "google.com"}) {
|
||||||
|
t.Error("plain domain should not match unrelated host")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeositeMatcher_matchDomain_Full(t *testing.T) {
|
||||||
|
m := &geositeMatcher{}
|
||||||
|
d := geositeDomain{
|
||||||
|
Type: geositeDomainFull,
|
||||||
|
Value: "example.com",
|
||||||
|
}
|
||||||
|
if !m.matchDomain(d, HostInfo{Name: "example.com"}) {
|
||||||
|
t.Error("full domain should match exact")
|
||||||
|
}
|
||||||
|
if m.matchDomain(d, HostInfo{Name: "www.example.com"}) {
|
||||||
|
t.Error("full domain should not match subdomain")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeositeMatcher_matchDomain_Root(t *testing.T) {
|
||||||
|
m := &geositeMatcher{}
|
||||||
|
d := geositeDomain{
|
||||||
|
Type: geositeDomainRoot,
|
||||||
|
Value: "example.com",
|
||||||
|
}
|
||||||
|
if !m.matchDomain(d, HostInfo{Name: "example.com"}) {
|
||||||
|
t.Error("root domain should match exact")
|
||||||
|
}
|
||||||
|
if !m.matchDomain(d, HostInfo{Name: "www.example.com"}) {
|
||||||
|
t.Error("root domain should match subdomain")
|
||||||
|
}
|
||||||
|
if m.matchDomain(d, HostInfo{Name: "www.example.com.au"}) {
|
||||||
|
t.Error("root domain should not match unrelated suffix")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeositeMatcher_matchDomain_Attrs(t *testing.T) {
|
||||||
|
m := &geositeMatcher{Attrs: []string{"ads"}}
|
||||||
|
d := geositeDomain{
|
||||||
|
Type: geositeDomainPlain,
|
||||||
|
Value: "google",
|
||||||
|
Attrs: map[string]bool{"ads": true},
|
||||||
|
}
|
||||||
|
if !m.matchDomain(d, HostInfo{Name: "google.com"}) {
|
||||||
|
t.Error("should match when domain has required attr")
|
||||||
|
}
|
||||||
|
|
||||||
|
dNoAttrs := geositeDomain{
|
||||||
|
Type: geositeDomainPlain,
|
||||||
|
Value: "google",
|
||||||
|
Attrs: map[string]bool{},
|
||||||
|
}
|
||||||
|
if m.matchDomain(dNoAttrs, HostInfo{Name: "google.com"}) {
|
||||||
|
t.Error("should not match when domain lacks required attr")
|
||||||
|
}
|
||||||
|
|
||||||
|
dOtherAttrs := geositeDomain{
|
||||||
|
Type: geositeDomainPlain,
|
||||||
|
Value: "google",
|
||||||
|
Attrs: map[string]bool{"news": true},
|
||||||
|
}
|
||||||
|
if m.matchDomain(dOtherAttrs, HostInfo{Name: "google.com"}) {
|
||||||
|
t.Error("should not match when domain has wrong attr")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeositeMatcher_Match(t *testing.T) {
|
||||||
|
m := &geositeMatcher{
|
||||||
|
Domains: []geositeDomain{
|
||||||
|
{Type: geositeDomainFull, Value: "exact.com"},
|
||||||
|
{Type: geositeDomainPlain, Value: "partial"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if !m.Match(HostInfo{Name: "exact.com"}) {
|
||||||
|
t.Error("should match full domain")
|
||||||
|
}
|
||||||
|
if !m.Match(HostInfo{Name: "www.partial.net"}) {
|
||||||
|
t.Error("should match partial domain")
|
||||||
|
}
|
||||||
|
if m.Match(HostInfo{Name: "other.net"}) {
|
||||||
|
t.Error("should not match unrelated host")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDomainAttributeToMap(t *testing.T) {
|
||||||
|
attrs := []*v2geo.Domain_Attribute{
|
||||||
|
{Key: "ads"},
|
||||||
|
{Key: "news"},
|
||||||
|
}
|
||||||
|
got := domainAttributeToMap(attrs)
|
||||||
|
if len(got) != 2 || !got["ads"] || !got["news"] {
|
||||||
|
t.Errorf("domainAttributeToMap = %v, want {ads:true, news:true}", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
got2 := domainAttributeToMap(nil)
|
||||||
|
if len(got2) != 0 {
|
||||||
|
t.Errorf("domainAttributeToMap(nil) = %v, want empty map", got2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewGeoIPMatcher(t *testing.T) {
|
||||||
|
list := &v2geo.GeoIP{
|
||||||
|
Cidr: []*v2geo.CIDR{
|
||||||
|
{Ip: net.IPv4(10, 0, 0, 0).To4(), Prefix: 8},
|
||||||
|
{Ip: net.IPv4(192, 168, 0, 0).To4(), Prefix: 16},
|
||||||
|
},
|
||||||
|
InverseMatch: false,
|
||||||
|
}
|
||||||
|
m, err := newGeoIPMatcher(list)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newGeoIPMatcher error: %v", err)
|
||||||
|
}
|
||||||
|
if len(m.N4) != 2 {
|
||||||
|
t.Errorf("expected 2 IPv4 nets, got %d", len(m.N4))
|
||||||
|
}
|
||||||
|
if m.Inverse {
|
||||||
|
t.Error("Inverse should be false")
|
||||||
|
}
|
||||||
|
// Verify sorted order: 10.0.0.0/8 < 192.168.0.0/16
|
||||||
|
if m.N4[0].IP.String() != "10.0.0.0" {
|
||||||
|
t.Errorf("N4[0] = %s, want 10.0.0.0", m.N4[0].IP)
|
||||||
|
}
|
||||||
|
if m.N4[1].IP.String() != "192.168.0.0" {
|
||||||
|
t.Errorf("N4[1] = %s, want 192.168.0.0", m.N4[1].IP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewGeoIPMatcher_IPv6(t *testing.T) {
|
||||||
|
list := &v2geo.GeoIP{
|
||||||
|
Cidr: []*v2geo.CIDR{
|
||||||
|
{Ip: net.ParseIP("2001:db8::"), Prefix: 32},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m, err := newGeoIPMatcher(list)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newGeoIPMatcher error: %v", err)
|
||||||
|
}
|
||||||
|
if len(m.N6) != 1 {
|
||||||
|
t.Errorf("expected 1 IPv6 net, got %d", len(m.N6))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewGeoIPMatcher_InvalidIPLength(t *testing.T) {
|
||||||
|
list := &v2geo.GeoIP{
|
||||||
|
Cidr: []*v2geo.CIDR{
|
||||||
|
{Ip: []byte{1, 2, 3}, Prefix: 24},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, err := newGeoIPMatcher(list)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid IP length")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewGeositeMatcher(t *testing.T) {
|
||||||
|
list := &v2geo.GeoSite{
|
||||||
|
Domain: []*v2geo.Domain{
|
||||||
|
{Type: v2geo.Domain_Plain, Value: "google"},
|
||||||
|
{Type: v2geo.Domain_Full, Value: "exact.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m, err := newGeositeMatcher(list, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newGeositeMatcher error: %v", err)
|
||||||
|
}
|
||||||
|
if len(m.Domains) != 2 {
|
||||||
|
t.Errorf("expected 2 domains, got %d", len(m.Domains))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewGeositeMatcher_WithAttrs(t *testing.T) {
|
||||||
|
list := &v2geo.GeoSite{
|
||||||
|
Domain: []*v2geo.Domain{
|
||||||
|
{
|
||||||
|
Type: v2geo.Domain_RootDomain,
|
||||||
|
Value: "google.com",
|
||||||
|
Attribute: []*v2geo.Domain_Attribute{
|
||||||
|
{Key: "ads"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m, err := newGeositeMatcher(list, []string{"ads"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newGeositeMatcher error: %v", err)
|
||||||
|
}
|
||||||
|
if !m.Match(HostInfo{Name: "www.google.com"}) {
|
||||||
|
t.Error("should match with root domain and attr")
|
||||||
|
}
|
||||||
|
}
|
||||||
115
ruleset/expr.go
115
ruleset/expr.go
@@ -31,6 +31,9 @@ type ExprRule struct {
|
|||||||
Log bool `yaml:"log"`
|
Log bool `yaml:"log"`
|
||||||
Modifier ModifierEntry `yaml:"modifier"`
|
Modifier ModifierEntry `yaml:"modifier"`
|
||||||
Expr string `yaml:"expr"`
|
Expr string `yaml:"expr"`
|
||||||
|
StartTime string `yaml:"start_time"`
|
||||||
|
StopTime string `yaml:"stop_time"`
|
||||||
|
Weekdays []string `yaml:"weekdays"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModifierEntry struct {
|
type ModifierEntry struct {
|
||||||
@@ -56,6 +59,10 @@ type compiledExprRule struct {
|
|||||||
ModInstance modifier.Instance
|
ModInstance modifier.Instance
|
||||||
Program *vm.Program
|
Program *vm.Program
|
||||||
GeoSiteConditions []string
|
GeoSiteConditions []string
|
||||||
|
StartTimeSecs int // seconds since midnight, -1 if unset
|
||||||
|
StopTimeSecs int // seconds since midnight, -1 if unset
|
||||||
|
Weekdays []time.Weekday
|
||||||
|
WeekdaysNegated bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Ruleset = (*exprRuleset)(nil)
|
var _ Ruleset = (*exprRuleset)(nil)
|
||||||
@@ -73,10 +80,13 @@ func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
|
|||||||
|
|
||||||
func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
func (r *exprRuleset) Match(info StreamInfo) MatchResult {
|
||||||
env := streamInfoToExprEnv(info)
|
env := streamInfoToExprEnv(info)
|
||||||
|
now := time.Now()
|
||||||
for _, rule := range r.Rules {
|
for _, rule := range r.Rules {
|
||||||
|
if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
v, err := vm.Run(rule.Program, env)
|
v, err := vm.Run(rule.Program, env)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log the error and continue to the next rule.
|
|
||||||
r.Logger.MatchError(info, rule.Name, err)
|
r.Logger.MatchError(info, rule.Name, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -163,12 +173,34 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
|||||||
depAnMap[name] = a
|
depAnMap[name] = a
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
startSecs := -1
|
||||||
|
if rule.StartTime != "" {
|
||||||
|
startSecs, err = parseTimeOfDay(rule.StartTime)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("rule %q has invalid start_time: %w", rule.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stopSecs := -1
|
||||||
|
if rule.StopTime != "" {
|
||||||
|
stopSecs, err = parseTimeOfDay(rule.StopTime)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("rule %q has invalid stop_time: %w", rule.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
weekdays, weekdaysNegated, err := parseWeekdays(rule.Weekdays)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("rule %q has invalid weekdays: %w", rule.Name, err)
|
||||||
|
}
|
||||||
cr := compiledExprRule{
|
cr := compiledExprRule{
|
||||||
Name: rule.Name,
|
Name: rule.Name,
|
||||||
Action: action,
|
Action: action,
|
||||||
Log: rule.Log,
|
Log: rule.Log,
|
||||||
Program: program,
|
Program: program,
|
||||||
GeoSiteConditions: extractGeoSiteConditions(rule.Expr),
|
GeoSiteConditions: extractGeoSiteConditions(rule.Expr),
|
||||||
|
StartTimeSecs: startSecs,
|
||||||
|
StopTimeSecs: stopSecs,
|
||||||
|
Weekdays: weekdays,
|
||||||
|
WeekdaysNegated: weekdaysNegated,
|
||||||
}
|
}
|
||||||
if action != nil && *action == ActionModify {
|
if action != nil && *action == ActionModify {
|
||||||
mod, ok := fullModMap[rule.Modifier.Name]
|
mod, ok := fullModMap[rule.Modifier.Name]
|
||||||
@@ -391,6 +423,87 @@ func buildFunctionMap(config *BuiltinConfig) (map[string]*Function, *geo.GeoMatc
|
|||||||
}, geoMatcher
|
}, geoMatcher
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func matchTime(now time.Time, startSecs, stopSecs int, weekdays []time.Weekday, negated bool) bool {
|
||||||
|
if startSecs >= 0 || stopSecs >= 0 {
|
||||||
|
currentSecs := now.Hour()*3600 + now.Minute()*60 + now.Second()
|
||||||
|
if startSecs >= 0 && stopSecs >= 0 {
|
||||||
|
if startSecs <= stopSecs {
|
||||||
|
if currentSecs < startSecs || currentSecs > stopSecs {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if currentSecs < startSecs && currentSecs > stopSecs {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if startSecs >= 0 {
|
||||||
|
if currentSecs < startSecs {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} else if currentSecs > stopSecs {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(weekdays) > 0 {
|
||||||
|
current := now.Weekday()
|
||||||
|
found := false
|
||||||
|
for _, d := range weekdays {
|
||||||
|
if current == d {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if negated == found {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTimeOfDay(s string) (int, error) {
|
||||||
|
t, err := time.Parse("15:04:05", s)
|
||||||
|
if err != nil {
|
||||||
|
return -1, fmt.Errorf("invalid time %q (expected hh:mm:ss)", s)
|
||||||
|
}
|
||||||
|
return t.Hour()*3600 + t.Minute()*60 + t.Second(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseWeekdays(days []string) ([]time.Weekday, bool, error) {
|
||||||
|
if len(days) == 0 {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
negated := false
|
||||||
|
parsed := make([]time.Weekday, 0, len(days))
|
||||||
|
for i, d := range days {
|
||||||
|
d = strings.TrimSpace(d)
|
||||||
|
if i == 0 && strings.HasPrefix(d, "!") {
|
||||||
|
negated = true
|
||||||
|
d = strings.TrimSpace(strings.TrimPrefix(d, "!"))
|
||||||
|
}
|
||||||
|
var wd time.Weekday
|
||||||
|
switch strings.ToLower(d) {
|
||||||
|
case "sun", "sunday":
|
||||||
|
wd = time.Sunday
|
||||||
|
case "mon", "monday":
|
||||||
|
wd = time.Monday
|
||||||
|
case "tue", "tues", "tuesday":
|
||||||
|
wd = time.Tuesday
|
||||||
|
case "wed", "wednesday":
|
||||||
|
wd = time.Wednesday
|
||||||
|
case "thu", "thur", "thurs", "thursday":
|
||||||
|
wd = time.Thursday
|
||||||
|
case "fri", "friday":
|
||||||
|
wd = time.Friday
|
||||||
|
case "sat", "saturday":
|
||||||
|
wd = time.Saturday
|
||||||
|
default:
|
||||||
|
return nil, false, fmt.Errorf("invalid weekday %q", d)
|
||||||
|
}
|
||||||
|
parsed = append(parsed, wd)
|
||||||
|
}
|
||||||
|
return parsed, negated, nil
|
||||||
|
}
|
||||||
|
|
||||||
const rulesetLogMetaKey = "_ruleset"
|
const rulesetLogMetaKey = "_ruleset"
|
||||||
|
|
||||||
func addGeoSiteLogMetadata(info StreamInfo, gm *geo.GeoMatcher, conditions []string) StreamInfo {
|
func addGeoSiteLogMetadata(info StreamInfo, gm *geo.GeoMatcher, conditions []string) StreamInfo {
|
||||||
|
|||||||
121
ruleset/interface_test.go
Normal file
121
ruleset/interface_test.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package ruleset
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAction_String(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
action Action
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{ActionMaybe, "maybe"},
|
||||||
|
{ActionAllow, "allow"},
|
||||||
|
{ActionBlock, "block"},
|
||||||
|
{ActionDrop, "drop"},
|
||||||
|
{ActionModify, "modify"},
|
||||||
|
{Action(99), "unknown"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.want, func(t *testing.T) {
|
||||||
|
if got := tt.action.String(); got != tt.want {
|
||||||
|
t.Errorf("Action.String() = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProtocol_String(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
protocol Protocol
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{ProtocolTCP, "tcp"},
|
||||||
|
{ProtocolUDP, "udp"},
|
||||||
|
{Protocol(99), "unknown"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.want, func(t *testing.T) {
|
||||||
|
if got := tt.protocol.String(); got != tt.want {
|
||||||
|
t.Errorf("Protocol.String() = %q, want %q", got, tt.protocol)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProtocol_Constants(t *testing.T) {
|
||||||
|
if ProtocolTCP != 0 {
|
||||||
|
t.Errorf("ProtocolTCP = %d, want 0", ProtocolTCP)
|
||||||
|
}
|
||||||
|
if ProtocolUDP != 1 {
|
||||||
|
t.Errorf("ProtocolUDP = %d, want 1", ProtocolUDP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAction_Constants(t *testing.T) {
|
||||||
|
if ActionMaybe != 0 {
|
||||||
|
t.Errorf("ActionMaybe = %d, want 0", ActionMaybe)
|
||||||
|
}
|
||||||
|
if ActionAllow != 1 {
|
||||||
|
t.Errorf("ActionAllow = %d, want 1", ActionAllow)
|
||||||
|
}
|
||||||
|
if ActionBlock != 2 {
|
||||||
|
t.Errorf("ActionBlock = %d, want 2", ActionBlock)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamInfo_SrcString(t *testing.T) {
|
||||||
|
info := StreamInfo{
|
||||||
|
SrcIP: net.ParseIP("192.168.1.1"),
|
||||||
|
SrcPort: 8080,
|
||||||
|
}
|
||||||
|
want := "192.168.1.1:8080"
|
||||||
|
if got := info.SrcString(); got != want {
|
||||||
|
t.Errorf("SrcString() = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamInfo_DstString(t *testing.T) {
|
||||||
|
info := StreamInfo{
|
||||||
|
DstIP: net.ParseIP("10.0.0.1"),
|
||||||
|
DstPort: 443,
|
||||||
|
}
|
||||||
|
want := "10.0.0.1:443"
|
||||||
|
if got := info.DstString(); got != want {
|
||||||
|
t.Errorf("DstString() = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamInfo_SrcString_IPv6(t *testing.T) {
|
||||||
|
info := StreamInfo{
|
||||||
|
SrcIP: net.ParseIP("::1"),
|
||||||
|
SrcPort: 53,
|
||||||
|
}
|
||||||
|
want := "[::1]:53"
|
||||||
|
if got := info.SrcString(); got != want {
|
||||||
|
t.Errorf("SrcString() = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchResult_ZeroValue(t *testing.T) {
|
||||||
|
var mr MatchResult
|
||||||
|
if mr.Action != ActionMaybe {
|
||||||
|
t.Errorf("zero MatchResult.Action = %v, want ActionMaybe (0)", mr.Action)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamInfo_PropsInitialization(t *testing.T) {
|
||||||
|
info := StreamInfo{
|
||||||
|
Props: analyzer.CombinedPropMap{
|
||||||
|
"tls": analyzer.PropMap{"sni": "example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if info.Props.Get("tls", "sni") != "example.com" {
|
||||||
|
t.Error("StreamInfo.Props not properly initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user