Files
Mellaris/engine/worker.go
T
2026-05-14 09:41:07 +05:30

283 lines
6.8 KiB
Go

package engine
import (
"context"
"net"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
type workerPacket struct {
Packet io.Packet
StreamID uint32
Data []byte
SrcMAC net.HardwareAddr
DstMAC net.HardwareAddr
Gen int64
}
type workerResult struct {
Packet io.Packet
StreamID uint32
Verdict io.Verdict
ModifiedPacket []byte
Gen int64
}
type worker struct {
id int
packetChan chan *workerPacket
resultChan chan workerResult
logger Logger
macResolver *sourceMACResolver
tcpFlowMgr *tcpFlowManager
udpSM *udpStreamManager
modSerializeBuffer gopacket.SerializeBuffer
}
type workerConfig struct {
ID int
ChanSize int
Logger Logger
Ruleset ruleset.Ruleset
MACResolver *sourceMACResolver
TCPMaxBufferedPagesTotal int // unused, kept for config compat
TCPMaxBufferedPagesPerConn int // unused, kept for config compat
UDPMaxStreams int
AnalyzerSelectionMode AnalyzerSelectionMode
ResultChan chan workerResult
Stats *statsCounters
}
func (c *workerConfig) fillDefaults() {
if c.ChanSize <= 0 {
c.ChanSize = 64
}
if c.UDPMaxStreams <= 0 {
c.UDPMaxStreams = 4096
}
}
func newWorker(config workerConfig) (*worker, error) {
config.fillDefaults()
sfNode, err := snowflake.NewNode(int64(config.ID))
if err != nil {
return nil, err
}
selector := newAnalyzerSelector(config.AnalyzerSelectionMode, config.Stats)
tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode, selector)
if config.Ruleset != nil {
tcpMgr.updateRuleset(config.Ruleset, 0)
}
udpSF := &udpStreamFactory{
WorkerID: config.ID,
Logger: config.Logger,
Node: sfNode,
Ruleset: config.Ruleset,
Selector: selector,
Stats: config.Stats,
}
udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams, config.Stats)
if err != nil {
return nil, err
}
return &worker{
id: config.ID,
packetChan: make(chan *workerPacket, config.ChanSize),
resultChan: config.ResultChan,
logger: config.Logger,
macResolver: config.MACResolver,
tcpFlowMgr: tcpMgr,
udpSM: udpSM,
modSerializeBuffer: gopacket.NewSerializeBuffer(),
}, nil
}
func (w *worker) Feed(p *workerPacket) bool {
select {
case w.packetChan <- p:
return true
default:
return false
}
}
func (w *worker) FeedBlocking(p *workerPacket) {
w.packetChan <- p
}
func (w *worker) Run(ctx context.Context) {
w.logger.WorkerStart(w.id)
defer w.logger.WorkerStop(w.id)
for {
select {
case <-ctx.Done():
return
case wp := <-w.packetChan:
if wp == nil {
return
}
v, b := w.handle(wp)
w.resultChan <- workerResult{
Packet: wp.Packet,
StreamID: wp.StreamID,
Verdict: v,
ModifiedPacket: b,
Gen: wp.Gen,
}
}
}
}
func (w *worker) UpdateRuleset(r ruleset.Ruleset) error {
w.tcpFlowMgr.updateRuleset(r, 0)
return w.udpSM.factory.UpdateRuleset(r)
}
func (w *worker) handle(wp *workerPacket) (io.Verdict, []byte) {
data := wp.Data
if len(data) == 0 {
return io.VerdictAccept, nil
}
if v, b, ok := w.handleIPPacket(wp, data); ok {
return v, b
}
// Ethernet frame fallback path (for custom PacketIO implementations).
if l3Payload, ok := extractL3PayloadFromEthernet(data); ok {
if v, b, ok := w.handleIPPacket(wp, l3Payload); ok {
return v, b
}
}
return io.VerdictAccept, nil
}
func (w *worker) handleIPPacket(wp *workerPacket, data []byte) (io.Verdict, []byte, bool) {
l3, transport, ok := ParseL3(data)
if !ok {
return io.VerdictAccept, nil, false
}
switch l3.Protocol {
case 6: // TCP
tcp, payload, ok := ParseTCP(transport)
if !ok {
return io.VerdictAccept, nil, true
}
verdict := w.tcpFlowMgr.handle(
wp.StreamID, l3, tcp, payload,
wp.SrcMAC, wp.DstMAC,
)
return verdict, nil, true
case 17: // UDP
udp, payload, ok := ParseUDP(transport)
if !ok {
return io.VerdictAccept, nil, true
}
v, modPayload := w.handleUDP(
wp.StreamID, l3, udp, payload,
wp.SrcMAC, wp.DstMAC,
)
if v == io.VerdictAcceptModify && modPayload != nil {
mv, mb := w.serializeModifiedUDP(data, l3, modPayload)
return mv, mb, true
}
return v, nil, true
default:
return io.VerdictAccept, nil, true
}
}
func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) {
ipSrc := l3.SrcIPAddr()
ipDst := l3.DstIPAddr()
endpointType := layers.EndpointIPv4
flowSrc := ipSrc.To4()
flowDst := ipDst.To4()
if l3.Version == 6 {
endpointType = layers.EndpointIPv6
flowSrc = ipSrc.To16()
flowDst = ipDst.To16()
}
ipFlow := gopacket.NewFlow(endpointType, flowSrc, flowDst)
if len(srcMAC) == 0 && w.macResolver != nil {
srcMAC = w.macResolver.Resolve(ipSrc)
}
uc := &udpContext{
Verdict: udpVerdictAccept,
SrcMAC: srcMAC,
DstMAC: dstMAC,
}
// Temporarily set payload on a UDP layer so existing UDP handling works.
w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{
BaseLayer: layers.BaseLayer{Payload: payload},
SrcPort: layers.UDPPort(udp.SrcPort),
DstPort: layers.UDPPort(udp.DstPort),
}, uc)
return io.Verdict(uc.Verdict), uc.Packet
}
func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, modPayload []byte) (io.Verdict, []byte) {
layerType := layers.LayerTypeIPv4
if l3.Version == 6 {
layerType = layers.LayerTypeIPv6
}
ipPkt := gopacket.NewPacket(fullData, layerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
netLayer := ipPkt.NetworkLayer()
trLayer := ipPkt.TransportLayer()
if netLayer == nil || trLayer == nil {
return io.VerdictAccept, nil
}
udpLayer, ok := trLayer.(*layers.UDP)
if !ok {
return io.VerdictAccept, nil
}
udpLayer.Payload = modPayload
_ = udpLayer.SetNetworkLayerForChecksum(netLayer)
_ = w.modSerializeBuffer.Clear()
err := gopacket.SerializePacket(w.modSerializeBuffer,
gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}, ipPkt)
if err != nil {
return io.VerdictAccept, nil
}
return io.VerdictAcceptModify, w.modSerializeBuffer.Bytes()
}
func extractL3PayloadFromEthernet(data []byte) ([]byte, bool) {
if len(data) < 14 {
return nil, false
}
offset := 12
etherType := uint16(data[offset])<<8 | uint16(data[offset+1])
offset += 2
for etherType == 0x8100 || etherType == 0x88A8 {
if len(data) < offset+4 {
return nil, false
}
etherType = uint16(data[offset+2])<<8 | uint16(data[offset+3])
offset += 4
}
if etherType != 0x0800 && etherType != 0x86DD {
return nil, false
}
if len(data) <= offset {
return nil, false
}
return data[offset:], true
}