First Commit
This commit is contained in:
115
engine/engine.go
Normal file
115
engine/engine.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/io"
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
var _ Engine = (*engine)(nil)
|
||||
|
||||
type engine struct {
|
||||
logger Logger
|
||||
io io.PacketIO
|
||||
workers []*worker
|
||||
}
|
||||
|
||||
func NewEngine(config Config) (Engine, error) {
|
||||
workerCount := config.Workers
|
||||
if workerCount <= 0 {
|
||||
workerCount = runtime.NumCPU()
|
||||
}
|
||||
var err error
|
||||
workers := make([]*worker, workerCount)
|
||||
for i := range workers {
|
||||
workers[i], err = newWorker(workerConfig{
|
||||
ID: i,
|
||||
ChanSize: config.WorkerQueueSize,
|
||||
Logger: config.Logger,
|
||||
Ruleset: config.Ruleset,
|
||||
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
|
||||
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
|
||||
UDPMaxStreams: config.WorkerUDPMaxStreams,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &engine{
|
||||
logger: config.Logger,
|
||||
io: config.IO,
|
||||
workers: workers,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
|
||||
for _, w := range e.workers {
|
||||
if err := w.UpdateRuleset(r); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *engine) Run(ctx context.Context) error {
|
||||
ioCtx, ioCancel := context.WithCancel(ctx)
|
||||
defer ioCancel() // Stop workers & IO
|
||||
|
||||
// Start workers
|
||||
for _, w := range e.workers {
|
||||
go w.Run(ioCtx)
|
||||
}
|
||||
|
||||
// Register IO callback
|
||||
errChan := make(chan error, 1)
|
||||
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return false
|
||||
}
|
||||
return e.dispatch(p)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Block until IO errors or context is cancelled
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// dispatch dispatches a packet to a worker.
|
||||
func (e *engine) dispatch(p io.Packet) bool {
|
||||
data := p.Data()
|
||||
ipVersion := data[0] >> 4
|
||||
var layerType gopacket.LayerType
|
||||
if ipVersion == 4 {
|
||||
layerType = layers.LayerTypeIPv4
|
||||
} else if ipVersion == 6 {
|
||||
layerType = layers.LayerTypeIPv6
|
||||
} else {
|
||||
// Unsupported network layer
|
||||
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
|
||||
return true
|
||||
}
|
||||
// Load balance by stream ID
|
||||
index := p.StreamID() % uint32(len(e.workers))
|
||||
packet := gopacket.NewPacket(data, layerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
|
||||
e.workers[index].Feed(&workerPacket{
|
||||
StreamID: p.StreamID(),
|
||||
Packet: packet,
|
||||
SetVerdict: func(v io.Verdict, b []byte) error {
|
||||
return e.io.SetVerdict(p, v, b)
|
||||
},
|
||||
})
|
||||
return true
|
||||
}
|
||||
49
engine/interface.go
Normal file
49
engine/interface.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/io"
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||
)
|
||||
|
||||
// Engine is the main engine for Mellaris.
|
||||
type Engine interface {
|
||||
// UpdateRuleset updates the ruleset.
|
||||
UpdateRuleset(ruleset.Ruleset) error
|
||||
// Run runs the engine, until an error occurs or the context is cancelled.
|
||||
Run(context.Context) error
|
||||
}
|
||||
|
||||
// Config is the configuration for the engine.
|
||||
type Config struct {
|
||||
Logger Logger
|
||||
IO io.PacketIO
|
||||
Ruleset ruleset.Ruleset
|
||||
|
||||
Workers int // Number of workers. Zero or negative means auto (number of CPU cores).
|
||||
WorkerQueueSize int
|
||||
WorkerTCPMaxBufferedPagesTotal int
|
||||
WorkerTCPMaxBufferedPagesPerConn int
|
||||
WorkerUDPMaxStreams int
|
||||
}
|
||||
|
||||
// Logger is the combined logging interface for the engine, workers and analyzers.
|
||||
type Logger interface {
|
||||
WorkerStart(id int)
|
||||
WorkerStop(id int)
|
||||
|
||||
TCPStreamNew(workerID int, info ruleset.StreamInfo)
|
||||
TCPStreamPropUpdate(info ruleset.StreamInfo, close bool)
|
||||
TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool)
|
||||
|
||||
UDPStreamNew(workerID int, info ruleset.StreamInfo)
|
||||
UDPStreamPropUpdate(info ruleset.StreamInfo, close bool)
|
||||
UDPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool)
|
||||
|
||||
ModifyError(info ruleset.StreamInfo, err error)
|
||||
|
||||
AnalyzerDebugf(streamID int64, name string, format string, args ...interface{})
|
||||
AnalyzerInfof(streamID int64, name string, format string, args ...interface{})
|
||||
AnalyzerErrorf(streamID int64, name string, format string, args ...interface{})
|
||||
}
|
||||
229
engine/tcp.go
Normal file
229
engine/tcp.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||
"git.difuse.io/Difuse/Mellaris/io"
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||
|
||||
"github.com/bwmarrin/snowflake"
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/google/gopacket/reassembly"
|
||||
)
|
||||
|
||||
// tcpVerdict is a subset of io.Verdict for TCP streams.
|
||||
// We don't allow modifying or dropping a single packet
|
||||
// for TCP streams for now, as it doesn't make much sense.
|
||||
type tcpVerdict io.Verdict
|
||||
|
||||
const (
|
||||
tcpVerdictAccept = tcpVerdict(io.VerdictAccept)
|
||||
tcpVerdictAcceptStream = tcpVerdict(io.VerdictAcceptStream)
|
||||
tcpVerdictDropStream = tcpVerdict(io.VerdictDropStream)
|
||||
)
|
||||
|
||||
type tcpContext struct {
|
||||
*gopacket.PacketMetadata
|
||||
Verdict tcpVerdict
|
||||
}
|
||||
|
||||
func (ctx *tcpContext) GetCaptureInfo() gopacket.CaptureInfo {
|
||||
return ctx.CaptureInfo
|
||||
}
|
||||
|
||||
type tcpStreamFactory struct {
|
||||
WorkerID int
|
||||
Logger Logger
|
||||
Node *snowflake.Node
|
||||
|
||||
RulesetMutex sync.RWMutex
|
||||
Ruleset ruleset.Ruleset
|
||||
}
|
||||
|
||||
func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream {
|
||||
id := f.Node.Generate()
|
||||
ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw())
|
||||
info := ruleset.StreamInfo{
|
||||
ID: id.Int64(),
|
||||
Protocol: ruleset.ProtocolTCP,
|
||||
SrcIP: ipSrc,
|
||||
DstIP: ipDst,
|
||||
SrcPort: uint16(tcp.SrcPort),
|
||||
DstPort: uint16(tcp.DstPort),
|
||||
Props: make(analyzer.CombinedPropMap),
|
||||
}
|
||||
f.Logger.TCPStreamNew(f.WorkerID, info)
|
||||
f.RulesetMutex.RLock()
|
||||
rs := f.Ruleset
|
||||
f.RulesetMutex.RUnlock()
|
||||
ans := analyzersToTCPAnalyzers(rs.Analyzers(info))
|
||||
// Create entries for each analyzer
|
||||
entries := make([]*tcpStreamEntry, 0, len(ans))
|
||||
for _, a := range ans {
|
||||
entries = append(entries, &tcpStreamEntry{
|
||||
Name: a.Name(),
|
||||
Stream: a.NewTCP(analyzer.TCPInfo{
|
||||
SrcIP: ipSrc,
|
||||
DstIP: ipDst,
|
||||
SrcPort: uint16(tcp.SrcPort),
|
||||
DstPort: uint16(tcp.DstPort),
|
||||
}, &analyzerLogger{
|
||||
StreamID: id.Int64(),
|
||||
Name: a.Name(),
|
||||
Logger: f.Logger,
|
||||
}),
|
||||
HasLimit: a.Limit() > 0,
|
||||
Quota: a.Limit(),
|
||||
})
|
||||
}
|
||||
return &tcpStream{
|
||||
info: info,
|
||||
virgin: true,
|
||||
logger: f.Logger,
|
||||
ruleset: rs,
|
||||
activeEntries: entries,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *tcpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error {
|
||||
f.RulesetMutex.Lock()
|
||||
defer f.RulesetMutex.Unlock()
|
||||
f.Ruleset = r
|
||||
return nil
|
||||
}
|
||||
|
||||
type tcpStream struct {
|
||||
info ruleset.StreamInfo
|
||||
virgin bool // true if no packets have been processed
|
||||
logger Logger
|
||||
ruleset ruleset.Ruleset
|
||||
activeEntries []*tcpStreamEntry
|
||||
doneEntries []*tcpStreamEntry
|
||||
lastVerdict tcpVerdict
|
||||
}
|
||||
|
||||
type tcpStreamEntry struct {
|
||||
Name string
|
||||
Stream analyzer.TCPStream
|
||||
HasLimit bool
|
||||
Quota int
|
||||
}
|
||||
|
||||
func (s *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool {
|
||||
if len(s.activeEntries) > 0 || s.virgin {
|
||||
// Make sure every stream matches against the ruleset at least once,
|
||||
// even if there are no activeEntries, as the ruleset may have built-in
|
||||
// properties that need to be matched.
|
||||
return true
|
||||
} else {
|
||||
ctx := ac.(*tcpContext)
|
||||
ctx.Verdict = s.lastVerdict
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) {
|
||||
dir, start, end, skip := sg.Info()
|
||||
rev := dir == reassembly.TCPDirServerToClient
|
||||
avail, _ := sg.Lengths()
|
||||
data := sg.Fetch(avail)
|
||||
updated := false
|
||||
for i := len(s.activeEntries) - 1; i >= 0; i-- {
|
||||
// Important: reverse order so we can remove entries
|
||||
entry := s.activeEntries[i]
|
||||
update, closeUpdate, done := s.feedEntry(entry, rev, start, end, skip, data)
|
||||
up1 := processPropUpdate(s.info.Props, entry.Name, update)
|
||||
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
|
||||
updated = updated || up1 || up2
|
||||
if done {
|
||||
s.activeEntries = append(s.activeEntries[:i], s.activeEntries[i+1:]...)
|
||||
s.doneEntries = append(s.doneEntries, entry)
|
||||
}
|
||||
}
|
||||
ctx := ac.(*tcpContext)
|
||||
if updated || s.virgin {
|
||||
s.virgin = false
|
||||
s.logger.TCPStreamPropUpdate(s.info, false)
|
||||
// Match properties against ruleset
|
||||
result := s.ruleset.Match(s.info)
|
||||
action := result.Action
|
||||
if action != ruleset.ActionMaybe && action != ruleset.ActionModify {
|
||||
verdict := actionToTCPVerdict(action)
|
||||
s.lastVerdict = verdict
|
||||
ctx.Verdict = verdict
|
||||
s.logger.TCPStreamAction(s.info, action, false)
|
||||
// Verdict issued, no need to process any more packets
|
||||
s.closeActiveEntries()
|
||||
}
|
||||
}
|
||||
if len(s.activeEntries) == 0 && ctx.Verdict == tcpVerdictAccept {
|
||||
// All entries are done but no verdict issued, accept stream
|
||||
s.lastVerdict = tcpVerdictAcceptStream
|
||||
ctx.Verdict = tcpVerdictAcceptStream
|
||||
s.logger.TCPStreamAction(s.info, ruleset.ActionAllow, true)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool {
|
||||
s.closeActiveEntries()
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *tcpStream) closeActiveEntries() {
|
||||
// Signal close to all active entries & move them to doneEntries
|
||||
updated := false
|
||||
for _, entry := range s.activeEntries {
|
||||
update := entry.Stream.Close(false)
|
||||
up := processPropUpdate(s.info.Props, entry.Name, update)
|
||||
updated = updated || up
|
||||
}
|
||||
if updated {
|
||||
s.logger.TCPStreamPropUpdate(s.info, true)
|
||||
}
|
||||
s.doneEntries = append(s.doneEntries, s.activeEntries...)
|
||||
s.activeEntries = nil
|
||||
}
|
||||
|
||||
func (s *tcpStream) feedEntry(entry *tcpStreamEntry, rev, start, end bool, skip int, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
|
||||
if !entry.HasLimit {
|
||||
update, done = entry.Stream.Feed(rev, start, end, skip, data)
|
||||
} else {
|
||||
qData := data
|
||||
if len(qData) > entry.Quota {
|
||||
qData = qData[:entry.Quota]
|
||||
}
|
||||
update, done = entry.Stream.Feed(rev, start, end, skip, qData)
|
||||
entry.Quota -= len(qData)
|
||||
if entry.Quota <= 0 {
|
||||
// Quota exhausted, signal close & move to doneEntries
|
||||
closeUpdate = entry.Stream.Close(true)
|
||||
done = true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func analyzersToTCPAnalyzers(ans []analyzer.Analyzer) []analyzer.TCPAnalyzer {
|
||||
tcpAns := make([]analyzer.TCPAnalyzer, 0, len(ans))
|
||||
for _, a := range ans {
|
||||
if tcpM, ok := a.(analyzer.TCPAnalyzer); ok {
|
||||
tcpAns = append(tcpAns, tcpM)
|
||||
}
|
||||
}
|
||||
return tcpAns
|
||||
}
|
||||
|
||||
func actionToTCPVerdict(a ruleset.Action) tcpVerdict {
|
||||
switch a {
|
||||
case ruleset.ActionMaybe, ruleset.ActionAllow, ruleset.ActionModify:
|
||||
return tcpVerdictAcceptStream
|
||||
case ruleset.ActionBlock, ruleset.ActionDrop:
|
||||
return tcpVerdictDropStream
|
||||
default:
|
||||
// Should never happen
|
||||
return tcpVerdictAcceptStream
|
||||
}
|
||||
}
|
||||
299
engine/udp.go
Normal file
299
engine/udp.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||
"git.difuse.io/Difuse/Mellaris/io"
|
||||
"git.difuse.io/Difuse/Mellaris/modifier"
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||
|
||||
"github.com/bwmarrin/snowflake"
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
)
|
||||
|
||||
// udpVerdict is a subset of io.Verdict for UDP streams.
|
||||
// For UDP, we support all verdicts.
|
||||
type udpVerdict io.Verdict
|
||||
|
||||
const (
|
||||
udpVerdictAccept = udpVerdict(io.VerdictAccept)
|
||||
udpVerdictAcceptModify = udpVerdict(io.VerdictAcceptModify)
|
||||
udpVerdictAcceptStream = udpVerdict(io.VerdictAcceptStream)
|
||||
udpVerdictDrop = udpVerdict(io.VerdictDrop)
|
||||
udpVerdictDropStream = udpVerdict(io.VerdictDropStream)
|
||||
)
|
||||
|
||||
var errInvalidModifier = errors.New("invalid modifier")
|
||||
|
||||
type udpContext struct {
|
||||
Verdict udpVerdict
|
||||
Packet []byte
|
||||
}
|
||||
|
||||
type udpStreamFactory struct {
|
||||
WorkerID int
|
||||
Logger Logger
|
||||
Node *snowflake.Node
|
||||
|
||||
RulesetMutex sync.RWMutex
|
||||
Ruleset ruleset.Ruleset
|
||||
}
|
||||
|
||||
func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) *udpStream {
|
||||
id := f.Node.Generate()
|
||||
ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw())
|
||||
info := ruleset.StreamInfo{
|
||||
ID: id.Int64(),
|
||||
Protocol: ruleset.ProtocolUDP,
|
||||
SrcIP: ipSrc,
|
||||
DstIP: ipDst,
|
||||
SrcPort: uint16(udp.SrcPort),
|
||||
DstPort: uint16(udp.DstPort),
|
||||
Props: make(analyzer.CombinedPropMap),
|
||||
}
|
||||
f.Logger.UDPStreamNew(f.WorkerID, info)
|
||||
f.RulesetMutex.RLock()
|
||||
rs := f.Ruleset
|
||||
f.RulesetMutex.RUnlock()
|
||||
ans := analyzersToUDPAnalyzers(rs.Analyzers(info))
|
||||
// Create entries for each analyzer
|
||||
entries := make([]*udpStreamEntry, 0, len(ans))
|
||||
for _, a := range ans {
|
||||
entries = append(entries, &udpStreamEntry{
|
||||
Name: a.Name(),
|
||||
Stream: a.NewUDP(analyzer.UDPInfo{
|
||||
SrcIP: ipSrc,
|
||||
DstIP: ipDst,
|
||||
SrcPort: uint16(udp.SrcPort),
|
||||
DstPort: uint16(udp.DstPort),
|
||||
}, &analyzerLogger{
|
||||
StreamID: id.Int64(),
|
||||
Name: a.Name(),
|
||||
Logger: f.Logger,
|
||||
}),
|
||||
HasLimit: a.Limit() > 0,
|
||||
Quota: a.Limit(),
|
||||
})
|
||||
}
|
||||
return &udpStream{
|
||||
info: info,
|
||||
virgin: true,
|
||||
logger: f.Logger,
|
||||
ruleset: rs,
|
||||
activeEntries: entries,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *udpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error {
|
||||
f.RulesetMutex.Lock()
|
||||
defer f.RulesetMutex.Unlock()
|
||||
f.Ruleset = r
|
||||
return nil
|
||||
}
|
||||
|
||||
type udpStreamManager struct {
|
||||
factory *udpStreamFactory
|
||||
streams *lru.Cache[uint32, *udpStreamValue]
|
||||
}
|
||||
|
||||
type udpStreamValue struct {
|
||||
Stream *udpStream
|
||||
IPFlow gopacket.Flow
|
||||
UDPFlow gopacket.Flow
|
||||
}
|
||||
|
||||
func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) {
|
||||
fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow
|
||||
rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()
|
||||
return fwd || rev, rev
|
||||
}
|
||||
|
||||
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int) (*udpStreamManager, error) {
|
||||
ss, err := lru.New[uint32, *udpStreamValue](maxStreams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &udpStreamManager{
|
||||
factory: factory,
|
||||
streams: ss,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) {
|
||||
rev := false
|
||||
value, ok := m.streams.Get(streamID)
|
||||
if !ok {
|
||||
// New stream
|
||||
value = &udpStreamValue{
|
||||
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc),
|
||||
IPFlow: ipFlow,
|
||||
UDPFlow: udp.TransportFlow(),
|
||||
}
|
||||
m.streams.Add(streamID, value)
|
||||
} else {
|
||||
// Stream ID exists, but is it really the same stream?
|
||||
ok, rev = value.Match(ipFlow, udp.TransportFlow())
|
||||
if !ok {
|
||||
// It's not - close the old stream & replace it with a new one
|
||||
value.Stream.Close()
|
||||
value = &udpStreamValue{
|
||||
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc),
|
||||
IPFlow: ipFlow,
|
||||
UDPFlow: udp.TransportFlow(),
|
||||
}
|
||||
m.streams.Add(streamID, value)
|
||||
}
|
||||
}
|
||||
if value.Stream.Accept(udp, rev, uc) {
|
||||
value.Stream.Feed(udp, rev, uc)
|
||||
}
|
||||
}
|
||||
|
||||
type udpStream struct {
|
||||
info ruleset.StreamInfo
|
||||
virgin bool // true if no packets have been processed
|
||||
logger Logger
|
||||
ruleset ruleset.Ruleset
|
||||
activeEntries []*udpStreamEntry
|
||||
doneEntries []*udpStreamEntry
|
||||
lastVerdict udpVerdict
|
||||
}
|
||||
|
||||
type udpStreamEntry struct {
|
||||
Name string
|
||||
Stream analyzer.UDPStream
|
||||
HasLimit bool
|
||||
Quota int
|
||||
}
|
||||
|
||||
func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool {
|
||||
if len(s.activeEntries) > 0 || s.virgin {
|
||||
// Make sure every stream matches against the ruleset at least once,
|
||||
// even if there are no activeEntries, as the ruleset may have built-in
|
||||
// properties that need to be matched.
|
||||
return true
|
||||
} else {
|
||||
uc.Verdict = s.lastVerdict
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) {
|
||||
updated := false
|
||||
for i := len(s.activeEntries) - 1; i >= 0; i-- {
|
||||
// Important: reverse order so we can remove entries
|
||||
entry := s.activeEntries[i]
|
||||
update, closeUpdate, done := s.feedEntry(entry, rev, udp.Payload)
|
||||
up1 := processPropUpdate(s.info.Props, entry.Name, update)
|
||||
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
|
||||
updated = updated || up1 || up2
|
||||
if done {
|
||||
s.activeEntries = append(s.activeEntries[:i], s.activeEntries[i+1:]...)
|
||||
s.doneEntries = append(s.doneEntries, entry)
|
||||
}
|
||||
}
|
||||
if updated || s.virgin {
|
||||
s.virgin = false
|
||||
s.logger.UDPStreamPropUpdate(s.info, false)
|
||||
// Match properties against ruleset
|
||||
result := s.ruleset.Match(s.info)
|
||||
action := result.Action
|
||||
if action == ruleset.ActionModify {
|
||||
// Call the modifier instance
|
||||
udpMI, ok := result.ModInstance.(modifier.UDPModifierInstance)
|
||||
if !ok {
|
||||
// Not for UDP, fallback to maybe
|
||||
s.logger.ModifyError(s.info, errInvalidModifier)
|
||||
action = ruleset.ActionMaybe
|
||||
} else {
|
||||
var err error
|
||||
uc.Packet, err = udpMI.Process(udp.Payload)
|
||||
if err != nil {
|
||||
// Modifier error, fallback to maybe
|
||||
s.logger.ModifyError(s.info, err)
|
||||
action = ruleset.ActionMaybe
|
||||
}
|
||||
}
|
||||
}
|
||||
if action != ruleset.ActionMaybe {
|
||||
verdict, final := actionToUDPVerdict(action)
|
||||
s.lastVerdict = verdict
|
||||
uc.Verdict = verdict
|
||||
s.logger.UDPStreamAction(s.info, action, false)
|
||||
if final {
|
||||
s.closeActiveEntries()
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(s.activeEntries) == 0 && uc.Verdict == udpVerdictAccept {
|
||||
// All entries are done but no verdict issued, accept stream
|
||||
s.lastVerdict = udpVerdictAcceptStream
|
||||
uc.Verdict = udpVerdictAcceptStream
|
||||
s.logger.UDPStreamAction(s.info, ruleset.ActionAllow, true)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *udpStream) Close() {
|
||||
s.closeActiveEntries()
|
||||
}
|
||||
|
||||
func (s *udpStream) closeActiveEntries() {
|
||||
// Signal close to all active entries & move them to doneEntries
|
||||
updated := false
|
||||
for _, entry := range s.activeEntries {
|
||||
update := entry.Stream.Close(false)
|
||||
up := processPropUpdate(s.info.Props, entry.Name, update)
|
||||
updated = updated || up
|
||||
}
|
||||
if updated {
|
||||
s.logger.UDPStreamPropUpdate(s.info, true)
|
||||
}
|
||||
s.doneEntries = append(s.doneEntries, s.activeEntries...)
|
||||
s.activeEntries = nil
|
||||
}
|
||||
|
||||
func (s *udpStream) feedEntry(entry *udpStreamEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
|
||||
update, done = entry.Stream.Feed(rev, data)
|
||||
if entry.HasLimit {
|
||||
entry.Quota -= len(data)
|
||||
if entry.Quota <= 0 {
|
||||
// Quota exhausted, signal close & move to doneEntries
|
||||
closeUpdate = entry.Stream.Close(true)
|
||||
done = true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func analyzersToUDPAnalyzers(ans []analyzer.Analyzer) []analyzer.UDPAnalyzer {
|
||||
udpAns := make([]analyzer.UDPAnalyzer, 0, len(ans))
|
||||
for _, a := range ans {
|
||||
if udpM, ok := a.(analyzer.UDPAnalyzer); ok {
|
||||
udpAns = append(udpAns, udpM)
|
||||
}
|
||||
}
|
||||
return udpAns
|
||||
}
|
||||
|
||||
func actionToUDPVerdict(a ruleset.Action) (v udpVerdict, final bool) {
|
||||
switch a {
|
||||
case ruleset.ActionMaybe:
|
||||
return udpVerdictAccept, false
|
||||
case ruleset.ActionAllow:
|
||||
return udpVerdictAcceptStream, true
|
||||
case ruleset.ActionBlock:
|
||||
return udpVerdictDropStream, true
|
||||
case ruleset.ActionDrop:
|
||||
return udpVerdictDrop, false
|
||||
case ruleset.ActionModify:
|
||||
return udpVerdictAcceptModify, false
|
||||
default:
|
||||
// Should never happen
|
||||
return udpVerdictAccept, false
|
||||
}
|
||||
}
|
||||
50
engine/utils.go
Normal file
50
engine/utils.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package engine
|
||||
|
||||
import "git.difuse.io/Difuse/Mellaris/analyzer"
|
||||
|
||||
var _ analyzer.Logger = (*analyzerLogger)(nil)
|
||||
|
||||
type analyzerLogger struct {
|
||||
StreamID int64
|
||||
Name string
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
func (l *analyzerLogger) Debugf(format string, args ...interface{}) {
|
||||
l.Logger.AnalyzerDebugf(l.StreamID, l.Name, format, args...)
|
||||
}
|
||||
|
||||
func (l *analyzerLogger) Infof(format string, args ...interface{}) {
|
||||
l.Logger.AnalyzerInfof(l.StreamID, l.Name, format, args...)
|
||||
}
|
||||
|
||||
func (l *analyzerLogger) Errorf(format string, args ...interface{}) {
|
||||
l.Logger.AnalyzerErrorf(l.StreamID, l.Name, format, args...)
|
||||
}
|
||||
|
||||
func processPropUpdate(cpm analyzer.CombinedPropMap, name string, update *analyzer.PropUpdate) (updated bool) {
|
||||
if update == nil || update.Type == analyzer.PropUpdateNone {
|
||||
return false
|
||||
}
|
||||
switch update.Type {
|
||||
case analyzer.PropUpdateMerge:
|
||||
m := cpm[name]
|
||||
if m == nil {
|
||||
m = make(analyzer.PropMap, len(update.M))
|
||||
cpm[name] = m
|
||||
}
|
||||
for k, v := range update.M {
|
||||
m[k] = v
|
||||
}
|
||||
return true
|
||||
case analyzer.PropUpdateReplace:
|
||||
cpm[name] = update.M
|
||||
return true
|
||||
case analyzer.PropUpdateDelete:
|
||||
delete(cpm, name)
|
||||
return true
|
||||
default:
|
||||
// Invalid update type, ignore for now
|
||||
return false
|
||||
}
|
||||
}
|
||||
185
engine/worker.go
Normal file
185
engine/worker.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.difuse.io/Difuse/Mellaris/io"
|
||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||
|
||||
"github.com/bwmarrin/snowflake"
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/google/gopacket/reassembly"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultChanSize = 64
|
||||
defaultTCPMaxBufferedPagesTotal = 4096
|
||||
defaultTCPMaxBufferedPagesPerConnection = 64
|
||||
defaultUDPMaxStreams = 4096
|
||||
)
|
||||
|
||||
type workerPacket struct {
|
||||
StreamID uint32
|
||||
Packet gopacket.Packet
|
||||
SetVerdict func(io.Verdict, []byte) error
|
||||
}
|
||||
|
||||
type worker struct {
|
||||
id int
|
||||
packetChan chan *workerPacket
|
||||
logger Logger
|
||||
|
||||
tcpStreamFactory *tcpStreamFactory
|
||||
tcpStreamPool *reassembly.StreamPool
|
||||
tcpAssembler *reassembly.Assembler
|
||||
|
||||
udpStreamFactory *udpStreamFactory
|
||||
udpStreamManager *udpStreamManager
|
||||
|
||||
modSerializeBuffer gopacket.SerializeBuffer
|
||||
}
|
||||
|
||||
type workerConfig struct {
|
||||
ID int
|
||||
ChanSize int
|
||||
Logger Logger
|
||||
Ruleset ruleset.Ruleset
|
||||
TCPMaxBufferedPagesTotal int
|
||||
TCPMaxBufferedPagesPerConn int
|
||||
UDPMaxStreams int
|
||||
}
|
||||
|
||||
func (c *workerConfig) fillDefaults() {
|
||||
if c.ChanSize <= 0 {
|
||||
c.ChanSize = defaultChanSize
|
||||
}
|
||||
if c.TCPMaxBufferedPagesTotal <= 0 {
|
||||
c.TCPMaxBufferedPagesTotal = defaultTCPMaxBufferedPagesTotal
|
||||
}
|
||||
if c.TCPMaxBufferedPagesPerConn <= 0 {
|
||||
c.TCPMaxBufferedPagesPerConn = defaultTCPMaxBufferedPagesPerConnection
|
||||
}
|
||||
if c.UDPMaxStreams <= 0 {
|
||||
c.UDPMaxStreams = defaultUDPMaxStreams
|
||||
}
|
||||
}
|
||||
|
||||
func newWorker(config workerConfig) (*worker, error) {
|
||||
config.fillDefaults()
|
||||
sfNode, err := snowflake.NewNode(int64(config.ID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tcpSF := &tcpStreamFactory{
|
||||
WorkerID: config.ID,
|
||||
Logger: config.Logger,
|
||||
Node: sfNode,
|
||||
Ruleset: config.Ruleset,
|
||||
}
|
||||
tcpStreamPool := reassembly.NewStreamPool(tcpSF)
|
||||
tcpAssembler := reassembly.NewAssembler(tcpStreamPool)
|
||||
tcpAssembler.MaxBufferedPagesTotal = config.TCPMaxBufferedPagesTotal
|
||||
tcpAssembler.MaxBufferedPagesPerConnection = config.TCPMaxBufferedPagesPerConn
|
||||
udpSF := &udpStreamFactory{
|
||||
WorkerID: config.ID,
|
||||
Logger: config.Logger,
|
||||
Node: sfNode,
|
||||
Ruleset: config.Ruleset,
|
||||
}
|
||||
udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &worker{
|
||||
id: config.ID,
|
||||
packetChan: make(chan *workerPacket, config.ChanSize),
|
||||
logger: config.Logger,
|
||||
tcpStreamFactory: tcpSF,
|
||||
tcpStreamPool: tcpStreamPool,
|
||||
tcpAssembler: tcpAssembler,
|
||||
udpStreamFactory: udpSF,
|
||||
udpStreamManager: udpSM,
|
||||
modSerializeBuffer: gopacket.NewSerializeBuffer(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *worker) Feed(p *workerPacket) {
|
||||
w.packetChan <- p
|
||||
}
|
||||
|
||||
func (w *worker) Run(ctx context.Context) {
|
||||
w.logger.WorkerStart(w.id)
|
||||
defer w.logger.WorkerStop(w.id)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case wPkt := <-w.packetChan:
|
||||
if wPkt == nil {
|
||||
// Closed
|
||||
return
|
||||
}
|
||||
v, b := w.handle(wPkt.StreamID, wPkt.Packet)
|
||||
_ = wPkt.SetVerdict(v, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *worker) UpdateRuleset(r ruleset.Ruleset) error {
|
||||
if err := w.tcpStreamFactory.UpdateRuleset(r); err != nil {
|
||||
return err
|
||||
}
|
||||
return w.udpStreamFactory.UpdateRuleset(r)
|
||||
}
|
||||
|
||||
func (w *worker) handle(streamID uint32, p gopacket.Packet) (io.Verdict, []byte) {
|
||||
netLayer, trLayer := p.NetworkLayer(), p.TransportLayer()
|
||||
if netLayer == nil || trLayer == nil {
|
||||
// Invalid packet
|
||||
return io.VerdictAccept, nil
|
||||
}
|
||||
ipFlow := netLayer.NetworkFlow()
|
||||
switch tr := trLayer.(type) {
|
||||
case *layers.TCP:
|
||||
return w.handleTCP(ipFlow, p.Metadata(), tr), nil
|
||||
case *layers.UDP:
|
||||
v, modPayload := w.handleUDP(streamID, ipFlow, tr)
|
||||
if v == io.VerdictAcceptModify && modPayload != nil {
|
||||
tr.Payload = modPayload
|
||||
_ = tr.SetNetworkLayerForChecksum(netLayer)
|
||||
_ = w.modSerializeBuffer.Clear()
|
||||
err := gopacket.SerializePacket(w.modSerializeBuffer,
|
||||
gopacket.SerializeOptions{
|
||||
FixLengths: true,
|
||||
ComputeChecksums: true,
|
||||
}, p)
|
||||
if err != nil {
|
||||
// Just accept without modification for now
|
||||
return io.VerdictAccept, nil
|
||||
}
|
||||
return v, w.modSerializeBuffer.Bytes()
|
||||
}
|
||||
return v, nil
|
||||
default:
|
||||
// Unsupported protocol
|
||||
return io.VerdictAccept, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (w *worker) handleTCP(ipFlow gopacket.Flow, pMeta *gopacket.PacketMetadata, tcp *layers.TCP) io.Verdict {
|
||||
ctx := &tcpContext{
|
||||
PacketMetadata: pMeta,
|
||||
Verdict: tcpVerdictAccept,
|
||||
}
|
||||
w.tcpAssembler.AssembleWithContext(ipFlow, tcp, ctx)
|
||||
return io.Verdict(ctx.Verdict)
|
||||
}
|
||||
|
||||
func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP) (io.Verdict, []byte) {
|
||||
ctx := &udpContext{
|
||||
Verdict: udpVerdictAccept,
|
||||
}
|
||||
w.udpStreamManager.MatchWithContext(streamID, ipFlow, udp, ctx)
|
||||
return io.Verdict(ctx.Verdict), ctx.Packet
|
||||
}
|
||||
Reference in New Issue
Block a user