test: improve coverage across package

This commit is contained in:
2026-05-01 14:09:10 +05:30
parent e1c68ec7d0
commit e3f1f5046a
18 changed files with 2652 additions and 6 deletions

124
analyzer/interface_test.go Normal file
View File

@@ -0,0 +1,124 @@
package analyzer
import (
"reflect"
"testing"
)
func TestPropMap_Get(t *testing.T) {
m := PropMap{
"a": "value-a",
"nested": PropMap{
"b": "value-b",
"deep": PropMap{
"c": "value-c",
},
},
}
tests := []struct {
name string
key string
want interface{}
}{
{"simple key", "a", "value-a"},
{"nested key", "nested.b", "value-b"},
{"deeply nested", "nested.deep.c", "value-c"},
{"non-existent top", "x", nil},
{"non-existent nested", "nested.x", nil},
{"empty key", "", nil},
{"partial path", "nested", PropMap{"b": "value-b", "deep": PropMap{"c": "value-c"}}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := m.Get(tt.key)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("PropMap.Get(%q) = %v, want %v", tt.key, got, tt.want)
}
})
}
}
func TestPropMap_Get_NilMap(t *testing.T) {
var m PropMap
if got := m.Get("any"); got != nil {
t.Errorf("nil PropMap.Get() = %v, want nil", got)
}
}
func TestPropMap_Get_EmptyMap(t *testing.T) {
m := PropMap{}
if got := m.Get("any"); got != nil {
t.Errorf("empty PropMap.Get() = %v, want nil", got)
}
}
func TestPropMap_Get_IntermediateNonMap(t *testing.T) {
m := PropMap{
"a": "hello",
}
if got := m.Get("a.b.c"); got != nil {
t.Errorf("PropMap.Get() through non-map = %v, want nil", got)
}
}
func TestCombinedPropMap_Get(t *testing.T) {
cm := CombinedPropMap{
"tls": PropMap{
"req": PropMap{
"sni": "example.com",
},
},
"http": PropMap{
"resp": PropMap{
"status": 200,
},
},
}
tests := []struct {
name string
analyzer string
key string
want interface{}
}{
{"tls sni", "tls", "req.sni", "example.com"},
{"http status", "http", "resp.status", 200},
{"unknown analyzer", "dns", "any", nil},
{"valid analyzer bad key", "tls", "req.nonexistent", nil},
{"empty analyzer", "", "key", nil},
{"empty key", "tls", "", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := cm.Get(tt.analyzer, tt.key)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("CombinedPropMap.Get(%q, %q) = %v, want %v", tt.analyzer, tt.key, got, tt.want)
}
})
}
}
func TestCombinedPropMap_Get_Nil(t *testing.T) {
var cm CombinedPropMap
if got := cm.Get("any", "key"); got != nil {
t.Errorf("nil CombinedPropMap.Get() = %v, want nil", got)
}
}
func TestPropUpdateType_Values(t *testing.T) {
if PropUpdateNone != 0 {
t.Errorf("PropUpdateNone = %d, want 0", PropUpdateNone)
}
if PropUpdateMerge != 1 {
t.Errorf("PropUpdateMerge = %d, want 1", PropUpdateMerge)
}
if PropUpdateReplace != 2 {
t.Errorf("PropUpdateReplace = %d, want 2", PropUpdateReplace)
}
if PropUpdateDelete != 3 {
t.Errorf("PropUpdateDelete = %d, want 3", PropUpdateDelete)
}
}

View File

