Files
Mellaris/engine/tcp_flow.go
T

307 lines
7.7 KiB
Go

package engine
import (
"net"
"sync"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
)
const tcpFlowMaxBuffer = 16384
type tcpFlowDirection uint8
const (
tcpDirC2S tcpFlowDirection = iota
tcpDirS2C
)
type tcpFlow struct {
streamID uint32
srcPort uint16
dstPort uint16
dirSeq [2]uint32
dirBuf [2][]byte
info ruleset.StreamInfo
virgin bool
logger Logger
rulesetVersion uint64
rulesetSource func() (ruleset.Ruleset, uint64)
activeEntries []*tcpFlowEntry
doneEntries []*tcpFlowEntry
lastVerdict io.Verdict
feedCalled [2]bool
}
type tcpFlowEntry struct {
Name string
Stream analyzer.TCPStream
HasLimit bool
Quota int
}
func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
rs, version := f.currentRuleset()
rulesetChanged := version != f.rulesetVersion
if !f.virgin && !rulesetChanged && len(f.activeEntries) == 0 {
return f.lastVerdict
}
if tcp.RST || tcp.FIN {
f.closeActiveEntries()
f.runMatch(rs, version, rulesetChanged, true)
f.maybeFinalizeVerdict()
return f.lastVerdict
}
propUpdated := false
if len(payload) > 0 {
dir, rev := f.resolveDirection(tcp)
expected := f.dirSeq[dir]
if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected {
f.feedCalled[dir] = true
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
if len(f.dirBuf[dir]) <= tcpFlowMaxBuffer {
propUpdated = f.feedAnalyzers(rev)
}
}
}
f.runMatch(rs, version, rulesetChanged, propUpdated)
f.maybeFinalizeVerdict()
return f.lastVerdict
}
func (f *tcpFlow) feedAnalyzers(rev bool) bool {
updated := false
buf := f.dirBuf[uint8(tcpDirC2S)]
if rev {
buf = f.dirBuf[uint8(tcpDirS2C)]
}
for i := len(f.activeEntries) - 1; i >= 0; i-- {
entry := f.activeEntries[i]
update, closeUpdate, done := feedFlowEntry(entry, rev, buf)
u1 := processPropUpdate(f.info.Props, entry.Name, update)
u2 := processPropUpdate(f.info.Props, entry.Name, closeUpdate)
if u1 || u2 {
updated = true
f.logger.TCPStreamPropUpdate(f.info, false)
}
if done {
f.activeEntries = append(f.activeEntries[:i], f.activeEntries[i+1:]...)
f.doneEntries = append(f.doneEntries, entry)
}
}
return updated
}
func (f *tcpFlow) runMatch(rs ruleset.Ruleset, version uint64, rulesetChanged bool, propUpdated bool) {
if !propUpdated && !f.virgin && !rulesetChanged {
return
}
f.virgin = false
f.rulesetVersion = version
result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
if rs != nil {
result = rs.Match(f.info)
}
action := result.Action
if action != ruleset.ActionMaybe && action != ruleset.ActionModify {
verdict := actionToTCPVerdict(action)
f.lastVerdict = verdict
f.closeActiveEntries()
f.logger.TCPStreamAction(f.info, action, false)
}
}
func (f *tcpFlow) maybeFinalizeVerdict() {
if len(f.activeEntries) == 0 && f.lastVerdict == io.VerdictAccept {
f.lastVerdict = io.VerdictAcceptStream
f.logger.TCPStreamAction(f.info, ruleset.ActionAllow, true)
}
}
func (f *tcpFlow) resolveDirection(tcp TCPInfo) (dir uint8, rev bool) {
if tcp.SrcPort == f.srcPort {
return uint8(tcpDirC2S), false
}
return uint8(tcpDirS2C), true
}
func (f *tcpFlow) currentRuleset() (ruleset.Ruleset, uint64) {
if f.rulesetSource == nil {
return nil, f.rulesetVersion
}
return f.rulesetSource()
}
func (f *tcpFlow) closeActiveEntries() {
updated := false
for _, entry := range f.activeEntries {
update := entry.Stream.Close(false)
updated = updated || processPropUpdate(f.info.Props, entry.Name, update)
}
if updated {
f.logger.TCPStreamPropUpdate(f.info, true)
}
f.doneEntries = append(f.doneEntries, f.activeEntries...)
f.activeEntries = nil
}
type tcpFlowManager struct {
mu sync.Mutex
flows map[uint32]*tcpFlow
sfNode *snowflake.Node
logger Logger
rulesetSource func() (ruleset.Ruleset, uint64)
workerID int
macResolver *sourceMACResolver
selector *analyzerSelector
}
func newTCPFlowManager(workerID int, logger Logger, macResolver *sourceMACResolver, node *snowflake.Node, selector *analyzerSelector) *tcpFlowManager {
return &tcpFlowManager{
flows: make(map[uint32]*tcpFlow),
sfNode: node,
logger: logger,
workerID: workerID,
macResolver: macResolver,
selector: selector,
}
}
func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) io.Verdict {
m.mu.Lock()
flow, ok := m.flows[streamID]
if !ok {
flow = m.createFlow(streamID, l3, tcp, payload, srcMAC, dstMAC)
m.flows[streamID] = flow
}
m.mu.Unlock()
verdict := flow.feed(l3, tcp, payload)
if verdict == io.VerdictAcceptStream || verdict == io.VerdictDropStream || tcp.RST || tcp.FIN {
m.mu.Lock()
delete(m.flows, streamID)
m.mu.Unlock()
}
return verdict
}
func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) *tcpFlow {
id := m.sfNode.Generate()
ipSrc := l3.SrcIPAddr()
ipDst := l3.DstIPAddr()
if len(srcMAC) == 0 && m.macResolver != nil {
srcMAC = m.macResolver.Resolve(ipSrc)
}
info := ruleset.StreamInfo{
ID: id.Int64(),
Protocol: ruleset.ProtocolTCP,
SrcMAC: append(net.HardwareAddr(nil), srcMAC...),
DstMAC: append(net.HardwareAddr(nil), dstMAC...),
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: tcp.SrcPort,
DstPort: tcp.DstPort,
Props: make(analyzer.CombinedPropMap),
}
m.logger.TCPStreamNew(m.workerID, info)
rs, version := m.rulesetSource()
var ans []analyzer.TCPAnalyzer
if rs != nil {
baseAns := rs.Analyzers(info)
baseAns = m.selector.SelectTCP(baseAns, payload)
ans = analyzersToTCPAnalyzers(baseAns)
}
entries := make([]*tcpFlowEntry, 0, len(ans))
for _, a := range ans {
entries = append(entries, &tcpFlowEntry{
Name: a.Name(),
Stream: a.NewTCP(analyzer.TCPInfo{
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: tcp.SrcPort,
DstPort: tcp.DstPort,
}, &analyzerLogger{
StreamID: id.Int64(),
Name: a.Name(),
Logger: m.logger,
}),
HasLimit: a.Limit() > 0,
Quota: a.Limit(),
})
}
flow := &tcpFlow{
streamID: streamID,
srcPort: tcp.SrcPort,
dstPort: tcp.DstPort,
info: info,
virgin: true,
logger: m.logger,
rulesetSource: m.rulesetSource,
rulesetVersion: version,
activeEntries: entries,
lastVerdict: io.VerdictAccept,
}
flow.dirSeq[tcpDirC2S] = tcp.Seq + 1
return flow
}
func (m *tcpFlowManager) updateRuleset(r ruleset.Ruleset, version uint64) {
m.rulesetSource = func() (ruleset.Ruleset, uint64) {
return r, version
}
}
func feedFlowEntry(entry *tcpFlowEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
if !entry.HasLimit {
update, done = entry.Stream.Feed(rev, true, false, 0, data)
} else {
qData := data
if len(qData) > entry.Quota {
qData = qData[:entry.Quota]
}
update, done = entry.Stream.Feed(rev, true, false, 0, qData)
entry.Quota -= len(qData)
if entry.Quota <= 0 {
closeUpdate = entry.Stream.Close(true)
done = true
}
}
return
}
func analyzersToTCPAnalyzers(ans []analyzer.Analyzer) []analyzer.TCPAnalyzer {
tcpAns := make([]analyzer.TCPAnalyzer, 0, len(ans))
for _, a := range ans {
if ta, ok := a.(analyzer.TCPAnalyzer); ok {
tcpAns = append(tcpAns, ta)
}
}
return tcpAns
}
func actionToTCPVerdict(a ruleset.Action) io.Verdict {
switch a {
case ruleset.ActionMaybe, ruleset.ActionAllow, ruleset.ActionModify:
return io.VerdictAcceptStream
case ruleset.ActionBlock, ruleset.ActionDrop:
return io.VerdictDropStream
default:
return io.VerdictAcceptStream
}
}