package engine import ( "context" "net" "git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/ruleset" "github.com/bwmarrin/snowflake" "github.com/google/gopacket" "github.com/google/gopacket/layers" ) type workerPacket struct { Packet io.Packet StreamID uint32 Data []byte SrcMAC net.HardwareAddr DstMAC net.HardwareAddr Gen int64 } type workerResult struct { Packet io.Packet StreamID uint32 Verdict io.Verdict ModifiedPacket []byte Gen int64 } type worker struct { id int packetChan chan *workerPacket resultChan chan workerResult logger Logger macResolver *sourceMACResolver tcpFlowMgr *tcpFlowManager udpSM *udpStreamManager modSerializeBuffer gopacket.SerializeBuffer } type workerConfig struct { ID int ChanSize int Logger Logger Ruleset ruleset.Ruleset MACResolver *sourceMACResolver TCPMaxBufferedPagesTotal int // unused, kept for config compat TCPMaxBufferedPagesPerConn int // unused, kept for config compat UDPMaxStreams int AnalyzerSelectionMode AnalyzerSelectionMode ResultChan chan workerResult Stats *statsCounters } func (c *workerConfig) fillDefaults() { if c.ChanSize <= 0 { c.ChanSize = 64 } if c.UDPMaxStreams <= 0 { c.UDPMaxStreams = 4096 } } func newWorker(config workerConfig) (*worker, error) { config.fillDefaults() sfNode, err := snowflake.NewNode(int64(config.ID)) if err != nil { return nil, err } selector := newAnalyzerSelector(config.AnalyzerSelectionMode, config.Stats) tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode, selector) if config.Ruleset != nil { tcpMgr.updateRuleset(config.Ruleset, 0) } udpSF := &udpStreamFactory{ WorkerID: config.ID, Logger: config.Logger, Node: sfNode, Ruleset: config.Ruleset, Selector: selector, Stats: config.Stats, } udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams, config.Stats) if err != nil { return nil, err } return &worker{ id: config.ID, packetChan: make(chan *workerPacket, config.ChanSize), resultChan: config.ResultChan, logger: config.Logger, macResolver: config.MACResolver, tcpFlowMgr: tcpMgr, udpSM: udpSM, modSerializeBuffer: gopacket.NewSerializeBuffer(), }, nil } func (w *worker) Feed(p *workerPacket) bool { select { case w.packetChan <- p: return true default: return false } } func (w *worker) FeedBlocking(p *workerPacket) { w.packetChan <- p } func (w *worker) Run(ctx context.Context) { w.logger.WorkerStart(w.id) defer w.logger.WorkerStop(w.id) for { select { case <-ctx.Done(): return case wp := <-w.packetChan: if wp == nil { return } v, b := w.handle(wp) w.resultChan <- workerResult{ Packet: wp.Packet, StreamID: wp.StreamID, Verdict: v, ModifiedPacket: b, Gen: wp.Gen, } } } } func (w *worker) UpdateRuleset(r ruleset.Ruleset) error { w.tcpFlowMgr.updateRuleset(r, 0) return w.udpSM.factory.UpdateRuleset(r) } func (w *worker) handle(wp *workerPacket) (io.Verdict, []byte) { data := wp.Data if len(data) == 0 { return io.VerdictAccept, nil } if v, b, ok := w.handleIPPacket(wp, data); ok { return v, b } // Ethernet frame fallback path (for custom PacketIO implementations). if l3Payload, ok := extractL3PayloadFromEthernet(data); ok { if v, b, ok := w.handleIPPacket(wp, l3Payload); ok { return v, b } } return io.VerdictAccept, nil } func (w *worker) handleIPPacket(wp *workerPacket, data []byte) (io.Verdict, []byte, bool) { l3, transport, ok := ParseL3(data) if !ok { return io.VerdictAccept, nil, false } switch l3.Protocol { case 6: // TCP tcp, payload, ok := ParseTCP(transport) if !ok { return io.VerdictAccept, nil, true } verdict := w.tcpFlowMgr.handle( wp.StreamID, l3, tcp, payload, wp.SrcMAC, wp.DstMAC, ) return verdict, nil, true case 17: // UDP udp, payload, ok := ParseUDP(transport) if !ok { return io.VerdictAccept, nil, true } v, modPayload := w.handleUDP( wp.StreamID, l3, udp, payload, wp.SrcMAC, wp.DstMAC, ) if v == io.VerdictAcceptModify && modPayload != nil { mv, mb := w.serializeModifiedUDP(data, l3, modPayload) return mv, mb, true } return v, nil, true default: return io.VerdictAccept, nil, true } } func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) { ipSrc := l3.SrcIPAddr() ipDst := l3.DstIPAddr() endpointType := layers.EndpointIPv4 flowSrc := ipSrc.To4() flowDst := ipDst.To4() if l3.Version == 6 { endpointType = layers.EndpointIPv6 flowSrc = ipSrc.To16() flowDst = ipDst.To16() } ipFlow := gopacket.NewFlow(endpointType, flowSrc, flowDst) if len(srcMAC) == 0 && w.macResolver != nil { srcMAC = w.macResolver.Resolve(ipSrc) } uc := &udpContext{ Verdict: udpVerdictAccept, SrcMAC: srcMAC, DstMAC: dstMAC, } // Temporarily set payload on a UDP layer so existing UDP handling works. w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{ BaseLayer: layers.BaseLayer{Payload: payload}, SrcPort: layers.UDPPort(udp.SrcPort), DstPort: layers.UDPPort(udp.DstPort), }, uc) return io.Verdict(uc.Verdict), uc.Packet } func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, modPayload []byte) (io.Verdict, []byte) { layerType := layers.LayerTypeIPv4 if l3.Version == 6 { layerType = layers.LayerTypeIPv6 } ipPkt := gopacket.NewPacket(fullData, layerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true}) netLayer := ipPkt.NetworkLayer() trLayer := ipPkt.TransportLayer() if netLayer == nil || trLayer == nil { return io.VerdictAccept, nil } udpLayer, ok := trLayer.(*layers.UDP) if !ok { return io.VerdictAccept, nil } udpLayer.Payload = modPayload _ = udpLayer.SetNetworkLayerForChecksum(netLayer) _ = w.modSerializeBuffer.Clear() err := gopacket.SerializePacket(w.modSerializeBuffer, gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}, ipPkt) if err != nil { return io.VerdictAccept, nil } return io.VerdictAcceptModify, w.modSerializeBuffer.Bytes() } func extractL3PayloadFromEthernet(data []byte) ([]byte, bool) { if len(data) < 14 { return nil, false } offset := 12 etherType := uint16(data[offset])<<8 | uint16(data[offset+1]) offset += 2 for etherType == 0x8100 || etherType == 0x88A8 { if len(data) < offset+4 { return nil, false } etherType = uint16(data[offset+2])<<8 | uint16(data[offset+3]) offset += 4 } if etherType != 0x0800 && etherType != 0x86DD { return nil, false } if len(data) <= offset { return nil, false } return data[offset:], true }