@@ -0,0 +1,319 @@
package internal
import (
"reflect"
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer/utils"
)
func buildClientHelloMsg(t *testing.T) *utils.ByteBuffer {
t.Helper()
// Bytes taken from the standard TLS test vector, starting after the record
// header (5 bytes) and handshake header (4 bytes).
body := []byte{
0x03, 0x03, // version TLS 1.2
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
0x00, // session ID length = 0
0x00, 0x20, // cipher suites length = 32 bytes = 16 suites
0xcc, 0xa8, 0xcc, 0xa9, 0xc0, 0x2f, 0xc0, 0x30,
0xc0, 0x2b, 0xc0, 0x2c, 0xc0, 0x13, 0xc0, 0x09,
0xc0, 0x14, 0xc0, 0x0a, 0x00, 0x9c, 0x00, 0x9d,
0x00, 0x2f, 0x00, 0x35, 0xc0, 0x12, 0x00, 0x0a, // ciphers
0x01, // compression methods length = 1
0x00, // compression method = null
0x00, 0x58, // extensions length = 88
// extension: server_name
0x00, 0x00, // type = server_name
0x00, 0x18, // length = 24
0x00, 0x16, // server name list length = 22
0x00, // name type = hostname
0x00, 0x13, // name length = 19
'e', 'x', 'a', 'm', 'p', 'l', 'e', '.', 'u', 'l', 'f', 'h', 'e', 'i', 'm', '.', 'n', 'e', 't',
// extension: status_request
0x00, 0x05, // type = status_request
0x00, 0x05, // length = 5
0x01, 0x00, 0x00, 0x00, 0x00,
// extension: supported_groups
0x00, 0x0a, // type = supported_groups
0x00, 0x0a, // length = 10
0x00, 0x08, // list length = 8
0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19,
// extension: ec_point_formats
0x00, 0x0b, // type = ec_point_formats
0x00, 0x02, // length = 2
0x01, 0x00,
// extension: signature_algorithms
0x00, 0x0d, // type = signature_algorithms
0x00, 0x12, // length = 18
0x00, 0x10, // list length = 16
0x04, 0x01, 0x04, 0x03, 0x05, 0x01, 0x05, 0x03,
0x06, 0x01, 0x06, 0x03, 0x02, 0x01, 0x02, 0x03,
// extension: renegotiation_info
0xff, 0x01, // type = renegotiation_info
0x00, 0x01, // length = 1
0x00,
// extension: extended_master_secret (empty)
0x00, 0x17, // type = extended_master_secret
0x00, 0x00, // length = 0
// extension: session_ticket (empty)
0x00, 0x23, // type = session_ticket
0x00, 0x00, // length = 0
}
return &utils.ByteBuffer{Buf: body}
}
func TestParseTLSClientHelloMsgData(t *testing.T) {
chBuf := buildClientHelloMsg(t)
m := ParseTLSClientHelloMsgData(chBuf)
if m == nil {
t.Fatal("ParseTLSClientHelloMsgData returned nil")
}
wantVersion := uint16(0x0303)
if v, ok := m["version"].(uint16); !ok || v != wantVersion {
t.Errorf("version = %v, want %v", m["version"], wantVersion)
}
wantCiphers := []uint16{52392, 52393, 49199, 49200, 49195, 49196, 49171, 49161, 49172, 49162, 156, 157, 47, 53, 49170, 10}
if c, ok := m["ciphers"].([]uint16); ok {
if !reflect.DeepEqual(c, wantCiphers) {
t.Errorf("ciphers = %v, want %v", c, wantCiphers)
}
} else {
t.Errorf("ciphers missing or wrong type: %T", m["ciphers"])
}
if sni, ok := m["sni"].(string); !ok || sni != "example.ulfheim.net" {
t.Errorf("sni = %q, want %q", m["sni"], "example.ulfheim.net")
}
if _, ok := m["compression"]; !ok {
t.Error("compression key missing")
}
if _, ok := m["session"]; !ok {
t.Error("session key missing")
}
if _, ok := m["random"]; !ok {
t.Error("random key missing")
}
}
func TestParseTLSServerHelloMsgData(t *testing.T) {
// ServerHello message body (after record header + handshake header)
body := []byte{
0x03, 0x03, // version TLS 1.2
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77,
0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f,
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87,
0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, // random
0x00, // session ID length = 0
0xc0, 0x13, // cipher suite
0x00, // compression method
0x00, 0x05, // extensions length = 5
// extension: renegotiation_info
0xff, 0x01, // type
0x00, 0x01, // length = 1
0x00,
}
m := ParseTLSServerHelloMsgData(&utils.ByteBuffer{Buf: body})
if m == nil {
t.Fatal("ParseTLSServerHelloMsgData returned nil")
}
wantCipher := uint16(0xc013)
if c, ok := m["cipher"].(uint16); !ok || c != wantCipher {
t.Errorf("cipher = %v, want %v", m["cipher"], wantCipher)
}
wantCompression := uint8(0)
if c, ok := m["compression"].(uint8); !ok || c != wantCompression {
t.Errorf("compression = %v, want %v", m["compression"], wantCompression)
}
}
func TestParseTLSServerHelloMsgData_NoExtensions(t *testing.T) {
// ServerHello message without extensions
body := []byte{
0x03, 0x03, // version
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
0x00, // session ID length
0x00, 0xff, // cipher suite
0x00, // compression method
}
m := ParseTLSServerHelloMsgData(&utils.ByteBuffer{Buf: body})
if m == nil {
t.Fatal("ParseTLSServerHelloMsgData returned nil for no-extensions case")
}
if _, ok := m["cipher"]; !ok {
t.Error("cipher key missing")
}
}
func TestParseTLSClientHelloMsgData_Truncated(t *testing.T) {
tests := []struct {
name string
buf []byte
}{
{"too short for session id", []byte{0x03, 0x03, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f}},
{"odd cipher suites length", []byte{
0x03, 0x03, // version
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
0x00, // session ID length
0x00, 0x03, // odd cipher suites length
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: tt.buf})
if m != nil {
t.Error("expected nil for truncated input")
}
})
}
}
func TestParseTLSClientHelloMsgData_ECH(t *testing.T) {
// ClientHello with ECH extension
body := []byte{
0x03, 0x03, // version
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
0x00, // session ID length
0x00, 0x02, // cipher suites length = 2
0x13, 0x01, // TLS_AES_128_GCM_SHA256
0x01, // compression methods length = 1
0x00, // null compression
0x00, 0x07, // extensions length = 7
// ECH extension
0xfe, 0x0d, // type = encrypted_client_hello
0x00, 0x03, // length = 3
0x01, 0x02, 0x03, // some ECH data
}
m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: body})
if m == nil {
t.Fatal("ParseTLSClientHelloMsgData returned nil")
}
ech, ok := m["ech"].(bool)
if !ok || !ech {
t.Errorf("ech = %v, want true", m["ech"])
}
}
func TestParseTLSClientHelloMsgData_ALPN(t *testing.T) {
// ClientHello with ALPN extension
body := []byte{
0x03, 0x03, // version
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
0x00, // session ID length
0x00, 0x02, // cipher suites length = 2
0x13, 0x01, // cipher suite
0x01, // compression methods length = 1
0x00, // compression method
0x00, 0x12, // extensions length = 18
// ALPN extension
0x00, 0x10, // type = ALPN
0x00, 0x0e, // length = 14
0x00, 0x0c, // list length = 12
0x08, 'h', 't', 't', 'p', '/', '1', '.', '1',
0x02, 'h', '2',
}
m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: body})
if m == nil {
t.Fatal("ParseTLSClientHelloMsgData returned nil")
}
alpn, ok := m["alpn"].([]string)
if !ok {
t.Fatalf("alpn missing or wrong type: %T", m["alpn"])
}
if len(alpn) != 2 || alpn[0] != "http/1.1" || alpn[1] != "h2" {
t.Errorf("alpn = %v, want [http/1.1 h2]", alpn)
}
}
func TestParseTLSClientHelloMsgData_SupportedVersionsClient(t *testing.T) {
// ClientHello with supported_versions extension (client format - list)
body := []byte{
0x03, 0x03, // version
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
0x00, // session ID length
0x00, 0x02, // cipher suites length = 2
0x13, 0x01, // cipher suite
0x01, // compression methods length = 1
0x00, // compression method
0x00, 0x0b, // extensions length = 11
// supported_versions (client list format)
0x00, 0x2b, // type = supported_versions
0x00, 0x07, // length = 7
0x06, // list length = 6
0x03, 0x04, // TLS 1.3
0x03, 0x03, // TLS 1.2
0x03, 0x01, // TLS 1.0
}
m := ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: body})
if m == nil {
t.Fatal("ParseTLSClientHelloMsgData returned nil")
}
versions, ok := m["supported_versions"].([]uint16)
if !ok {
t.Fatalf("supported_versions missing or wrong type: %T", m["supported_versions"])
}
want := []uint16{0x0304, 0x0303, 0x0301}
if !reflect.DeepEqual(versions, want) {
t.Errorf("supported_versions = %v, want %v", versions, want)
}
}
func TestParseTLSServerHelloMsgData_SupportedVersionsServer(t *testing.T) {
// ServerHello with supported_versions extension (server format - single value)
body := []byte{
0x03, 0x03, // version
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // random
0x00, // session ID length
0x13, 0x01, // cipher suite
0x00, // compression method
0x00, 0x06, // extensions length = 6
// supported_versions (server format - single value)
0x00, 0x2b, // type
0x00, 0x02, // length = 2
0x03, 0x04, // TLS 1.3
}
m := ParseTLSServerHelloMsgData(&utils.ByteBuffer{Buf: body})
if m == nil {
t.Fatal("ParseTLSServerHelloMsgData returned nil")
}
v, ok := m["supported_versions"].(uint16)
if !ok {
t.Fatalf("supported_versions missing or wrong type: %T", m["supported_versions"])
}
if v != 0x0304 {
t.Errorf("supported_versions = 0x%04x, want 0x0304", v)
}
}

