diff --git a/engine/engine.go b/engine/engine.go index 6b29271..483476c 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -2,6 +2,7 @@ package engine import ( "context" + "encoding/binary" "runtime" "git.difuse.io/Difuse/Mellaris/io" @@ -24,6 +25,7 @@ func NewEngine(config Config) (Engine, error) { if workerCount <= 0 { workerCount = runtime.NumCPU() } + macResolver := newSourceMACResolver() var err error workers := make([]*worker, workerCount) for i := range workers { @@ -32,6 +34,7 @@ func NewEngine(config Config) (Engine, error) { ChanSize: config.WorkerQueueSize, Logger: config.Logger, Ruleset: config.Ruleset, + MACResolver: macResolver, TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal, TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn, UDPMaxStreams: config.WorkerUDPMaxStreams, @@ -90,13 +93,8 @@ func (e *engine) Run(ctx context.Context) error { // dispatch dispatches a packet to a worker. func (e *engine) dispatch(p io.Packet) bool { data := p.Data() - ipVersion := data[0] >> 4 - var layerType gopacket.LayerType - if ipVersion == 4 { - layerType = layers.LayerTypeIPv4 - } else if ipVersion == 6 { - layerType = layers.LayerTypeIPv6 - } else { + layerType, srcMAC, dstMAC, ok := classifyPacket(data) + if !ok { // Unsupported network layer _ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil) return true @@ -107,9 +105,41 @@ func (e *engine) dispatch(p io.Packet) bool { e.workers[index].Feed(&workerPacket{ StreamID: p.StreamID(), Packet: packet, + SrcMAC: srcMAC, + DstMAC: dstMAC, SetVerdict: func(v io.Verdict, b []byte) error { return e.io.SetVerdict(p, v, b) }, }) return true } + +// classifyPacket detects packet framing and returns a gopacket decode layer +// plus best-effort source/destination MAC addresses when available. +func classifyPacket(data []byte) (gopacket.LayerType, []byte, []byte, bool) { + if len(data) == 0 { + return 0, nil, nil, false + } + + // Fast path for IP packets (NFQUEUE payloads are typically IP-only). + ipVersion := data[0] >> 4 + if ipVersion == 4 { + return layers.LayerTypeIPv4, nil, nil, true + } + if ipVersion == 6 { + return layers.LayerTypeIPv6, nil, nil, true + } + + // Ethernet frame path (for custom PacketIO implementations). + if len(data) >= 14 { + etherType := binary.BigEndian.Uint16(data[12:14]) + if etherType == uint16(layers.EthernetTypeIPv4) || etherType == uint16(layers.EthernetTypeIPv6) { + return layers.LayerTypeEthernet, + append([]byte(nil), data[6:12]...), + append([]byte(nil), data[:6]...), + true + } + } + + return 0, nil, nil, false +} diff --git a/engine/mac_resolver.go b/engine/mac_resolver.go new file mode 100644 index 0000000..d8c1a0e --- /dev/null +++ b/engine/mac_resolver.go @@ -0,0 +1,145 @@ +package engine + +import ( + "bufio" + "net" + "os" + "strings" + "sync" + "time" +) + +const ( + ifaceCacheTTL = 30 * time.Second + arpCacheTTL = 10 * time.Second +) + +type sourceMACResolver struct { + mu sync.RWMutex + + lastIfaceRefresh time.Time + ifaceByIP map[string]net.HardwareAddr + + lastARPRefresh time.Time + arpByIP map[string]net.HardwareAddr +} + +func newSourceMACResolver() *sourceMACResolver { + return &sourceMACResolver{ + ifaceByIP: make(map[string]net.HardwareAddr), + arpByIP: make(map[string]net.HardwareAddr), + } +} + +func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr { + if ip == nil { + return nil + } + ipKey := ip.String() + if ipKey == "" { + return nil + } + + now := time.Now() + r.mu.RLock() + ifaceRefreshDue := now.Sub(r.lastIfaceRefresh) > ifaceCacheTTL + arpRefreshDue := now.Sub(r.lastARPRefresh) > arpCacheTTL + if mac := r.ifaceByIP[ipKey]; len(mac) != 0 { + out := append(net.HardwareAddr(nil), mac...) + r.mu.RUnlock() + return out + } + if mac := r.arpByIP[ipKey]; len(mac) != 0 && !arpRefreshDue { + out := append(net.HardwareAddr(nil), mac...) + r.mu.RUnlock() + return out + } + r.mu.RUnlock() + + if ifaceRefreshDue { + r.refreshIfaceCache(now) + } + if arpRefreshDue { + r.refreshARPCache(now) + } + + r.mu.RLock() + defer r.mu.RUnlock() + if mac := r.ifaceByIP[ipKey]; len(mac) != 0 { + return append(net.HardwareAddr(nil), mac...) + } + if mac := r.arpByIP[ipKey]; len(mac) != 0 { + return append(net.HardwareAddr(nil), mac...) + } + return nil +} + +func (r *sourceMACResolver) refreshIfaceCache(now time.Time) { + interfaces, err := net.Interfaces() + if err != nil { + return + } + + m := make(map[string]net.HardwareAddr) + for _, iface := range interfaces { + if len(iface.HardwareAddr) == 0 { + continue + } + addrs, err := iface.Addrs() + if err != nil { + continue + } + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if !ok || ipNet.IP == nil { + continue + } + m[ipNet.IP.String()] = append(net.HardwareAddr(nil), iface.HardwareAddr...) + } + } + + r.mu.Lock() + r.ifaceByIP = m + r.lastIfaceRefresh = now + r.mu.Unlock() +} + +func (r *sourceMACResolver) refreshARPCache(now time.Time) { + f, err := os.Open("/proc/net/arp") + if err != nil { + return + } + defer f.Close() + + m := make(map[string]net.HardwareAddr) + scanner := bufio.NewScanner(f) + lineNo := 0 + for scanner.Scan() { + lineNo++ + if lineNo == 1 { + continue // header + } + fields := strings.Fields(scanner.Text()) + if len(fields) < 4 { + continue + } + ipStr := fields[0] + hwAddr := fields[3] + if hwAddr == "00:00:00:00:00:00" { + continue + } + mac, err := net.ParseMAC(hwAddr) + if err != nil { + continue + } + m[ipStr] = append(net.HardwareAddr(nil), mac...) + } + if err := scanner.Err(); err != nil { + return + } + + r.mu.Lock() + r.arpByIP = m + r.lastARPRefresh = now + r.mu.Unlock() +} diff --git a/engine/tcp.go b/engine/tcp.go index c019b5b..52b8f3d 100644 --- a/engine/tcp.go +++ b/engine/tcp.go @@ -27,7 +27,8 @@ const ( type tcpContext struct { *gopacket.PacketMetadata - Verdict tcpVerdict + Verdict tcpVerdict + SrcMAC, DstMAC net.HardwareAddr } func (ctx *tcpContext) GetCaptureInfo() gopacket.CaptureInfo { @@ -46,9 +47,12 @@ type tcpStreamFactory struct { func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream { id := f.Node.Generate() ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw()) + ctx := ac.(*tcpContext) info := ruleset.StreamInfo{ ID: id.Int64(), Protocol: ruleset.ProtocolTCP, + SrcMAC: append(net.HardwareAddr(nil), ctx.SrcMAC...), + DstMAC: append(net.HardwareAddr(nil), ctx.DstMAC...), SrcIP: ipSrc, DstIP: ipDst, SrcPort: uint16(tcp.SrcPort), diff --git a/engine/udp.go b/engine/udp.go index c3b7bf2..6bea55e 100644 --- a/engine/udp.go +++ b/engine/udp.go @@ -31,8 +31,9 @@ const ( var errInvalidModifier = errors.New("invalid modifier") type udpContext struct { - Verdict udpVerdict - Packet []byte + Verdict udpVerdict + Packet []byte + SrcMAC, DstMAC net.HardwareAddr } type udpStreamFactory struct { @@ -50,6 +51,8 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u info := ruleset.StreamInfo{ ID: id.Int64(), Protocol: ruleset.ProtocolUDP, + SrcMAC: append(net.HardwareAddr(nil), uc.SrcMAC...), + DstMAC: append(net.HardwareAddr(nil), uc.DstMAC...), SrcIP: ipSrc, DstIP: ipDst, SrcPort: uint16(udp.SrcPort), diff --git a/engine/worker.go b/engine/worker.go index 5086173..e146637 100644 --- a/engine/worker.go +++ b/engine/worker.go @@ -2,6 +2,7 @@ package engine import ( "context" + "net" "git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/ruleset" @@ -22,13 +23,16 @@ const ( type workerPacket struct { StreamID uint32 Packet gopacket.Packet + SrcMAC net.HardwareAddr + DstMAC net.HardwareAddr SetVerdict func(io.Verdict, []byte) error } type worker struct { - id int - packetChan chan *workerPacket - logger Logger + id int + packetChan chan *workerPacket + logger Logger + macResolver *sourceMACResolver tcpStreamFactory *tcpStreamFactory tcpStreamPool *reassembly.StreamPool @@ -45,6 +49,7 @@ type workerConfig struct { ChanSize int Logger Logger Ruleset ruleset.Ruleset + MACResolver *sourceMACResolver TCPMaxBufferedPagesTotal int TCPMaxBufferedPagesPerConn int UDPMaxStreams int @@ -95,6 +100,7 @@ func newWorker(config workerConfig) (*worker, error) { id: config.ID, packetChan: make(chan *workerPacket, config.ChanSize), logger: config.Logger, + macResolver: config.MACResolver, tcpStreamFactory: tcpSF, tcpStreamPool: tcpStreamPool, tcpAssembler: tcpAssembler, @@ -120,7 +126,7 @@ func (w *worker) Run(ctx context.Context) { // Closed return } - v, b := w.handle(wPkt.StreamID, wPkt.Packet) + v, b := w.handle(wPkt.StreamID, wPkt.Packet, wPkt.SrcMAC, wPkt.DstMAC) _ = wPkt.SetVerdict(v, b) } } @@ -133,18 +139,21 @@ func (w *worker) UpdateRuleset(r ruleset.Ruleset) error { return w.udpStreamFactory.UpdateRuleset(r) } -func (w *worker) handle(streamID uint32, p gopacket.Packet) (io.Verdict, []byte) { +func (w *worker) handle(streamID uint32, p gopacket.Packet, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) { netLayer, trLayer := p.NetworkLayer(), p.TransportLayer() if netLayer == nil || trLayer == nil { // Invalid packet return io.VerdictAccept, nil } ipFlow := netLayer.NetworkFlow() + if len(srcMAC) == 0 && w.macResolver != nil { + srcMAC = w.macResolver.Resolve(net.IP(ipFlow.Src().Raw())) + } switch tr := trLayer.(type) { case *layers.TCP: - return w.handleTCP(ipFlow, p.Metadata(), tr), nil + return w.handleTCP(ipFlow, srcMAC, dstMAC, p.Metadata(), tr), nil case *layers.UDP: - v, modPayload := w.handleUDP(streamID, ipFlow, tr) + v, modPayload := w.handleUDP(streamID, ipFlow, srcMAC, dstMAC, tr) if v == io.VerdictAcceptModify && modPayload != nil { tr.Payload = modPayload _ = tr.SetNetworkLayerForChecksum(netLayer) @@ -167,18 +176,22 @@ func (w *worker) handle(streamID uint32, p gopacket.Packet) (io.Verdict, []byte) } } -func (w *worker) handleTCP(ipFlow gopacket.Flow, pMeta *gopacket.PacketMetadata, tcp *layers.TCP) io.Verdict { +func (w *worker) handleTCP(ipFlow gopacket.Flow, srcMAC, dstMAC net.HardwareAddr, pMeta *gopacket.PacketMetadata, tcp *layers.TCP) io.Verdict { ctx := &tcpContext{ PacketMetadata: pMeta, Verdict: tcpVerdictAccept, + SrcMAC: srcMAC, + DstMAC: dstMAC, } w.tcpAssembler.AssembleWithContext(ipFlow, tcp, ctx) return io.Verdict(ctx.Verdict) } -func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP) (io.Verdict, []byte) { +func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, srcMAC, dstMAC net.HardwareAddr, udp *layers.UDP) (io.Verdict, []byte) { ctx := &udpContext{ Verdict: udpVerdictAccept, + SrcMAC: srcMAC, + DstMAC: dstMAC, } w.udpStreamManager.MatchWithContext(streamID, ipFlow, udp, ctx) return io.Verdict(ctx.Verdict), ctx.Packet diff --git a/ruleset/expr.go b/ruleset/expr.go index 45a6293..11682f0 100644 --- a/ruleset/expr.go +++ b/ruleset/expr.go @@ -191,6 +191,10 @@ func streamInfoToExprEnv(info StreamInfo) map[string]interface{} { m := map[string]interface{}{ "id": info.ID, "proto": info.Protocol.String(), + "mac": map[string]string{ + "src": info.SrcMAC.String(), + "dst": info.DstMAC.String(), + }, "ip": map[string]string{ "src": info.SrcIP.String(), "dst": info.DstIP.String(), @@ -211,7 +215,7 @@ func streamInfoToExprEnv(info StreamInfo) map[string]interface{} { func isBuiltInAnalyzer(name string) bool { switch name { - case "id", "proto", "ip", "port": + case "id", "proto", "mac", "ip", "port": return true default: return false diff --git a/ruleset/interface.go b/ruleset/interface.go index e190544..f7821eb 100644 --- a/ruleset/interface.go +++ b/ruleset/interface.go @@ -67,6 +67,7 @@ const ( type StreamInfo struct { ID int64 Protocol Protocol + SrcMAC, DstMAC net.HardwareAddr SrcIP, DstIP net.IP SrcPort, DstPort uint16 Props analyzer.CombinedPropMap