283 lines
6.8 KiB
Go
283 lines
6.8 KiB
Go
package engine
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
|
|
"git.difuse.io/Difuse/Mellaris/io"
|
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
|
|
|
"github.com/bwmarrin/snowflake"
|
|
"github.com/google/gopacket"
|
|
"github.com/google/gopacket/layers"
|
|
)
|
|
|
|
type workerPacket struct {
|
|
Packet io.Packet
|
|
StreamID uint32
|
|
Data []byte
|
|
SrcMAC net.HardwareAddr
|
|
DstMAC net.HardwareAddr
|
|
Gen int64
|
|
}
|
|
|
|
type workerResult struct {
|
|
Packet io.Packet
|
|
StreamID uint32
|
|
Verdict io.Verdict
|
|
ModifiedPacket []byte
|
|
Gen int64
|
|
}
|
|
|
|
type worker struct {
|
|
id int
|
|
packetChan chan *workerPacket
|
|
resultChan chan workerResult
|
|
logger Logger
|
|
macResolver *sourceMACResolver
|
|
|
|
tcpFlowMgr *tcpFlowManager
|
|
udpSM *udpStreamManager
|
|
|
|
modSerializeBuffer gopacket.SerializeBuffer
|
|
}
|
|
|
|
type workerConfig struct {
|
|
ID int
|
|
ChanSize int
|
|
Logger Logger
|
|
Ruleset ruleset.Ruleset
|
|
MACResolver *sourceMACResolver
|
|
TCPMaxBufferedPagesTotal int // unused, kept for config compat
|
|
TCPMaxBufferedPagesPerConn int // unused, kept for config compat
|
|
UDPMaxStreams int
|
|
AnalyzerSelectionMode AnalyzerSelectionMode
|
|
ResultChan chan workerResult
|
|
Stats *statsCounters
|
|
}
|
|
|
|
func (c *workerConfig) fillDefaults() {
|
|
if c.ChanSize <= 0 {
|
|
c.ChanSize = 64
|
|
}
|
|
if c.UDPMaxStreams <= 0 {
|
|
c.UDPMaxStreams = 4096
|
|
}
|
|
}
|
|
|
|
func newWorker(config workerConfig) (*worker, error) {
|
|
config.fillDefaults()
|
|
sfNode, err := snowflake.NewNode(int64(config.ID))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
selector := newAnalyzerSelector(config.AnalyzerSelectionMode, config.Stats)
|
|
tcpMgr := newTCPFlowManager(config.ID, config.Logger, config.MACResolver, sfNode, selector)
|
|
if config.Ruleset != nil {
|
|
tcpMgr.updateRuleset(config.Ruleset, 0)
|
|
}
|
|
|
|
udpSF := &udpStreamFactory{
|
|
WorkerID: config.ID,
|
|
Logger: config.Logger,
|
|
Node: sfNode,
|
|
Ruleset: config.Ruleset,
|
|
Selector: selector,
|
|
Stats: config.Stats,
|
|
}
|
|
udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams, config.Stats)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &worker{
|
|
id: config.ID,
|
|
packetChan: make(chan *workerPacket, config.ChanSize),
|
|
resultChan: config.ResultChan,
|
|
logger: config.Logger,
|
|
macResolver: config.MACResolver,
|
|
tcpFlowMgr: tcpMgr,
|
|
udpSM: udpSM,
|
|
modSerializeBuffer: gopacket.NewSerializeBuffer(),
|
|
}, nil
|
|
}
|
|
|
|
func (w *worker) Feed(p *workerPacket) bool {
|
|
select {
|
|
case w.packetChan <- p:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (w *worker) FeedBlocking(p *workerPacket) {
|
|
w.packetChan <- p
|
|
}
|
|
|
|
func (w *worker) Run(ctx context.Context) {
|
|
w.logger.WorkerStart(w.id)
|
|
defer w.logger.WorkerStop(w.id)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case wp := <-w.packetChan:
|
|
if wp == nil {
|
|
return
|
|
}
|
|
v, b := w.handle(wp)
|
|
w.resultChan <- workerResult{
|
|
Packet: wp.Packet,
|
|
StreamID: wp.StreamID,
|
|
Verdict: v,
|
|
ModifiedPacket: b,
|
|
Gen: wp.Gen,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (w *worker) UpdateRuleset(r ruleset.Ruleset) error {
|
|
w.tcpFlowMgr.updateRuleset(r, 0)
|
|
return w.udpSM.factory.UpdateRuleset(r)
|
|
}
|
|
|
|
func (w *worker) handle(wp *workerPacket) (io.Verdict, []byte) {
|
|
data := wp.Data
|
|
if len(data) == 0 {
|
|
return io.VerdictAccept, nil
|
|
}
|
|
|
|
if v, b, ok := w.handleIPPacket(wp, data); ok {
|
|
return v, b
|
|
}
|
|
|
|
// Ethernet frame fallback path (for custom PacketIO implementations).
|
|
if l3Payload, ok := extractL3PayloadFromEthernet(data); ok {
|
|
if v, b, ok := w.handleIPPacket(wp, l3Payload); ok {
|
|
return v, b
|
|
}
|
|
}
|
|
|
|
return io.VerdictAccept, nil
|
|
}
|
|
|
|
func (w *worker) handleIPPacket(wp *workerPacket, data []byte) (io.Verdict, []byte, bool) {
|
|
l3, transport, ok := ParseL3(data)
|
|
if !ok {
|
|
return io.VerdictAccept, nil, false
|
|
}
|
|
switch l3.Protocol {
|
|
case 6: // TCP
|
|
tcp, payload, ok := ParseTCP(transport)
|
|
if !ok {
|
|
return io.VerdictAccept, nil, true
|
|
}
|
|
verdict := w.tcpFlowMgr.handle(
|
|
wp.StreamID, l3, tcp, payload,
|
|
wp.SrcMAC, wp.DstMAC,
|
|
)
|
|
return verdict, nil, true
|
|
case 17: // UDP
|
|
udp, payload, ok := ParseUDP(transport)
|
|
if !ok {
|
|
return io.VerdictAccept, nil, true
|
|
}
|
|
v, modPayload := w.handleUDP(
|
|
wp.StreamID, l3, udp, payload,
|
|
wp.SrcMAC, wp.DstMAC,
|
|
)
|
|
if v == io.VerdictAcceptModify && modPayload != nil {
|
|
mv, mb := w.serializeModifiedUDP(data, l3, modPayload)
|
|
return mv, mb, true
|
|
}
|
|
return v, nil, true
|
|
default:
|
|
return io.VerdictAccept, nil, true
|
|
}
|
|
}
|
|
|
|
func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) {
|
|
ipSrc := l3.SrcIPAddr()
|
|
ipDst := l3.DstIPAddr()
|
|
endpointType := layers.EndpointIPv4
|
|
flowSrc := ipSrc.To4()
|
|
flowDst := ipDst.To4()
|
|
if l3.Version == 6 {
|
|
endpointType = layers.EndpointIPv6
|
|
flowSrc = ipSrc.To16()
|
|
flowDst = ipDst.To16()
|
|
}
|
|
ipFlow := gopacket.NewFlow(endpointType, flowSrc, flowDst)
|
|
|
|
if len(srcMAC) == 0 && w.macResolver != nil {
|
|
srcMAC = w.macResolver.Resolve(ipSrc)
|
|
}
|
|
|
|
uc := &udpContext{
|
|
Verdict: udpVerdictAccept,
|
|
SrcMAC: srcMAC,
|
|
DstMAC: dstMAC,
|
|
}
|
|
// Temporarily set payload on a UDP layer so existing UDP handling works.
|
|
w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{
|
|
BaseLayer: layers.BaseLayer{Payload: payload},
|
|
SrcPort: layers.UDPPort(udp.SrcPort),
|
|
DstPort: layers.UDPPort(udp.DstPort),
|
|
}, uc)
|
|
return io.Verdict(uc.Verdict), uc.Packet
|
|
}
|
|
|
|
func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, modPayload []byte) (io.Verdict, []byte) {
|
|
layerType := layers.LayerTypeIPv4
|
|
if l3.Version == 6 {
|
|
layerType = layers.LayerTypeIPv6
|
|
}
|
|
ipPkt := gopacket.NewPacket(fullData, layerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
|
|
netLayer := ipPkt.NetworkLayer()
|
|
trLayer := ipPkt.TransportLayer()
|
|
if netLayer == nil || trLayer == nil {
|
|
return io.VerdictAccept, nil
|
|
}
|
|
udpLayer, ok := trLayer.(*layers.UDP)
|
|
if !ok {
|
|
return io.VerdictAccept, nil
|
|
}
|
|
udpLayer.Payload = modPayload
|
|
_ = udpLayer.SetNetworkLayerForChecksum(netLayer)
|
|
_ = w.modSerializeBuffer.Clear()
|
|
err := gopacket.SerializePacket(w.modSerializeBuffer,
|
|
gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}, ipPkt)
|
|
if err != nil {
|
|
return io.VerdictAccept, nil
|
|
}
|
|
return io.VerdictAcceptModify, w.modSerializeBuffer.Bytes()
|
|
}
|
|
|
|
func extractL3PayloadFromEthernet(data []byte) ([]byte, bool) {
|
|
if len(data) < 14 {
|
|
return nil, false
|
|
}
|
|
offset := 12
|
|
etherType := uint16(data[offset])<<8 | uint16(data[offset+1])
|
|
offset += 2
|
|
|
|
for etherType == 0x8100 || etherType == 0x88A8 {
|
|
if len(data) < offset+4 {
|
|
return nil, false
|
|
}
|
|
etherType = uint16(data[offset+2])<<8 | uint16(data[offset+3])
|
|
offset += 4
|
|
}
|
|
|
|
if etherType != 0x0800 && etherType != 0x86DD {
|
|
return nil, false
|
|
}
|
|
if len(data) <= offset {
|
|
return nil, false
|
|
}
|
|
return data[offset:], true
|
|
}
|