256
analyzer/tcp/fet_test.go Normal file
View File

@@ -0,0 +1,256 @@
package tcp
import (
"testing"
)
func TestPopCount(t *testing.T) {
tests := []struct {
input byte
want int
}{
{0x00, 0},
{0x01, 1},
{0x02, 1},
{0x03, 2},
{0x07, 3},
{0x0f, 4},
{0xff, 8},
{0x55, 4}, // 01010101
{0xaa, 4}, // 10101010
}
for _, tt := range tests {
if got := popCount(tt.input); got != tt.want {
t.Errorf("popCount(0x%02x) = %d, want %d", tt.input, got, tt.want)
}
}
}
func TestAveragePopCount(t *testing.T) {
tests := []struct {
name string
input []byte
want float32
}{
{"empty", []byte{}, 0},
{"all zeros", []byte{0x00, 0x00, 0x00}, 0},
{"all ones", []byte{0xff, 0xff}, 8.0},
{"mixed", []byte{0x00, 0xff}, 4.0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := averagePopCount(tt.input)
if got != tt.want {
t.Errorf("averagePopCount() = %v, want %v", got, tt.want)
}
})
}
}
func TestIsFirstSixPrintable(t *testing.T) {
tests := []struct {
name string
input []byte
want bool
}{
{"too short", []byte("abc"), false},
{"all printable", []byte("abcdef"), true},
{"non-printable at pos 0", []byte{0x00, 'a', 'b', 'c', 'd', 'e'}, false},
{"non-printable at pos 5", []byte{'a', 'b', 'c', 'd', 'e', 0x1f}, false},
{"exactly 6 printable", []byte("123456"), true},
{"spaces", []byte(" "), true},
{"non-printable middle", []byte{'a', 'b', 0x01, 'd', 'e', 'f'}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isFirstSixPrintable(tt.input)
if got != tt.want {
t.Errorf("isFirstSixPrintable(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestPrintablePercentage(t *testing.T) {
tests := []struct {
name string
input []byte
want float32
}{
{"empty", []byte{}, 0},
{"all printable", []byte("hello"), 1.0},
{"none printable", []byte{0x00, 0x01, 0x02, 0x03}, 0},
{"half printable", []byte{'a', 0x00, 'b', 0x00}, 0.5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := printablePercentage(tt.input)
if got != tt.want {
t.Errorf("printablePercentage() = %v, want %v", got, tt.want)
}
})
}
}
func TestContiguousPrintable(t *testing.T) {
tests := []struct {
name string
input []byte
want int
}{
{"empty", []byte{}, 0},
{"all printable", []byte("hello world"), 11},
{"none printable", []byte{0x00, 0x01, 0x02}, 0},
{"start printable", []byte{'a', 'b', 'c', 0x00, 'd', 'e', 'f'}, 3},
{"end printable", []byte{0x00, 'a', 'b', 'c', 'd', 'e', 'f'}, 6},
{"middle printable", []byte{0x00, 'a', 'b', 'c', 0x00}, 3},
{"two segments", []byte{'a', 'b', 0x00, 'c', 'd', 'e'}, 3},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := contiguousPrintable(tt.input)
if got != tt.want {
t.Errorf("contiguousPrintable() = %d, want %d", got, tt.want)
}
})
}
}
func TestIsTLSorHTTP(t *testing.T) {
tests := []struct {
name string
input []byte
want bool
}{
{"too short", []byte("AB"), false},
// TLS ClientHello: 0x16 0x03 0x01
{"tls 1.0", []byte{0x16, 0x03, 0x01}, true},
// TLS 0x17 is application data record
{"tls app data", []byte{0x17, 0x03, 0x03}, true},
{"tls max content type", []byte{0x17, 0x03, 0x09}, true},
{"bad tls content type", []byte{0x15, 0x03, 0x01}, false},
{"bad tls version", []byte{0x16, 0x04, 0x01}, false},
{"bad tls length", []byte{0x16, 0x03, 0x0a}, false},
// HTTP methods
{"GET", []byte("GET / HTTP/1.1..."), true},
{"HEAD", []byte("HEAD /index.html..."), true},
{"POST", []byte("POST /api..."), true},
{"PUT", []byte("PUT /data..."), true},
{"DELETE", []byte("DELETE /..."), true},
{"CONNECT", []byte("CONNECT proxy..."), true},
{"OPTIONS", []byte("OPTIONS *..."), true},
{"TRACE", []byte("TRACE /..."), true},
{"PATCH", []byte("PATCH /..."), true},
{"random data", []byte{0x00, 0x01, 0x02, 0x03}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isTLSorHTTP(tt.input)
if got != tt.want {
t.Errorf("isTLSorHTTP(%v) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestIsPrintable(t *testing.T) {
if !isPrintable('a') {
t.Error("'a' should be printable")
}
if isPrintable(0x00) {
t.Error("0x00 should not be printable")
}
if !isPrintable(0x20) {
t.Error("0x20 (space) should be printable")
}
if !isPrintable(0x7e) {
t.Error("0x7e (~) should be printable")
}
if isPrintable(0x7f) {
t.Error("0x7f (DEL) should not be printable")
}
}
func TestFETStream_Feed(t *testing.T) {
s := newFETStream(nil)
u, done := s.Feed(false, false, false, 0, []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
if u == nil {
t.Fatal("Feed returned nil update")
}
if !done {
t.Error("FET should be done after first packet")
}
m := u.M
if _, ok := m["ex1"]; !ok {
t.Error("ex1 missing")
}
if _, ok := m["ex2"]; !ok {
t.Error("ex2 missing")
}
if _, ok := m["ex3"]; !ok {
t.Error("ex3 missing")
}
if _, ok := m["ex4"]; !ok {
t.Error("ex4 missing")
}
if _, ok := m["ex5"]; !ok {
t.Error("ex5 missing")
}
if yes, ok := m["yes"].(bool); ok && yes {
t.Error("HTTP should be exempt (yes=false)")
}
if u.Type != 2 {
t.Errorf("prop update type = %d, want PropUpdateReplace", u.Type)
}
}
func TestFETStream_Feed_EncryptedLike(t *testing.T) {
s := newFETStream(nil)
data := make([]byte, 100)
for i := range data {
data[i] = byte(i % 256)
}
u, done := s.Feed(false, false, false, 0, data)
if u == nil {
t.Fatal("Feed returned nil update")
}
if !done {
t.Error("should be done")
}
}
func TestFETStream_Feed_Skip(t *testing.T) {
s := newFETStream(nil)
_, done := s.Feed(false, false, false, 5, []byte("data"))
if !done {
t.Error("skip != 0 should return done=true")
}
}
func TestFETStream_Feed_Empty(t *testing.T) {
s := newFETStream(nil)
u, done := s.Feed(false, false, false, 0, []byte{})
if u != nil || done {
t.Error("empty data should return nil, false")
}
}
func TestFETAnalyzer_Name(t *testing.T) {
a := &FETAnalyzer{}
if a.Name() != "fet" {
t.Errorf("Name() = %q, want fet", a.Name())
}
}
func TestFETStream_Close(t *testing.T) {
s := newFETStream(nil)
if u := s.Close(false); u != nil {
t.Error("Close should return nil")
}
}

159
analyzer/tcp/ssh_test.go Normal file
View File

@@ -0,0 +1,159 @@
package tcp
import (
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
)
func TestSSHAnalyzer_Name(t *testing.T) {
a := &SSHAnalyzer{}
if a.Name() != "ssh" {
t.Errorf("Name() = %q, want ssh", a.Name())
}
}
func TestSSHStream_Feed_Client(t *testing.T) {
s := newSSHStream(nil)
u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH_8.9p1 Ubuntu-3\r\n"))
if u == nil {
t.Fatal("Feed returned nil update")
}
if done {
t.Error("should not be done (server not yet received)")
}
client, ok := u.M["client"].(analyzer.PropMap)
if !ok {
t.Fatal("client prop missing")
}
if client["protocol"] != "2.0" {
t.Errorf("protocol = %v, want 2.0", client["protocol"])
}
if client["software"] != "OpenSSH_8.9p1" {
t.Errorf("software = %v, want OpenSSH_8.9p1", client["software"])
}
if comments, ok := client["comments"]; ok {
if comments != "Ubuntu-3" {
t.Errorf("comments = %v, want Ubuntu-3", comments)
}
}
}
func TestSSHStream_Feed_ClientWithComments(t *testing.T) {
s := newSSHStream(nil)
u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH_7.4 Ubuntu-3\r\n"))
if u == nil {
t.Fatal("Feed returned nil update")
}
if done {
t.Error("should not be done")
}
client := u.M["client"].(analyzer.PropMap)
if client["comments"] != "Ubuntu-3" {
t.Errorf("comments = %v, want Ubuntu-3", client["comments"])
}
}
func TestSSHStream_Feed_Both(t *testing.T) {
s := newSSHStream(nil)
s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH_8.9\r\n"))
u, done := s.Feed(true, false, false, 0, []byte("SSH-2.0-dropbear_2022.83\r\n"))
if u == nil {
t.Fatal("Feed returned nil update")
}
if !done {
t.Error("should be done after both sides")
}
server, ok := u.M["server"].(analyzer.PropMap)
if !ok {
t.Fatal("server prop missing")
}
if server["software"] != "dropbear_2022.83" {
t.Errorf("server software = %v", server["software"])
}
}
func TestSSHStream_Feed_NotSSH(t *testing.T) {
s := newSSHStream(nil)
u, done := s.Feed(false, false, false, 0, []byte("HTTP/1.1 200 OK\r\n"))
if u != nil {
t.Error("should return nil for non-SSH")
}
if !done {
t.Error("should be cancelled (done) for non-SSH")
}
}
func TestSSHStream_Feed_InvalidLine(t *testing.T) {
s := newSSHStream(nil)
u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-foo bar baz\r\n"))
if u != nil {
t.Error("should return nil for invalid line (>2 fields)")
}
if !done {
t.Error("should be cancelled for invalid SSH line")
}
}
func TestSSHStream_Feed_NoLineEnd(t *testing.T) {
s := newSSHStream(nil)
u, done := s.Feed(false, false, false, 0, []byte("SSH-2.0-OpenSSH"))
if u != nil {
t.Error("should return nil when no EOL found yet")
}
if done {
t.Error("should not be done, waiting for more data")
}
}
func TestSSHStream_Feed_IncompleteThenComplete(t *testing.T) {
s := newSSHStream(nil)
u, _ := s.Feed(false, false, false, 0, []byte("SSH-2.0-"))
if u != nil {
t.Error("first partial feed should return nil")
}
u, done := s.Feed(false, false, false, 0, []byte("Dropbear\r\n"))
if u == nil {
t.Fatal("second feed should return update")
}
if done {
t.Error("should not be done (only client)")
}
client := u.M["client"].(analyzer.PropMap)
if client["software"] != "Dropbear" {
t.Errorf("software = %v", client["software"])
}
}
func TestSSHStream_Feed_Skip(t *testing.T) {
s := newSSHStream(nil)
_, done := s.Feed(false, false, false, 5, []byte("data"))
if !done {
t.Error("skip != 0 should return done=true")
}
}
func TestSSHStream_Feed_Empty(t *testing.T) {
s := newSSHStream(nil)
u, done := s.Feed(false, false, false, 0, []byte{})
if u != nil || done {
t.Error("empty data should return nil, false")
}
}
func TestSSHStream_Close(t *testing.T) {
s := newSSHStream(nil)
s.clientBuf.Append([]byte("data"))
s.serverBuf.Append([]byte("data"))
s.clientMap = analyzer.PropMap{"key": "val"}
u := s.Close(false)
if u != nil {
t.Error("Close should return nil")
}
if s.clientBuf.Len() != 0 || s.serverBuf.Len() != 0 {
t.Error("Close should reset buffers")
}
if s.clientMap != nil || s.serverMap != nil {
t.Error("Close should nil maps")
}
}

View File

@@ -0,0 +1,187 @@
package udp
import (
"encoding/binary"
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
)
func makeWireGuardPacket(msgType byte, body []byte) []byte {
buf := make([]byte, 1+3+len(body))
buf[0] = msgType
copy(buf[4:], body)
return buf
}
func TestWireGuardUDPStream_Feed_HandshakeInitiation(t *testing.T) {
s := newWireGuardUDPStream(nil)
body := make([]byte, wireguardSizeHandshakeInitiation-4)
binary.LittleEndian.PutUint32(body[0:4], 0x01020304)
u, done := s.Feed(false, makeWireGuardPacket(wireguardTypeHandshakeInitiation, body))
if u == nil {
t.Fatal("Feed returned nil update")
}
if done {
t.Error("should not be done")
}
if u.Type != analyzer.PropUpdateReplace {
t.Errorf("Type = %d, want PropUpdateReplace", u.Type)
}
initMap, ok := u.M["handshake_initiation"].(analyzer.PropMap)
if !ok {
t.Fatal("handshake_initiation missing")
}
if idx := initMap["sender_index"].(uint32); idx != 0x01020304 {
t.Errorf("sender_index = %d, want %d", idx, 0x01020304)
}
}
func TestWireGuardUDPStream_Feed_HandshakeResponse(t *testing.T) {
s := newWireGuardUDPStream(nil)
body := make([]byte, wireguardSizeHandshakeResponse-4)
binary.LittleEndian.PutUint32(body[0:4], 0x01020304)
binary.LittleEndian.PutUint32(body[4:8], 0x05060708)
u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeHandshakeResponse, body))
if u == nil {
t.Fatal("Feed returned nil update")
}
respMap, ok := u.M["handshake_response"].(analyzer.PropMap)
if !ok {
t.Fatal("handshake_response missing")
}
if idx := respMap["sender_index"].(uint32); idx != 0x01020304 {
t.Errorf("sender_index = %d", idx)
}
if idx := respMap["receiver_index"].(uint32); idx != 0x05060708 {
t.Errorf("receiver_index = %d", idx)
}
}
func TestWireGuardUDPStream_Feed_PacketData(t *testing.T) {
s := newWireGuardUDPStream(nil)
body := make([]byte, wireguardMinSizePacketData-4)
binary.LittleEndian.PutUint32(body[0:4], 0x0a0b0c0d)
binary.LittleEndian.PutUint64(body[4:12], 42)
u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeData, body))
if u == nil {
t.Fatal("Feed returned nil update")
}
dataMap, ok := u.M["packet_data"].(analyzer.PropMap)
if !ok {
t.Fatal("packet_data missing")
}
if idx := dataMap["receiver_index"].(uint32); idx != 0x0a0b0c0d {
t.Errorf("receiver_index = %d", idx)
}
if ctr := dataMap["counter"].(uint64); ctr != 42 {
t.Errorf("counter = %d", ctr)
}
}
func TestWireGuardUDPStream_Feed_CookieReply(t *testing.T) {
s := newWireGuardUDPStream(nil)
body := make([]byte, wireguardSizePacketCookieReply-4)
binary.LittleEndian.PutUint32(body[0:4], 0x11111111)
u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeCookieReply, body))
if u == nil {
t.Fatal("Feed returned nil update")
}
crMap, ok := u.M["packet_cookie_reply"].(analyzer.PropMap)
if !ok {
t.Fatal("packet_cookie_reply missing")
}
if idx := crMap["receiver_index"].(uint32); idx != 0x11111111 {
t.Errorf("receiver_index = %d", idx)
}
}
func TestWireGuardUDPStream_Feed_InvalidLength(t *testing.T) {
s := newWireGuardUDPStream(nil)
u, done := s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01})
if u == nil || !done {
u, done = s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01})
}
if u == nil || !done {
u, _ = s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01})
}
u, done = s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 0, 0x01})
if u != nil || !done {
t.Error("should return done after 4 invalid packets")
}
}
func TestWireGuardUDPStream_Feed_TooShort(t *testing.T) {
s := newWireGuardUDPStream(nil)
u, _ := s.Feed(false, []byte{1, 0, 0})
if u != nil {
t.Error("should return nil for too-short packet")
}
}
func TestWireGuardUDPStream_Feed_NonZeroReserved(t *testing.T) {
s := newWireGuardUDPStream(nil)
u, _ := s.Feed(false, []byte{wireguardTypeHandshakeInitiation, 0, 0, 1, 0, 0, 0, 0})
if u != nil {
t.Error("should return nil when reserved bytes are non-zero")
}
}
func TestWireGuardUDPStream_Feed_HandshakeInitiation_WrongSize(t *testing.T) {
s := newWireGuardUDPStream(nil)
body := make([]byte, wireguardSizeHandshakeInitiation-4+1)
binary.LittleEndian.PutUint32(body[0:4], 1)
u, _ := s.Feed(false, makeWireGuardPacket(wireguardTypeHandshakeInitiation, body))
if u != nil {
t.Error("should return nil for wrong init size")
}
}
func TestWireGuardUDPStream_Close(t *testing.T) {
s := newWireGuardUDPStream(nil)
u := s.Close(false)
if u != nil {
t.Error("Close() should return nil")
}
}
func TestWireGuardUDPStream_PutSenderIndex_MatchReceiverIndex(t *testing.T) {
s := newWireGuardUDPStream(nil)
s.putSenderIndex(false, 0xdeadbeef)
if !s.matchReceiverIndex(true, 0xdeadbeef) {
t.Error("should match reverse direction")
}
if s.matchReceiverIndex(false, 0xdeadbeef) {
t.Error("should not match same direction")
}
if s.matchReceiverIndex(true, 0x12345678) {
t.Error("should not match wrong index")
}
}
func TestWireGuardUDPStream_RememberedIndexRing(t *testing.T) {
s := newWireGuardUDPStream(nil)
for i := uint32(0); i < wireguardRememberedIndexCount+2; i++ {
s.putSenderIndex(false, i)
}
if s.matchReceiverIndex(true, 0) {
t.Error("index 0 should have been evicted from ring")
}
if !s.matchReceiverIndex(true, uint32(wireguardRememberedIndexCount+1)) {
t.Error("latest index should still be in ring")
}
}
func TestWireGuardAnalyzer_Name(t *testing.T) {
a := &WireGuardAnalyzer{}
if a.Name() != "wireguard" {
t.Errorf("Name() = %q, want wireguard", a.Name())
}
}

