This commit is contained in:
2026-05-14 04:12:57 +00:00
5 changed files with 435 additions and 69 deletions
+108 -19
View File
@@ -8,11 +8,24 @@ type L3Info struct {
IHL uint8
SrcIP [4]byte
DstIP [4]byte
SrcIPv6 [16]byte
DstIPv6 [16]byte
Length uint16
}
func (i L3Info) SrcIPAddr() net.IP { return net.IP(i.SrcIP[:]) }
func (i L3Info) DstIPAddr() net.IP { return net.IP(i.DstIP[:]) }
func (i L3Info) SrcIPAddr() net.IP {
if i.Version == 6 {
return net.IP(i.SrcIPv6[:])
}
return net.IP(i.SrcIP[:])
}
func (i L3Info) DstIPAddr() net.IP {
if i.Version == 6 {
return net.IP(i.DstIPv6[:])
}
return net.IP(i.DstIP[:])
}
type TCPInfo struct {
SrcPort uint16
@@ -32,29 +45,105 @@ type UDPInfo struct {
}
func ParseL3(data []byte) (l3 L3Info, transport []byte, ok bool) {
if len(data) < 20 {
if len(data) < 1 {
return
}
version := data[0] >> 4
if version != 4 {
switch version {
case 4:
if len(data) < 20 {
return
}
ihl := data[0] & 0x0F
if ihl < 5 || len(data) < int(ihl)*4 {
return
}
totalLen := int(uint16(data[2])<<8 | uint16(data[3]))
if totalLen < int(ihl)*4 || totalLen > len(data) {
totalLen = len(data)
}
return L3Info{
Version: 4,
Protocol: data[9],
IHL: ihl,
Length: uint16(totalLen),
SrcIP: [4]byte{data[12], data[13], data[14], data[15]},
DstIP: [4]byte{data[16], data[17], data[18], data[19]},
}, data[ihl*4 : totalLen], true
case 6:
if len(data) < 40 {
return
}
payloadLen := int(uint16(data[4])<<8 | uint16(data[5]))
totalLen := 40 + payloadLen
if payloadLen == 0 || totalLen > len(data) {
totalLen = len(data)
}
protocol, tr, ipv6OK := parseIPv6Transport(data, data[6], totalLen)
if !ipv6OK {
return
}
var srcIP, dstIP [16]byte
copy(srcIP[:], data[8:24])
copy(dstIP[:], data[24:40])
return L3Info{
Version: 6,
Protocol: protocol,
SrcIPv6: srcIP,
DstIPv6: dstIP,
Length: uint16(totalLen),
}, tr, true
default:
return
}
ihl := data[0] & 0x0F
if ihl < 5 || len(data) < int(ihl)*4 {
return
}
func parseIPv6Transport(data []byte, nextHeader uint8, totalLen int) (protocol uint8, transport []byte, ok bool) {
offset := 40
proto := nextHeader
for {
if offset > totalLen {
return 0, nil, false
}
switch proto {
case 0, 43, 60: // Hop-by-hop options, Routing, Destination options
if offset+2 > totalLen {
return 0, nil, false
}
hdrLen := (int(data[offset+1]) + 1) * 8
if hdrLen < 8 || offset+hdrLen > totalLen {
return 0, nil, false
}
proto = data[offset]
offset += hdrLen
case 44: // Fragment
if offset+8 > totalLen {
return 0, nil, false
}
// Only first fragment carries L4 headers.
fragOffset := (uint16(data[offset+2])<<8 | uint16(data[offset+3])) >> 3
if fragOffset != 0 {
return 0, nil, false
}
proto = data[offset]
offset += 8
case 51: // Authentication Header
if offset+2 > totalLen {
return 0, nil, false
}
hdrLen := (int(data[offset+1]) + 2) * 4
if hdrLen < 8 || offset+hdrLen > totalLen {
return 0, nil, false
}
proto = data[offset]
offset += hdrLen
default:
if offset > totalLen {
return 0, nil, false
}
return proto, data[offset:totalLen], true
}
}
totalLen := int(uint16(data[2])<<8 | uint16(data[3]))
if totalLen < int(ihl)*4 || totalLen > len(data) {
totalLen = len(data)
}
return L3Info{
Version: 4,
Protocol: data[9],
IHL: ihl,
Length: uint16(totalLen),
SrcIP: [4]byte{data[12], data[13], data[14], data[15]},
DstIP: [4]byte{data[16], data[17], data[18], data[19]},
}, data[ihl*4:totalLen], true
}
func ParseTCP(transport []byte) (TCPInfo, []byte, bool) {
+127
View File
@@ -0,0 +1,127 @@
package engine
import (
"encoding/binary"
"net"
"testing"
"github.com/google/gopacket/layers"
)
func TestParseL3IPv6UDP(t *testing.T) {
src := net.ParseIP("2001:db8::10").To16()
dst := net.ParseIP("2001:db8::20").To16()
payload := []byte("hello")
pkt := buildIPv6UDPPacket(t, src, dst, 12345, 443, payload)
l3, transport, ok := ParseL3(pkt)
if !ok {
t.Fatal("ParseL3 should parse IPv6 packet")
}
if l3.Version != 6 {
t.Fatalf("version=%d want=6", l3.Version)
}
if l3.Protocol != 17 {
t.Fatalf("protocol=%d want=17 (udp)", l3.Protocol)
}
if !l3.SrcIPAddr().Equal(src) {
t.Fatalf("src=%v want=%v", l3.SrcIPAddr(), src)
}
if !l3.DstIPAddr().Equal(dst) {
t.Fatalf("dst=%v want=%v", l3.DstIPAddr(), dst)
}
udp, gotPayload, ok := ParseUDP(transport)
if !ok {
t.Fatal("ParseUDP should parse transport payload")
}
if udp.SrcPort != 12345 || udp.DstPort != 443 {
t.Fatalf("ports=%d->%d want=12345->443", udp.SrcPort, udp.DstPort)
}
if string(gotPayload) != string(payload) {
t.Fatalf("payload=%q want=%q", string(gotPayload), string(payload))
}
}
func TestParseL3IPv6HopByHopThenUDP(t *testing.T) {
src := net.ParseIP("2001:db8::1").To16()
dst := net.ParseIP("2001:db8::2").To16()
udpPayload := []byte("abc")
udpLen := 8 + len(udpPayload)
totalPayloadLen := 8 + udpLen // 8-byte hop-by-hop extension + udp packet
pkt := make([]byte, 40+totalPayloadLen)
pkt[0] = 0x60
binary.BigEndian.PutUint16(pkt[4:6], uint16(totalPayloadLen))
pkt[6] = 0 // Hop-by-hop
pkt[7] = 64 // Hop limit
copy(pkt[8:24], src)
copy(pkt[24:40], dst)
off := 40
pkt[off+0] = 17 // Next header: UDP
pkt[off+1] = 0 // Hdr Ext Len: (0+1)*8 = 8 bytes
off += 8
binary.BigEndian.PutUint16(pkt[off:off+2], 5353)
binary.BigEndian.PutUint16(pkt[off+2:off+4], 53)
binary.BigEndian.PutUint16(pkt[off+4:off+6], uint16(udpLen))
copy(pkt[off+8:], udpPayload)
l3, transport, ok := ParseL3(pkt)
if !ok {
t.Fatal("ParseL3 should parse IPv6 packet with extension headers")
}
if l3.Version != 6 || l3.Protocol != 17 {
t.Fatalf("version/protocol=%d/%d want=6/17", l3.Version, l3.Protocol)
}
udp, gotPayload, ok := ParseUDP(transport)
if !ok {
t.Fatal("ParseUDP should parse UDP after hop-by-hop extension")
}
if udp.SrcPort != 5353 || udp.DstPort != 53 {
t.Fatalf("ports=%d->%d want=5353->53", udp.SrcPort, udp.DstPort)
}
if string(gotPayload) != string(udpPayload) {
t.Fatalf("payload=%q want=%q", string(gotPayload), string(udpPayload))
}
}
func buildIPv6UDPPacket(t *testing.T, src, dst net.IP, srcPort, dstPort uint16, payload []byte) []byte {
t.Helper()
udpLen := 8 + len(payload)
pkt := make([]byte, 40+udpLen)
pkt[0] = 0x60
binary.BigEndian.PutUint16(pkt[4:6], uint16(udpLen))
pkt[6] = 17
pkt[7] = 64
copy(pkt[8:24], src.To16())
copy(pkt[24:40], dst.To16())
off := 40
binary.BigEndian.PutUint16(pkt[off:off+2], srcPort)
binary.BigEndian.PutUint16(pkt[off+2:off+4], dstPort)
binary.BigEndian.PutUint16(pkt[off+4:off+6], uint16(udpLen))
copy(pkt[off+8:], payload)
return pkt
}
func TestParseL3IPv6FragmentNonFirst(t *testing.T) {
src := net.ParseIP("2001:db8::a").To16()
dst := net.ParseIP("2001:db8::b").To16()
// IPv6 header + fragment header (offset != 0)
pkt := make([]byte, 48)
pkt[0] = 0x60
binary.BigEndian.PutUint16(pkt[4:6], 8)
pkt[6] = 44
pkt[7] = 64
copy(pkt[8:24], src)
copy(pkt[24:40], dst)
pkt[40] = uint8(layers.IPProtocolUDP)
// fragment offset in 8-byte units: 1 (non-first fragment)
binary.BigEndian.PutUint16(pkt[42:44], 1<<3)
_, _, ok := ParseL3(pkt)
if ok {
t.Fatal("ParseL3 should reject non-first IPv6 fragments for L4 parsing")
}
}
+2 -6
View File
@@ -22,8 +22,6 @@ const (
type tcpFlow struct {
streamID uint32
srcIP [4]byte
dstIP [4]byte
srcPort uint16
dstPort uint16
@@ -203,8 +201,8 @@ func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload
func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) *tcpFlow {
id := m.sfNode.Generate()
ipSrc := net.IP(l3.SrcIP[:])
ipDst := net.IP(l3.DstIP[:])
ipSrc := l3.SrcIPAddr()
ipDst := l3.DstIPAddr()
if len(srcMAC) == 0 && m.macResolver != nil {
srcMAC = m.macResolver.Resolve(ipSrc)
}
@@ -248,8 +246,6 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
flow := &tcpFlow{
streamID: streamID,
srcIP: l3.SrcIP,
dstIP: l3.DstIP,
srcPort: tcp.SrcPort,
dstPort: tcp.DstPort,
info: info,
+84 -44
View File
@@ -150,56 +150,67 @@ func (w *worker) handle(wp *workerPacket) (io.Verdict, []byte) {
return io.VerdictAccept, nil
}
ipVersion := data[0] >> 4
if ipVersion == 4 {
l3, transport, ok := ParseL3(data)
if !ok {
return io.VerdictAccept, nil
}
switch l3.Protocol {
case 6: // TCP
tcp, payload, ok := ParseTCP(transport)
if !ok {
return io.VerdictAccept, nil
}
verdict := w.tcpFlowMgr.handle(
wp.StreamID, l3, tcp, payload,
wp.SrcMAC, wp.DstMAC,
)
return verdict, nil
case 17: // UDP
udp, payload, ok := ParseUDP(transport)
if !ok {
return io.VerdictAccept, nil
}
v, modPayload := w.handleUDP(
wp.StreamID, l3, udp, payload,
wp.SrcMAC, wp.DstMAC,
)
if v == io.VerdictAcceptModify && modPayload != nil {
return w.serializeModifiedUDP(data, l3, udp, transport, modPayload)
}
return v, nil
default:
return io.VerdictAccept, nil
}
if v, b, ok := w.handleIPPacket(wp, data); ok {
return v, b
}
// Ethernet frame path (for custom PacketIO)
if ipVersion == 6 {
// TODO: IPv6 support with raw parsing
return io.VerdictAccept, nil
// Ethernet frame fallback path (for custom PacketIO implementations).
if l3Payload, ok := extractL3PayloadFromEthernet(data); ok {
if v, b, ok := w.handleIPPacket(wp, l3Payload); ok {
return v, b
}
}
return io.VerdictAccept, nil
}
func (w *worker) handleIPPacket(wp *workerPacket, data []byte) (io.Verdict, []byte, bool) {
l3, transport, ok := ParseL3(data)
if !ok {
return io.VerdictAccept, nil, false
}
switch l3.Protocol {
case 6: // TCP
tcp, payload, ok := ParseTCP(transport)
if !ok {
return io.VerdictAccept, nil, true
}
verdict := w.tcpFlowMgr.handle(
wp.StreamID, l3, tcp, payload,
wp.SrcMAC, wp.DstMAC,
)
return verdict, nil, true
case 17: // UDP
udp, payload, ok := ParseUDP(transport)
if !ok {
return io.VerdictAccept, nil, true
}
v, modPayload := w.handleUDP(
wp.StreamID, l3, udp, payload,
wp.SrcMAC, wp.DstMAC,
)
if v == io.VerdictAcceptModify && modPayload != nil {
mv, mb := w.serializeModifiedUDP(data, l3, modPayload)
return mv, mb, true
}
return v, nil, true
default:
return io.VerdictAccept, nil, true
}
}
func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) {
ipSrc := net.IP(l3.SrcIP[:])
ipDst := net.IP(l3.DstIP[:])
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, ipSrc.To4(), ipDst.To4())
ipSrc := l3.SrcIPAddr()
ipDst := l3.DstIPAddr()
endpointType := layers.EndpointIPv4
flowSrc := ipSrc.To4()
flowDst := ipDst.To4()
if l3.Version == 6 {
endpointType = layers.EndpointIPv6
flowSrc = ipSrc.To16()
flowDst = ipDst.To16()
}
ipFlow := gopacket.NewFlow(endpointType, flowSrc, flowDst)
if len(srcMAC) == 0 && w.macResolver != nil {
srcMAC = w.macResolver.Resolve(ipSrc)
@@ -219,8 +230,12 @@ func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []by
return io.Verdict(uc.Verdict), uc.Packet
}
func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, udp UDPInfo, transport []byte, modPayload []byte) (io.Verdict, []byte) {
ipPkt := gopacket.NewPacket(fullData, layers.LayerTypeIPv4, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, modPayload []byte) (io.Verdict, []byte) {
layerType := layers.LayerTypeIPv4
if l3.Version == 6 {
layerType = layers.LayerTypeIPv6
}
ipPkt := gopacket.NewPacket(fullData, layerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
netLayer := ipPkt.NetworkLayer()
trLayer := ipPkt.TransportLayer()
if netLayer == nil || trLayer == nil {
@@ -240,3 +255,28 @@ func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, udp UDPInfo, t
}
return io.VerdictAcceptModify, w.modSerializeBuffer.Bytes()
}
func extractL3PayloadFromEthernet(data []byte) ([]byte, bool) {
if len(data) < 14 {
return nil, false
}
offset := 12
etherType := uint16(data[offset])<<8 | uint16(data[offset+1])
offset += 2
for etherType == 0x8100 || etherType == 0x88A8 {
if len(data) < offset+4 {
return nil, false
}
etherType = uint16(data[offset+2])<<8 | uint16(data[offset+3])
offset += 4
}
if etherType != 0x0800 && etherType != 0x86DD {
return nil, false
}
if len(data) <= offset {
return nil, false
}
return data[offset:], true
}
+114
View File
@@ -0,0 +1,114 @@
package engine
import (
"net"
"testing"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
func TestWorkerHandleIPv6TCP(t *testing.T) {
w, err := newWorker(workerConfig{
ID: 0,
Logger: noopTestLogger{},
Ruleset: fixedRuleset{action: ruleset.ActionBlock},
ResultChan: make(chan workerResult, 1),
})
if err != nil {
t.Fatalf("new worker: %v", err)
}
src := net.ParseIP("2001:db8::11").To16()
dst := net.ParseIP("2001:db8::22").To16()
data := serializeIPv6TCP(t, src, dst, 42310, 443, 1000)
v, _ := w.handle(&workerPacket{
StreamID: 11,
Data: data,
})
if v != io.VerdictDropStream {
t.Fatalf("verdict=%v want=%v", v, io.VerdictDropStream)
}
}
func TestWorkerHandleIPv6UDP(t *testing.T) {
w, err := newWorker(workerConfig{
ID: 0,
Logger: noopTestLogger{},
Ruleset: fixedRuleset{action: ruleset.ActionBlock},
ResultChan: make(chan workerResult, 1),
})
if err != nil {
t.Fatalf("new worker: %v", err)
}
src := net.ParseIP("2001:db8::33").To16()
dst := net.ParseIP("2001:db8::44").To16()
data := serializeIPv6UDP(t, src, dst, 50000, 53, []byte("dns"))
v, _ := w.handle(&workerPacket{
StreamID: 12,
Data: data,
})
if v != io.VerdictDropStream {
t.Fatalf("verdict=%v want=%v", v, io.VerdictDropStream)
}
}
func serializeIPv6TCP(t *testing.T, src, dst net.IP, srcPort, dstPort uint16, seq uint32) []byte {
t.Helper()
ip6 := &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolTCP,
SrcIP: src,
DstIP: dst,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
Seq: seq,
SYN: true,
}
if err := tcp.SetNetworkLayerForChecksum(ip6); err != nil {
t.Fatalf("set tcp checksum network layer: %v", err)
}
buf := gopacket.NewSerializeBuffer()
if err := gopacket.SerializeLayers(buf, gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}, ip6, tcp); err != nil {
t.Fatalf("serialize ipv6 tcp: %v", err)
}
return append([]byte(nil), buf.Bytes()...)
}
func serializeIPv6UDP(t *testing.T, src, dst net.IP, srcPort, dstPort uint16, payload []byte) []byte {
t.Helper()
ip6 := &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
SrcIP: src,
DstIP: dst,
}
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
if err := udp.SetNetworkLayerForChecksum(ip6); err != nil {
t.Fatalf("set udp checksum network layer: %v", err)
}
buf := gopacket.NewSerializeBuffer()
if err := gopacket.SerializeLayers(buf, gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}, ip6, udp, gopacket.Payload(payload)); err != nil {
t.Fatalf("serialize ipv6 udp: %v", err)
}
return append([]byte(nil), buf.Bytes()...)
}