206 lines
4.9 KiB
Go
206 lines
4.9 KiB
Go
package udp
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"git.difuse.io/Difuse/Mellaris/modifier"
|
|
|
|
"github.com/google/gopacket"
|
|
"github.com/google/gopacket/layers"
|
|
)
|
|
|
|
func dnsResponseBytes(t *testing.T, id uint16, qtype layers.DNSType, name string) []byte {
|
|
t.Helper()
|
|
dns := &layers.DNS{
|
|
ID: id,
|
|
QR: true,
|
|
OpCode: layers.DNSOpCodeQuery,
|
|
AA: false,
|
|
RD: true,
|
|
RA: true,
|
|
ResponseCode: layers.DNSResponseCodeNoErr,
|
|
QDCount: 1,
|
|
Questions: []layers.DNSQuestion{
|
|
{
|
|
Name: []byte(name),
|
|
Type: qtype,
|
|
Class: layers.DNSClassIN,
|
|
},
|
|
},
|
|
}
|
|
buf := gopacket.NewSerializeBuffer()
|
|
err := gopacket.SerializeLayers(buf, gopacket.SerializeOptions{
|
|
FixLengths: true,
|
|
ComputeChecksums: true,
|
|
}, dns)
|
|
if err != nil {
|
|
t.Fatalf("failed to serialize DNS response: %v", err)
|
|
}
|
|
return buf.Bytes()
|
|
}
|
|
|
|
func TestDNSModifier_Name(t *testing.T) {
|
|
m := &DNSModifier{}
|
|
if m.Name() != "dns" {
|
|
t.Errorf("Name() = %q, want dns", m.Name())
|
|
}
|
|
}
|
|
|
|
func TestDNSModifier_New_NoArgs(t *testing.T) {
|
|
m := &DNSModifier{}
|
|
inst, err := m.New(map[string]interface{}{})
|
|
if err != nil {
|
|
t.Fatalf("New() unexpected error: %v", err)
|
|
}
|
|
if inst == nil {
|
|
t.Fatal("New() returned nil instance")
|
|
}
|
|
}
|
|
|
|
func TestDNSModifier_New_WithA(t *testing.T) {
|
|
m := &DNSModifier{}
|
|
inst, err := m.New(map[string]interface{}{
|
|
"a": "192.168.1.1",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("New() unexpected error: %v", err)
|
|
}
|
|
dnsInst, ok := inst.(*dnsModifierInstance)
|
|
if !ok {
|
|
t.Fatalf("New() returned wrong type: %T", inst)
|
|
}
|
|
if dnsInst.A == nil || dnsInst.A.String() != "192.168.1.1" {
|
|
t.Errorf("A = %v, want 192.168.1.1", dnsInst.A)
|
|
}
|
|
}
|
|
|
|
func TestDNSModifier_New_WithAAAA(t *testing.T) {
|
|
m := &DNSModifier{}
|
|
inst, err := m.New(map[string]interface{}{
|
|
"aaaa": "2001:db8::1",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("New() unexpected error: %v", err)
|
|
}
|
|
dnsInst, ok := inst.(*dnsModifierInstance)
|
|
if !ok {
|
|
t.Fatalf("New() returned wrong type: %T", inst)
|
|
}
|
|
if dnsInst.AAAA == nil || dnsInst.AAAA.String() != "2001:db8::1" {
|
|
t.Errorf("AAAA = %v, want 2001:db8::1", dnsInst.AAAA)
|
|
}
|
|
}
|
|
|
|
func TestDNSModifier_New_InvalidA(t *testing.T) {
|
|
m := &DNSModifier{}
|
|
_, err := m.New(map[string]interface{}{
|
|
"a": "not-an-ip",
|
|
})
|
|
if err == nil {
|
|
t.Fatal("New() expected error for invalid A")
|
|
}
|
|
}
|
|
|
|
func TestDNSModifier_New_InvalidAAAA(t *testing.T) {
|
|
m := &DNSModifier{}
|
|
_, err := m.New(map[string]interface{}{
|
|
"aaaa": "not-an-ip",
|
|
})
|
|
if err == nil {
|
|
t.Fatal("New() expected error for invalid AAAA")
|
|
}
|
|
}
|
|
|
|
func TestDNSModifierInstance_Process_AType(t *testing.T) {
|
|
m := &DNSModifier{}
|
|
inst, err := m.New(map[string]interface{}{
|
|
"a": "10.10.10.10",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
dnsResp := dnsResponseBytes(t, 1, layers.DNSTypeA, "example.com")
|
|
modified, err := inst.(*dnsModifierInstance).Process(dnsResp)
|
|
if err != nil {
|
|
t.Fatalf("Process() error: %v", err)
|
|
}
|
|
|
|
result := &layers.DNS{}
|
|
err = result.DecodeFromBytes(modified, gopacket.NilDecodeFeedback)
|
|
if err != nil {
|
|
t.Fatalf("decode result: %v", err)
|
|
}
|
|
if len(result.Answers) != 1 {
|
|
t.Fatalf("expected 1 answer, got %d", len(result.Answers))
|
|
}
|
|
if result.Answers[0].Type != layers.DNSTypeA {
|
|
t.Errorf("answer type = %d, want A", result.Answers[0].Type)
|
|
}
|
|
if result.Answers[0].IP.String() != "10.10.10.10" {
|
|
t.Errorf("answer IP = %s, want 10.10.10.10", result.Answers[0].IP)
|
|
}
|
|
}
|
|
|
|
func TestDNSModifierInstance_Process_AAAAType(t *testing.T) {
|
|
m := &DNSModifier{}
|
|
inst, err := m.New(map[string]interface{}{
|
|
"aaaa": "2001:db8::1",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
dnsResp := dnsResponseBytes(t, 1, layers.DNSTypeAAAA, "example.com")
|
|
modified, err := inst.(*dnsModifierInstance).Process(dnsResp)
|
|
if err != nil {
|
|
t.Fatalf("Process() error: %v", err)
|
|
}
|
|
|
|
result := &layers.DNS{}
|
|
err = result.DecodeFromBytes(modified, gopacket.NilDecodeFeedback)
|
|
if err != nil {
|
|
t.Fatalf("decode result: %v", err)
|
|
}
|
|
if len(result.Answers) != 1 {
|
|
t.Fatalf("expected 1 answer, got %d", len(result.Answers))
|
|
}
|
|
if result.Answers[0].Type != layers.DNSTypeAAAA {
|
|
t.Errorf("answer type = %d, want AAAA", result.Answers[0].Type)
|
|
}
|
|
}
|
|
|
|
func TestDNSModifierInstance_Process_NotAResponse(t *testing.T) {
|
|
inst := &dnsModifierInstance{}
|
|
query := dnsResponseBytes(t, 1, layers.DNSTypeA, "test.com")
|
|
query[2] &^= 0x80 // clear QR bit
|
|
|
|
_, err := inst.Process(query)
|
|
if err == nil {
|
|
t.Fatal("expected error for non-response")
|
|
}
|
|
errPkt, ok := err.(*modifier.ErrInvalidPacket)
|
|
if !ok {
|
|
t.Fatalf("expected *ErrInvalidPacket, got %T", err)
|
|
}
|
|
_ = errPkt
|
|
}
|
|
|
|
func TestDNSModifierInstance_Process_NoQuestions(t *testing.T) {
|
|
inst := &dnsModifierInstance{}
|
|
dns := &layers.DNS{
|
|
ID: 1,
|
|
QR: true,
|
|
RD: true,
|
|
RA: true,
|
|
ResponseCode: layers.DNSResponseCodeNoErr,
|
|
}
|
|
buf := gopacket.NewSerializeBuffer()
|
|
gopacket.SerializeLayers(buf, gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}, dns)
|
|
|
|
_, err := inst.Process(buf.Bytes())
|
|
if err == nil {
|
|
t.Fatal("expected error for empty questions")
|
|
}
|
|
}
|