View File

@@ -0,0 +1,321 @@
package utils
import (
"bytes"
"reflect"
"testing"
)
func TestByteBuffer_Append(t *testing.T) {
b := &ByteBuffer{}
b.Append([]byte("hello"))
b.Append([]byte(" "))
b.Append([]byte("world"))
if string(b.Buf) != "hello world" {
t.Errorf("Append() result = %q, want %q", string(b.Buf), "hello world")
}
}
func TestByteBuffer_Len(t *testing.T) {
b := &ByteBuffer{}
if b.Len() != 0 {
t.Errorf("Len() = %d, want 0", b.Len())
}
b.Append([]byte("abc"))
if b.Len() != 3 {
t.Errorf("Len() = %d, want 3", b.Len())
}
}
func TestByteBuffer_Index(t *testing.T) {
b := &ByteBuffer{Buf: []byte("hello world")}
if i := b.Index([]byte("world")); i != 6 {
t.Errorf("Index('world') = %d, want 6", i)
}
if i := b.Index([]byte("xyz")); i != -1 {
t.Errorf("Index('xyz') = %d, want -1", i)
}
if i := b.Index([]byte{}); i != 0 {
t.Errorf("Index('') = %d, want 0", i)
}
}
func TestByteBuffer_Get(t *testing.T) {
b := &ByteBuffer{Buf: []byte("abcdef")}
data, ok := b.Get(3, true)
if !ok {
t.Fatal("Get(3, true) returned false")
}
if !bytes.Equal(data, []byte("abc")) {
t.Errorf("Get(3, true) data = %q, want %q", data, "abc")
}
if !bytes.Equal(b.Buf, []byte("def")) {
t.Errorf("after consume, Buf = %q, want %q", b.Buf, "def")
}
data, ok = b.Get(4, false)
if ok {
t.Fatal("Get(4, false) should return false (only 3 bytes left)")
}
data, ok = b.Get(2, false)
if !ok {
t.Fatal("Get(2, false) returned false")
}
if !bytes.Equal(data, []byte("de")) {
t.Errorf("Get(2, false) data = %q, want %q", data, "de")
}
if !bytes.Equal(b.Buf, []byte("def")) {
t.Errorf("after non-consume, Buf = %q, want %q (unchanged)", b.Buf, "def")
}
data, ok = b.Get(3, true)
if !ok {
t.Fatal("Get(3, true) returned false")
}
if !bytes.Equal(data, []byte("def")) {
t.Errorf("Get(3, true) data = %q, want %q", data, "def")
}
if b.Len() != 0 {
t.Errorf("after consume all, Len() = %d, want 0", b.Len())
}
}
func TestByteBuffer_GetString(t *testing.T) {
b := &ByteBuffer{Buf: []byte("hello world")}
s, ok := b.GetString(5, true)
if !ok {
t.Fatal("GetString(5, true) returned false")
}
if s != "hello" {
t.Errorf("GetString() = %q, want %q", s, "hello")
}
if b.Len() != 6 {
t.Errorf("after consume, Len() = %d, want 6", b.Len())
}
s, ok = b.GetString(7, false)
if ok {
t.Fatal("GetString(7, false) should return false")
}
s, ok = b.GetString(6, false)
if !ok {
t.Fatal("GetString(6, false) returned false")
}
if s != " world" {
t.Errorf("GetString() = %q, want %q", s, " world")
}
}
func TestByteBuffer_GetByte(t *testing.T) {
b := &ByteBuffer{Buf: []byte("abc")}
bt, ok := b.GetByte(true)
if !ok {
t.Fatal("GetByte(true) returned false")
}
if bt != 'a' {
t.Errorf("GetByte() = %c, want 'a'", bt)
}
if b.Len() != 2 {
t.Errorf("after consume, Len() = %d, want 2", b.Len())
}
bt, ok = b.GetByte(false)
if !ok {
t.Fatal("GetByte(false) returned false")
}
if bt != 'b' {
t.Errorf("GetByte() = %c, want 'b'", bt)
}
if b.Len() != 2 {
t.Errorf("after non-consume, Len() = %d, want 2 (unchanged)", b.Len())
}
b.GetByte(true)
b.GetByte(true)
bt, ok = b.GetByte(true)
if ok {
t.Fatal("GetByte(true) on empty buffer should return false")
}
}
func TestByteBuffer_GetUint16(t *testing.T) {
b := &ByteBuffer{Buf: []byte{0x01, 0x02, 0x03, 0x04}}
v, ok := b.GetUint16(false, true)
if !ok {
t.Fatal("GetUint16(bigEndian) returned false")
}
if v != 0x0102 {
t.Errorf("GetUint16(bigEndian) = 0x%04x, want 0x0102", v)
}
if b.Len() != 2 {
t.Errorf("after consume, Len() = %d, want 2", b.Len())
}
v, ok = b.GetUint16(true, true)
if !ok {
t.Fatal("GetUint16(littleEndian) returned false")
}
if v != 0x0403 {
t.Errorf("GetUint16(littleEndian) = 0x%04x, want 0x0403", v)
}
v, ok = b.GetUint16(false, false)
if ok {
t.Fatal("GetUint16 on empty buffer should return false")
}
}
func TestByteBuffer_GetUint32(t *testing.T) {
b := &ByteBuffer{Buf: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}}
v, ok := b.GetUint32(false, true)
if !ok {
t.Fatal("GetUint32(bigEndian) returned false")
}
if v != 0x01020304 {
t.Errorf("GetUint32(bigEndian) = 0x%08x, want 0x01020304", v)
}
if b.Len() != 4 {
t.Errorf("after consume, Len() = %d, want 4", b.Len())
}
v, ok = b.GetUint32(true, true)
if !ok {
t.Fatal("GetUint32(littleEndian) returned false")
}
if v != 0x08070605 {
t.Errorf("GetUint32(littleEndian) = 0x%08x, want 0x08070605", v)
}
v, ok = b.GetUint32(false, false)
if ok {
t.Fatal("GetUint32 on empty buffer should return false")
}
}
func TestByteBuffer_GetUntil(t *testing.T) {
b := &ByteBuffer{Buf: []byte("hello\r\nworld\r\n")}
data, ok := b.GetUntil([]byte("\r\n"), true, true)
if !ok {
t.Fatal("GetUntil(sep, include) returned false")
}
if !bytes.Equal(data, []byte("hello\r\n")) {
t.Errorf("GetUntil(include) = %q, want %q", data, "hello\\r\\n")
}
data, ok = b.GetUntil([]byte("\r\n"), false, false)
if !ok {
t.Fatal("GetUntil(sep, exclude, non-consume) returned false")
}
if !bytes.Equal(data, []byte("world")) {
t.Errorf("GetUntil(exclude) = %q, want %q", data, "world")
}
if b.Len() != 7 {
t.Errorf("after non-consume, Len() = %d, want 7", b.Len())
}
data, ok = b.GetUntil([]byte("\r\n"), true, true)
if !ok {
t.Fatal("GetUntil second (consume) returned false")
}
if !bytes.Equal(data, []byte("world\r\n")) {
t.Errorf("GetUntil second = %q, want %q", data, "world\\r\\n")
}
_, ok = b.GetUntil([]byte("xyz"), false, false)
if ok {
t.Fatal("GetUntil(not found) should return false")
}
}
func TestByteBuffer_GetSubBuffer(t *testing.T) {
b := &ByteBuffer{Buf: []byte("hello world")}
sub, ok := b.GetSubBuffer(5, true)
if !ok {
t.Fatal("GetSubBuffer() returned false")
}
if !bytes.Equal(sub.Buf, []byte("hello")) {
t.Errorf("GetSubBuffer() = %q, want %q", sub.Buf, "hello")
}
if b.Len() != 6 {
t.Errorf("after consume, Len() = %d, want 6", b.Len())
}
_, ok = b.GetSubBuffer(7, false)
if ok {
t.Fatal("GetSubBuffer(7) should return false (only 6 bytes left)")
}
}
func TestByteBuffer_Skip(t *testing.T) {
b := &ByteBuffer{Buf: []byte("abcdef")}
ok := b.Skip(2)
if !ok {
t.Fatal("Skip(2) returned false")
}
if !bytes.Equal(b.Buf, []byte("cdef")) {
t.Errorf("after Skip(2), Buf = %q, want %q", b.Buf, "cdef")
}
ok = b.Skip(10)
if ok {
t.Fatal("Skip(10) should return false")
}
if !bytes.Equal(b.Buf, []byte("cdef")) {
t.Errorf("after failed Skip, Buf = %q, want %q (unchanged)", b.Buf, "cdef")
}
ok = b.Skip(4)
if !ok {
t.Fatal("Skip(4) returned false")
}
if b.Len() != 0 {
t.Errorf("after Skip all, Len() = %d, want 0", b.Len())
}
}
func TestByteBuffer_Reset(t *testing.T) {
b := &ByteBuffer{Buf: []byte("data")}
b.Reset()
if b.Buf != nil {
t.Errorf("after Reset, Buf = %v, want nil", b.Buf)
}
}
func TestByteBuffer_GetZeroLength(t *testing.T) {
b := &ByteBuffer{Buf: []byte("abc")}
data, ok := b.Get(0, true)
if !ok {
t.Fatal("Get(0) returned false")
}
if len(data) != 0 {
t.Errorf("Get(0) len = %d, want 0", len(data))
}
if b.Len() != 3 {
t.Errorf("after Get(0, consume), Len() = %d, want 3 (0-length consume is no-op)", b.Len())
}
}
func TestByteBuffer_GetConsumeDoesNotMutateReturnedSlice(t *testing.T) {
b := &ByteBuffer{Buf: []byte("hello")}
data, ok := b.Get(5, true)
if !ok {
t.Fatal("Get() returned false")
}
if !reflect.DeepEqual(data, []byte("hello")) {
t.Errorf("Get() returned wrong data: %v", data)
}
if b.Len() != 0 {
t.Errorf("after consume, Len() should be 0")
}
}

