From 1e5c1dea75a21cc37397f57f913e7a9b262a82da Mon Sep 17 00:00:00 2001 From: hayzam Date: Thu, 14 May 2026 09:41:07 +0530 Subject: [PATCH] flows: implement ipv --- engine/packet.go | 127 ++++++++++++++++++++++++++++++------ engine/packet_ipv6_test.go | 127 ++++++++++++++++++++++++++++++++++++ engine/tcp_flow.go | 8 +-- engine/worker.go | 128 ++++++++++++++++++++++++------------- engine/worker_ipv6_test.go | 114 +++++++++++++++++++++++++++++++++ 5 files changed, 435 insertions(+), 69 deletions(-) create mode 100644 engine/packet_ipv6_test.go create mode 100644 engine/worker_ipv6_test.go diff --git a/engine/packet.go b/engine/packet.go index 8d5354e..6b47e4e 100644 --- a/engine/packet.go +++ b/engine/packet.go @@ -8,11 +8,24 @@ type L3Info struct { IHL uint8 SrcIP [4]byte DstIP [4]byte + SrcIPv6 [16]byte + DstIPv6 [16]byte Length uint16 } -func (i L3Info) SrcIPAddr() net.IP { return net.IP(i.SrcIP[:]) } -func (i L3Info) DstIPAddr() net.IP { return net.IP(i.DstIP[:]) } +func (i L3Info) SrcIPAddr() net.IP { + if i.Version == 6 { + return net.IP(i.SrcIPv6[:]) + } + return net.IP(i.SrcIP[:]) +} + +func (i L3Info) DstIPAddr() net.IP { + if i.Version == 6 { + return net.IP(i.DstIPv6[:]) + } + return net.IP(i.DstIP[:]) +} type TCPInfo struct { SrcPort uint16 @@ -32,29 +45,105 @@ type UDPInfo struct { } func ParseL3(data []byte) (l3 L3Info, transport []byte, ok bool) { - if len(data) < 20 { + if len(data) < 1 { return } version := data[0] >> 4 - if version != 4 { + switch version { + case 4: + if len(data) < 20 { + return + } + ihl := data[0] & 0x0F + if ihl < 5 || len(data) < int(ihl)*4 { + return + } + totalLen := int(uint16(data[2])<<8 | uint16(data[3])) + if totalLen < int(ihl)*4 || totalLen > len(data) { + totalLen = len(data) + } + return L3Info{ + Version: 4, + Protocol: data[9], + IHL: ihl, + Length: uint16(totalLen), + SrcIP: [4]byte{data[12], data[13], data[14], data[15]}, + DstIP: [4]byte{data[16], data[17], data[18], data[19]}, + }, data[ihl*4 : totalLen], true + case 6: + if len(data) < 40 { + return + } + payloadLen := int(uint16(data[4])<<8 | uint16(data[5])) + totalLen := 40 + payloadLen + if payloadLen == 0 || totalLen > len(data) { + totalLen = len(data) + } + protocol, tr, ipv6OK := parseIPv6Transport(data, data[6], totalLen) + if !ipv6OK { + return + } + var srcIP, dstIP [16]byte + copy(srcIP[:], data[8:24]) + copy(dstIP[:], data[24:40]) + return L3Info{ + Version: 6, + Protocol: protocol, + SrcIPv6: srcIP, + DstIPv6: dstIP, + Length: uint16(totalLen), + }, tr, true + default: return } - ihl := data[0] & 0x0F - if ihl < 5 || len(data) < int(ihl)*4 { - return +} + +func parseIPv6Transport(data []byte, nextHeader uint8, totalLen int) (protocol uint8, transport []byte, ok bool) { + offset := 40 + proto := nextHeader + for { + if offset > totalLen { + return 0, nil, false + } + switch proto { + case 0, 43, 60: // Hop-by-hop options, Routing, Destination options + if offset+2 > totalLen { + return 0, nil, false + } + hdrLen := (int(data[offset+1]) + 1) * 8 + if hdrLen < 8 || offset+hdrLen > totalLen { + return 0, nil, false + } + proto = data[offset] + offset += hdrLen + case 44: // Fragment + if offset+8 > totalLen { + return 0, nil, false + } + // Only first fragment carries L4 headers. + fragOffset := (uint16(data[offset+2])<<8 | uint16(data[offset+3])) >> 3 + if fragOffset != 0 { + return 0, nil, false + } + proto = data[offset] + offset += 8 + case 51: // Authentication Header + if offset+2 > totalLen { + return 0, nil, false + } + hdrLen := (int(data[offset+1]) + 2) * 4 + if hdrLen < 8 || offset+hdrLen > totalLen { + return 0, nil, false + } + proto = data[offset] + offset += hdrLen + default: + if offset > totalLen { + return 0, nil, false + } + return proto, data[offset:totalLen], true + } } - totalLen := int(uint16(data[2])<<8 | uint16(data[3])) - if totalLen < int(ihl)*4 || totalLen > len(data) { - totalLen = len(data) - } - return L3Info{ - Version: 4, - Protocol: data[9], - IHL: ihl, - Length: uint16(totalLen), - SrcIP: [4]byte{data[12], data[13], data[14], data[15]}, - DstIP: [4]byte{data[16], data[17], data[18], data[19]}, - }, data[ihl*4:totalLen], true } func ParseTCP(transport []byte) (TCPInfo, []byte, bool) { diff --git a/engine/packet_ipv6_test.go b/engine/packet_ipv6_test.go new file mode 100644 index 0000000..98d5712 --- /dev/null +++ b/engine/packet_ipv6_test.go @@ -0,0 +1,127 @@ +package engine + +import ( + "encoding/binary" + "net" + "testing" + + "github.com/google/gopacket/layers" +) + +func TestParseL3IPv6UDP(t *testing.T) { + src := net.ParseIP("2001:db8::10").To16() + dst := net.ParseIP("2001:db8::20").To16() + payload := []byte("hello") + + pkt := buildIPv6UDPPacket(t, src, dst, 12345, 443, payload) + l3, transport, ok := ParseL3(pkt) + if !ok { + t.Fatal("ParseL3 should parse IPv6 packet") + } + if l3.Version != 6 { + t.Fatalf("version=%d want=6", l3.Version) + } + if l3.Protocol != 17 { + t.Fatalf("protocol=%d want=17 (udp)", l3.Protocol) + } + if !l3.SrcIPAddr().Equal(src) { + t.Fatalf("src=%v want=%v", l3.SrcIPAddr(), src) + } + if !l3.DstIPAddr().Equal(dst) { + t.Fatalf("dst=%v want=%v", l3.DstIPAddr(), dst) + } + udp, gotPayload, ok := ParseUDP(transport) + if !ok { + t.Fatal("ParseUDP should parse transport payload") + } + if udp.SrcPort != 12345 || udp.DstPort != 443 { + t.Fatalf("ports=%d->%d want=12345->443", udp.SrcPort, udp.DstPort) + } + if string(gotPayload) != string(payload) { + t.Fatalf("payload=%q want=%q", string(gotPayload), string(payload)) + } +} + +func TestParseL3IPv6HopByHopThenUDP(t *testing.T) { + src := net.ParseIP("2001:db8::1").To16() + dst := net.ParseIP("2001:db8::2").To16() + udpPayload := []byte("abc") + udpLen := 8 + len(udpPayload) + totalPayloadLen := 8 + udpLen // 8-byte hop-by-hop extension + udp packet + + pkt := make([]byte, 40+totalPayloadLen) + pkt[0] = 0x60 + binary.BigEndian.PutUint16(pkt[4:6], uint16(totalPayloadLen)) + pkt[6] = 0 // Hop-by-hop + pkt[7] = 64 // Hop limit + copy(pkt[8:24], src) + copy(pkt[24:40], dst) + + off := 40 + pkt[off+0] = 17 // Next header: UDP + pkt[off+1] = 0 // Hdr Ext Len: (0+1)*8 = 8 bytes + off += 8 + + binary.BigEndian.PutUint16(pkt[off:off+2], 5353) + binary.BigEndian.PutUint16(pkt[off+2:off+4], 53) + binary.BigEndian.PutUint16(pkt[off+4:off+6], uint16(udpLen)) + copy(pkt[off+8:], udpPayload) + + l3, transport, ok := ParseL3(pkt) + if !ok { + t.Fatal("ParseL3 should parse IPv6 packet with extension headers") + } + if l3.Version != 6 || l3.Protocol != 17 { + t.Fatalf("version/protocol=%d/%d want=6/17", l3.Version, l3.Protocol) + } + udp, gotPayload, ok := ParseUDP(transport) + if !ok { + t.Fatal("ParseUDP should parse UDP after hop-by-hop extension") + } + if udp.SrcPort != 5353 || udp.DstPort != 53 { + t.Fatalf("ports=%d->%d want=5353->53", udp.SrcPort, udp.DstPort) + } + if string(gotPayload) != string(udpPayload) { + t.Fatalf("payload=%q want=%q", string(gotPayload), string(udpPayload)) + } +} + +func buildIPv6UDPPacket(t *testing.T, src, dst net.IP, srcPort, dstPort uint16, payload []byte) []byte { + t.Helper() + udpLen := 8 + len(payload) + pkt := make([]byte, 40+udpLen) + pkt[0] = 0x60 + binary.BigEndian.PutUint16(pkt[4:6], uint16(udpLen)) + pkt[6] = 17 + pkt[7] = 64 + copy(pkt[8:24], src.To16()) + copy(pkt[24:40], dst.To16()) + off := 40 + binary.BigEndian.PutUint16(pkt[off:off+2], srcPort) + binary.BigEndian.PutUint16(pkt[off+2:off+4], dstPort) + binary.BigEndian.PutUint16(pkt[off+4:off+6], uint16(udpLen)) + copy(pkt[off+8:], payload) + return pkt +} + +func TestParseL3IPv6FragmentNonFirst(t *testing.T) { + src := net.ParseIP("2001:db8::a").To16() + dst := net.ParseIP("2001:db8::b").To16() + + // IPv6 header + fragment header (offset != 0) + pkt := make([]byte, 48) + pkt[0] = 0x60 + binary.BigEndian.PutUint16(pkt[4:6], 8) + pkt[6] = 44 + pkt[7] = 64 + copy(pkt[8:24], src) + copy(pkt[24:40], dst) + pkt[40] = uint8(layers.IPProtocolUDP) + // fragment offset in 8-byte units: 1 (non-first fragment) + binary.BigEndian.PutUint16(pkt[42:44], 1<<3) + + _, _, ok := ParseL3(pkt) + if ok { + t.Fatal("ParseL3 should reject non-first IPv6 fragments for L4 parsing") + } +} diff --git a/engine/tcp_flow.go b/engine/tcp_flow.go index 33ba874..17d8991 100644 --- a/engine/tcp_flow.go +++ b/engine/tcp_flow.go @@ -22,8 +22,6 @@ const ( type tcpFlow struct { streamID uint32 - srcIP [4]byte - dstIP [4]byte srcPort uint16 dstPort uint16 @@ -199,8 +197,8 @@ func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) *tcpFlow { id := m.sfNode.Generate() - ipSrc := net.IP(l3.SrcIP[:]) - ipDst := net.IP(l3.DstIP[:]) + ipSrc := l3.SrcIPAddr() + ipDst := l3.DstIPAddr() if len(srcMAC) == 0 && m.macResolver != nil { srcMAC = m.macResolver.Resolve(ipSrc) } @@ -244,8 +242,6 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay flow := &tcpFlow{ streamID: streamID, - srcIP: l3.SrcIP, - dstIP: l3.DstIP, srcPort: tcp.SrcPort, dstPort: tcp.DstPort, info: info, diff --git a/engine/worker.go b/engine/worker.go index 684acbe..080ceca 100644 --- a/engine/worker.go +++ b/engine/worker.go @@ -150,56 +150,67 @@ func (w *worker) handle(wp *workerPacket) (io.Verdict, []byte) { return io.VerdictAccept, nil } - ipVersion := data[0] >> 4 - if ipVersion == 4 { - l3, transport, ok := ParseL3(data) - if !ok { - return io.VerdictAccept, nil - } - switch l3.Protocol { - case 6: // TCP - tcp, payload, ok := ParseTCP(transport) - if !ok { - return io.VerdictAccept, nil - } - verdict := w.tcpFlowMgr.handle( - wp.StreamID, l3, tcp, payload, - wp.SrcMAC, wp.DstMAC, - ) - return verdict, nil - - case 17: // UDP - udp, payload, ok := ParseUDP(transport) - if !ok { - return io.VerdictAccept, nil - } - v, modPayload := w.handleUDP( - wp.StreamID, l3, udp, payload, - wp.SrcMAC, wp.DstMAC, - ) - if v == io.VerdictAcceptModify && modPayload != nil { - return w.serializeModifiedUDP(data, l3, udp, transport, modPayload) - } - return v, nil - - default: - return io.VerdictAccept, nil - } + if v, b, ok := w.handleIPPacket(wp, data); ok { + return v, b } - // Ethernet frame path (for custom PacketIO) - if ipVersion == 6 { - // TODO: IPv6 support with raw parsing - return io.VerdictAccept, nil + // Ethernet frame fallback path (for custom PacketIO implementations). + if l3Payload, ok := extractL3PayloadFromEthernet(data); ok { + if v, b, ok := w.handleIPPacket(wp, l3Payload); ok { + return v, b + } } return io.VerdictAccept, nil } +func (w *worker) handleIPPacket(wp *workerPacket, data []byte) (io.Verdict, []byte, bool) { + l3, transport, ok := ParseL3(data) + if !ok { + return io.VerdictAccept, nil, false + } + switch l3.Protocol { + case 6: // TCP + tcp, payload, ok := ParseTCP(transport) + if !ok { + return io.VerdictAccept, nil, true + } + verdict := w.tcpFlowMgr.handle( + wp.StreamID, l3, tcp, payload, + wp.SrcMAC, wp.DstMAC, + ) + return verdict, nil, true + case 17: // UDP + udp, payload, ok := ParseUDP(transport) + if !ok { + return io.VerdictAccept, nil, true + } + v, modPayload := w.handleUDP( + wp.StreamID, l3, udp, payload, + wp.SrcMAC, wp.DstMAC, + ) + if v == io.VerdictAcceptModify && modPayload != nil { + mv, mb := w.serializeModifiedUDP(data, l3, modPayload) + return mv, mb, true + } + return v, nil, true + default: + return io.VerdictAccept, nil, true + } +} + func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) { - ipSrc := net.IP(l3.SrcIP[:]) - ipDst := net.IP(l3.DstIP[:]) - ipFlow := gopacket.NewFlow(layers.EndpointIPv4, ipSrc.To4(), ipDst.To4()) + ipSrc := l3.SrcIPAddr() + ipDst := l3.DstIPAddr() + endpointType := layers.EndpointIPv4 + flowSrc := ipSrc.To4() + flowDst := ipDst.To4() + if l3.Version == 6 { + endpointType = layers.EndpointIPv6 + flowSrc = ipSrc.To16() + flowDst = ipDst.To16() + } + ipFlow := gopacket.NewFlow(endpointType, flowSrc, flowDst) if len(srcMAC) == 0 && w.macResolver != nil { srcMAC = w.macResolver.Resolve(ipSrc) @@ -219,8 +230,12 @@ func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []by return io.Verdict(uc.Verdict), uc.Packet } -func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, udp UDPInfo, transport []byte, modPayload []byte) (io.Verdict, []byte) { - ipPkt := gopacket.NewPacket(fullData, layers.LayerTypeIPv4, gopacket.DecodeOptions{Lazy: true, NoCopy: true}) +func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, modPayload []byte) (io.Verdict, []byte) { + layerType := layers.LayerTypeIPv4 + if l3.Version == 6 { + layerType = layers.LayerTypeIPv6 + } + ipPkt := gopacket.NewPacket(fullData, layerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true}) netLayer := ipPkt.NetworkLayer() trLayer := ipPkt.TransportLayer() if netLayer == nil || trLayer == nil { @@ -240,3 +255,28 @@ func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, udp UDPInfo, t } return io.VerdictAcceptModify, w.modSerializeBuffer.Bytes() } + +func extractL3PayloadFromEthernet(data []byte) ([]byte, bool) { + if len(data) < 14 { + return nil, false + } + offset := 12 + etherType := uint16(data[offset])<<8 | uint16(data[offset+1]) + offset += 2 + + for etherType == 0x8100 || etherType == 0x88A8 { + if len(data) < offset+4 { + return nil, false + } + etherType = uint16(data[offset+2])<<8 | uint16(data[offset+3]) + offset += 4 + } + + if etherType != 0x0800 && etherType != 0x86DD { + return nil, false + } + if len(data) <= offset { + return nil, false + } + return data[offset:], true +} diff --git a/engine/worker_ipv6_test.go b/engine/worker_ipv6_test.go new file mode 100644 index 0000000..8c55fc8 --- /dev/null +++ b/engine/worker_ipv6_test.go @@ -0,0 +1,114 @@ +package engine + +import ( + "net" + "testing" + + "git.difuse.io/Difuse/Mellaris/io" + "git.difuse.io/Difuse/Mellaris/ruleset" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +func TestWorkerHandleIPv6TCP(t *testing.T) { + w, err := newWorker(workerConfig{ + ID: 0, + Logger: noopTestLogger{}, + Ruleset: fixedRuleset{action: ruleset.ActionBlock}, + ResultChan: make(chan workerResult, 1), + }) + if err != nil { + t.Fatalf("new worker: %v", err) + } + + src := net.ParseIP("2001:db8::11").To16() + dst := net.ParseIP("2001:db8::22").To16() + data := serializeIPv6TCP(t, src, dst, 42310, 443, 1000) + + v, _ := w.handle(&workerPacket{ + StreamID: 11, + Data: data, + }) + if v != io.VerdictDropStream { + t.Fatalf("verdict=%v want=%v", v, io.VerdictDropStream) + } +} + +func TestWorkerHandleIPv6UDP(t *testing.T) { + w, err := newWorker(workerConfig{ + ID: 0, + Logger: noopTestLogger{}, + Ruleset: fixedRuleset{action: ruleset.ActionBlock}, + ResultChan: make(chan workerResult, 1), + }) + if err != nil { + t.Fatalf("new worker: %v", err) + } + + src := net.ParseIP("2001:db8::33").To16() + dst := net.ParseIP("2001:db8::44").To16() + data := serializeIPv6UDP(t, src, dst, 50000, 53, []byte("dns")) + + v, _ := w.handle(&workerPacket{ + StreamID: 12, + Data: data, + }) + if v != io.VerdictDropStream { + t.Fatalf("verdict=%v want=%v", v, io.VerdictDropStream) + } +} + +func serializeIPv6TCP(t *testing.T, src, dst net.IP, srcPort, dstPort uint16, seq uint32) []byte { + t.Helper() + ip6 := &layers.IPv6{ + Version: 6, + HopLimit: 64, + NextHeader: layers.IPProtocolTCP, + SrcIP: src, + DstIP: dst, + } + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + Seq: seq, + SYN: true, + } + if err := tcp.SetNetworkLayerForChecksum(ip6); err != nil { + t.Fatalf("set tcp checksum network layer: %v", err) + } + buf := gopacket.NewSerializeBuffer() + if err := gopacket.SerializeLayers(buf, gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + }, ip6, tcp); err != nil { + t.Fatalf("serialize ipv6 tcp: %v", err) + } + return append([]byte(nil), buf.Bytes()...) +} + +func serializeIPv6UDP(t *testing.T, src, dst net.IP, srcPort, dstPort uint16, payload []byte) []byte { + t.Helper() + ip6 := &layers.IPv6{ + Version: 6, + HopLimit: 64, + NextHeader: layers.IPProtocolUDP, + SrcIP: src, + DstIP: dst, + } + udp := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + if err := udp.SetNetworkLayerForChecksum(ip6); err != nil { + t.Fatalf("set udp checksum network layer: %v", err) + } + buf := gopacket.NewSerializeBuffer() + if err := gopacket.SerializeLayers(buf, gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + }, ip6, udp, gopacket.Payload(payload)); err != nil { + t.Fatalf("serialize ipv6 udp: %v", err) + } + return append([]byte(nil), buf.Bytes()...) +}