package udp import ( "testing" "git.difuse.io/Difuse/Mellaris/modifier" "github.com/google/gopacket" "github.com/google/gopacket/layers" ) func dnsResponseBytes(t *testing.T, id uint16, qtype layers.DNSType, name string) []byte { t.Helper() dns := &layers.DNS{ ID: id, QR: true, OpCode: layers.DNSOpCodeQuery, AA: false, RD: true, RA: true, ResponseCode: layers.DNSResponseCodeNoErr, QDCount: 1, Questions: []layers.DNSQuestion{ { Name: []byte(name), Type: qtype, Class: layers.DNSClassIN, }, }, } buf := gopacket.NewSerializeBuffer() err := gopacket.SerializeLayers(buf, gopacket.SerializeOptions{ FixLengths: true, ComputeChecksums: true, }, dns) if err != nil { t.Fatalf("failed to serialize DNS response: %v", err) } return buf.Bytes() } func TestDNSModifier_Name(t *testing.T) { m := &DNSModifier{} if m.Name() != "dns" { t.Errorf("Name() = %q, want dns", m.Name()) } } func TestDNSModifier_New_NoArgs(t *testing.T) { m := &DNSModifier{} inst, err := m.New(map[string]interface{}{}) if err != nil { t.Fatalf("New() unexpected error: %v", err) } if inst == nil { t.Fatal("New() returned nil instance") } } func TestDNSModifier_New_WithA(t *testing.T) { m := &DNSModifier{} inst, err := m.New(map[string]interface{}{ "a": "192.168.1.1", }) if err != nil { t.Fatalf("New() unexpected error: %v", err) } dnsInst, ok := inst.(*dnsModifierInstance) if !ok { t.Fatalf("New() returned wrong type: %T", inst) } if dnsInst.A == nil || dnsInst.A.String() != "192.168.1.1" { t.Errorf("A = %v, want 192.168.1.1", dnsInst.A) } } func TestDNSModifier_New_WithAAAA(t *testing.T) { m := &DNSModifier{} inst, err := m.New(map[string]interface{}{ "aaaa": "2001:db8::1", }) if err != nil { t.Fatalf("New() unexpected error: %v", err) } dnsInst, ok := inst.(*dnsModifierInstance) if !ok { t.Fatalf("New() returned wrong type: %T", inst) } if dnsInst.AAAA == nil || dnsInst.AAAA.String() != "2001:db8::1" { t.Errorf("AAAA = %v, want 2001:db8::1", dnsInst.AAAA) } } func TestDNSModifier_New_InvalidA(t *testing.T) { m := &DNSModifier{} _, err := m.New(map[string]interface{}{ "a": "not-an-ip", }) if err == nil { t.Fatal("New() expected error for invalid A") } } func TestDNSModifier_New_InvalidAAAA(t *testing.T) { m := &DNSModifier{} _, err := m.New(map[string]interface{}{ "aaaa": "not-an-ip", }) if err == nil { t.Fatal("New() expected error for invalid AAAA") } } func TestDNSModifierInstance_Process_AType(t *testing.T) { m := &DNSModifier{} inst, err := m.New(map[string]interface{}{ "a": "10.10.10.10", }) if err != nil { t.Fatal(err) } dnsResp := dnsResponseBytes(t, 1, layers.DNSTypeA, "example.com") modified, err := inst.(*dnsModifierInstance).Process(dnsResp) if err != nil { t.Fatalf("Process() error: %v", err) } result := &layers.DNS{} err = result.DecodeFromBytes(modified, gopacket.NilDecodeFeedback) if err != nil { t.Fatalf("decode result: %v", err) } if len(result.Answers) != 1 { t.Fatalf("expected 1 answer, got %d", len(result.Answers)) } if result.Answers[0].Type != layers.DNSTypeA { t.Errorf("answer type = %d, want A", result.Answers[0].Type) } if result.Answers[0].IP.String() != "10.10.10.10" { t.Errorf("answer IP = %s, want 10.10.10.10", result.Answers[0].IP) } } func TestDNSModifierInstance_Process_AAAAType(t *testing.T) { m := &DNSModifier{} inst, err := m.New(map[string]interface{}{ "aaaa": "2001:db8::1", }) if err != nil { t.Fatal(err) } dnsResp := dnsResponseBytes(t, 1, layers.DNSTypeAAAA, "example.com") modified, err := inst.(*dnsModifierInstance).Process(dnsResp) if err != nil { t.Fatalf("Process() error: %v", err) } result := &layers.DNS{} err = result.DecodeFromBytes(modified, gopacket.NilDecodeFeedback) if err != nil { t.Fatalf("decode result: %v", err) } if len(result.Answers) != 1 { t.Fatalf("expected 1 answer, got %d", len(result.Answers)) } if result.Answers[0].Type != layers.DNSTypeAAAA { t.Errorf("answer type = %d, want AAAA", result.Answers[0].Type) } } func TestDNSModifierInstance_Process_NotAResponse(t *testing.T) { inst := &dnsModifierInstance{} query := dnsResponseBytes(t, 1, layers.DNSTypeA, "test.com") query[2] &^= 0x80 // clear QR bit _, err := inst.Process(query) if err == nil { t.Fatal("expected error for non-response") } errPkt, ok := err.(*modifier.ErrInvalidPacket) if !ok { t.Fatalf("expected *ErrInvalidPacket, got %T", err) } _ = errPkt } func TestDNSModifierInstance_Process_NoQuestions(t *testing.T) { inst := &dnsModifierInstance{} dns := &layers.DNS{ ID: 1, QR: true, RD: true, RA: true, ResponseCode: layers.DNSResponseCodeNoErr, } buf := gopacket.NewSerializeBuffer() gopacket.SerializeLayers(buf, gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}, dns) _, err := inst.Process(buf.Bytes()) if err == nil { t.Fatal("expected error for empty questions") } }