185
analyzer/utils/lsm_test.go Normal file
View File

@@ -0,0 +1,185 @@
package utils
import "testing"
func TestLinearStateMachine_RunPause(t *testing.T) {
callCount := 0
lsm := NewLinearStateMachine(
func() LSMAction {
callCount++
return LSMActionPause
},
)
cancelled, done := lsm.Run()
if cancelled {
t.Error("unexpected cancelled=true")
}
if done {
t.Error("unexpected done=true")
}
if callCount != 1 {
t.Errorf("callCount = %d, want 1", callCount)
}
}
func TestLinearStateMachine_RunNext(t *testing.T) {
callOrder := []int{}
lsm := NewLinearStateMachine(
func() LSMAction { callOrder = append(callOrder, 1); return LSMActionNext },
func() LSMAction { callOrder = append(callOrder, 2); return LSMActionNext },
func() LSMAction { callOrder = append(callOrder, 3); return LSMActionNext },
)
cancelled, done := lsm.Run()
if cancelled {
t.Error("unexpected cancelled=true")
}
if !done {
t.Error("unexpected done=false")
}
if len(callOrder) != 3 {
t.Fatalf("callOrder len = %d, want 3", len(callOrder))
}
for i, v := range []int{1, 2, 3} {
if callOrder[i] != v {
t.Errorf("callOrder[%d] = %d, want %d", i, callOrder[i], v)
}
}
}
func TestLinearStateMachine_RunReset(t *testing.T) {
callCount := 0
lsm := NewLinearStateMachine(
func() LSMAction {
callCount++
if callCount == 1 {
return LSMActionReset
}
return LSMActionNext
},
func() LSMAction { callCount++; return LSMActionNext },
)
cancelled, done := lsm.Run()
if cancelled {
t.Error("unexpected cancelled=true")
}
if !done {
t.Error("unexpected done=false")
}
if callCount != 3 {
t.Errorf("callCount = %d, want 3 (step0 reset, step0 next, step1 next)", callCount)
}
}
func TestLinearStateMachine_RunCancel(t *testing.T) {
callCount := 0
lsm := NewLinearStateMachine(
func() LSMAction { callCount++; return LSMActionNext },
func() LSMAction { callCount++; return LSMActionCancel },
func() LSMAction { callCount++; return LSMActionNext },
)
cancelled, done := lsm.Run()
if !cancelled {
t.Error("unexpected cancelled=false")
}
if !done {
t.Error("unexpected done=false")
}
if callCount != 2 {
t.Errorf("callCount = %d, want 2 (third step should not execute)", callCount)
}
}
func TestLinearStateMachine_RunMixed(t *testing.T) {
pauseCount := 0
lsm := NewLinearStateMachine(
func() LSMAction { return LSMActionNext },
func() LSMAction {
pauseCount++
if pauseCount == 1 {
return LSMActionPause
}
return LSMActionNext
},
func() LSMAction { return LSMActionNext },
func() LSMAction { return LSMActionNext },
)
cancelled, done := lsm.Run()
if cancelled {
t.Error("unexpected cancelled=true")
}
if done {
t.Error("unexpected done=true on first run")
}
cancelled, done = lsm.Run()
if cancelled {
t.Error("unexpected cancelled=true on second run")
}
if !done {
t.Error("unexpected done=false on second run")
}
}
func TestLinearStateMachine_RunEmpty(t *testing.T) {
lsm := NewLinearStateMachine()
cancelled, done := lsm.Run()
if cancelled {
t.Error("unexpected cancelled=true")
}
if !done {
t.Error("unexpected done=false for empty LSM")
}
}
func TestLinearStateMachine_AppendSteps(t *testing.T) {
lsm := NewLinearStateMachine(
func() LSMAction { return LSMActionNext },
)
lsm.Run()
lsm.AppendSteps(
func() LSMAction { return LSMActionNext },
)
_, done := lsm.Run()
if !done {
t.Error("unexpected done=false after AppendSteps")
}
}
func TestLinearStateMachine_Reset(t *testing.T) {
callCount := 0
lsm := NewLinearStateMachine(
func() LSMAction { callCount++; return LSMActionCancel },
)
lsm.Run()
if !lsm.cancelled {
t.Error("expected cancelled=true after cancel")
}
lsm.Reset()
if lsm.cancelled {
t.Error("expected cancelled=false after Reset")
}
if lsm.index != 0 {
t.Errorf("expected index=0 after Reset, got %d", lsm.index)
}
_, done := lsm.Run()
if !done {
t.Error("expected done=true, step executed again after Reset")
}
if callCount != 2 {
t.Errorf("callCount = %d, want 2 (first run + reset run)", callCount)
}
}
func TestLSMActionConstants(t *testing.T) {
if LSMActionPause != 0 {
t.Errorf("LSMActionPause = %d, want 0", LSMActionPause)
}
if LSMActionNext != 1 {
t.Errorf("LSMActionNext = %d, want 1", LSMActionNext)
}
if LSMActionReset != 2 {
t.Errorf("LSMActionReset = %d, want 2", LSMActionReset)
}
if LSMActionCancel != 3 {
t.Errorf("LSMActionCancel = %d, want 3", LSMActionCancel)
}
}

View File

@@ -0,0 +1,29 @@
package utils
import (
"reflect"
"testing"
)
func TestByteSlicesToStrings(t *testing.T) {
tests := []struct {
name string
input [][]byte
want []string
}{
{"nil", nil, []string{}},
{"empty", [][]byte{}, []string{}},
{"single", [][]byte{[]byte("hello")}, []string{"hello"}},
{"multiple", [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, []string{"foo", "bar", "baz"}},
{"empty element", [][]byte{[]byte("a"), []byte{}, []byte("b")}, []string{"a", "", "b"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ByteSlicesToStrings(tt.input)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ByteSlicesToStrings() = %v, want %v", got, tt.want)
}
})
}
}