fix: eliminate stale verdict poisoning, memory leaks, data races, and per-packet allocations in engine
This commit is contained in:
@@ -130,8 +130,8 @@ func (s *httpStream) parseResponseLine() utils.LSMAction {
|
|||||||
return utils.LSMActionCancel
|
return utils.LSMActionCancel
|
||||||
}
|
}
|
||||||
version := fields[0]
|
version := fields[0]
|
||||||
status, _ := strconv.Atoi(fields[1])
|
status, err := strconv.Atoi(fields[1])
|
||||||
if !strings.HasPrefix(version, "HTTP/") || status == 0 {
|
if err != nil || !strings.HasPrefix(version, "HTTP/") || status == 0 {
|
||||||
// Invalid version
|
// Invalid version
|
||||||
return utils.LSMActionCancel
|
return utils.LSMActionCancel
|
||||||
}
|
}
|
||||||
|
|||||||
+4
-2
@@ -6,6 +6,8 @@ import (
|
|||||||
"git.difuse.io/Difuse/Mellaris/analyzer/utils"
|
"git.difuse.io/Difuse/Mellaris/analyzer/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const maxHandshakeLen = 65536
|
||||||
|
|
||||||
var _ analyzer.TCPAnalyzer = (*TLSAnalyzer)(nil)
|
var _ analyzer.TCPAnalyzer = (*TLSAnalyzer)(nil)
|
||||||
|
|
||||||
type TLSAnalyzer struct{}
|
type TLSAnalyzer struct{}
|
||||||
@@ -123,7 +125,7 @@ func (s *tlsStream) tlsClientHelloPreprocess() utils.LSMAction {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.clientHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8])
|
s.clientHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8])
|
||||||
if s.clientHelloLen < minDataSize {
|
if s.clientHelloLen < minDataSize || s.clientHelloLen > maxHandshakeLen {
|
||||||
return utils.LSMActionCancel
|
return utils.LSMActionCancel
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,7 +169,7 @@ func (s *tlsStream) tlsServerHelloPreprocess() utils.LSMAction {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.serverHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8])
|
s.serverHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8])
|
||||||
if s.serverHelloLen < minDataSize {
|
if s.serverHelloLen < minDataSize || s.serverHelloLen > maxHandshakeLen {
|
||||||
return utils.LSMActionCancel
|
return utils.LSMActionCancel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ const (
|
|||||||
OpenVPNMinPktLen = 6
|
OpenVPNMinPktLen = 6
|
||||||
OpenVPNTCPPktDefaultLimit = 256
|
OpenVPNTCPPktDefaultLimit = 256
|
||||||
OpenVPNUDPPktDefaultLimit = 256
|
OpenVPNUDPPktDefaultLimit = 256
|
||||||
|
OpenVPNTCPMaxPktLen = 4096
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenVPNAnalyzer struct{}
|
type OpenVPNAnalyzer struct{}
|
||||||
@@ -195,7 +196,7 @@ func newOpenVPNUDPStream(logger analyzer.Logger) *openvpnUDPStream {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *openvpnUDPStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, d bool) {
|
func (o *openvpnUDPStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, d bool) {
|
||||||
if len(data) == 0 {
|
if len(data) < OpenVPNMinPktLen {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
var update *analyzer.PropUpdate
|
var update *analyzer.PropUpdate
|
||||||
@@ -338,7 +339,7 @@ func (o *openvpnTCPStream) parsePkt(rev bool) (p *openvpnPkt, action utils.LSMAc
|
|||||||
return nil, utils.LSMActionPause
|
return nil, utils.LSMActionPause
|
||||||
}
|
}
|
||||||
|
|
||||||
if pktLen < OpenVPNMinPktLen {
|
if pktLen < OpenVPNMinPktLen || pktLen > OpenVPNTCPMaxPktLen {
|
||||||
return nil, utils.LSMActionCancel
|
return nil, utils.LSMActionCancel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
quicInvalidCountThreshold = 16
|
quicInvalidCountThreshold = 16
|
||||||
quicMaxCryptoDataLen = 256 * 1024
|
quicMaxCryptoDataLen = 256 * 1024
|
||||||
|
quicMaxFrameEntries = 100
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -158,6 +159,9 @@ func (s *quicStream) mergeFrame(offset int64, data []byte) {
|
|||||||
if len(data) == 0 || offset < 0 {
|
if len(data) == 0 || offset < 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if len(s.frames) >= quicMaxFrameEntries {
|
||||||
|
return
|
||||||
|
}
|
||||||
if s.frames == nil {
|
if s.frames == nil {
|
||||||
s.frames = make(map[int64][]byte)
|
s.frames = make(map[int64][]byte)
|
||||||
}
|
}
|
||||||
|
|||||||
+33
-4
@@ -5,6 +5,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.difuse.io/Difuse/Mellaris/io"
|
"git.difuse.io/Difuse/Mellaris/io"
|
||||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||||
@@ -15,8 +16,14 @@ var _ Engine = (*engine)(nil)
|
|||||||
type verdictEntry struct {
|
type verdictEntry struct {
|
||||||
Verdict io.Verdict
|
Verdict io.Verdict
|
||||||
Gen int64
|
Gen int64
|
||||||
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
verdictTTL = 15 * time.Second
|
||||||
|
verdictSweepInterval = 15 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
type engine struct {
|
type engine struct {
|
||||||
logger Logger
|
logger Logger
|
||||||
io io.PacketIO
|
io io.PacketIO
|
||||||
@@ -39,7 +46,7 @@ func NewEngine(config Config) (Engine, error) {
|
|||||||
}
|
}
|
||||||
overflowPolicy := config.OverflowPolicy
|
overflowPolicy := config.OverflowPolicy
|
||||||
if overflowPolicy == "" {
|
if overflowPolicy == "" {
|
||||||
overflowPolicy = OverflowPolicyAccept
|
overflowPolicy = OverflowPolicyDrop
|
||||||
}
|
}
|
||||||
selectionMode := config.AnalyzerSelectionMode
|
selectionMode := config.AnalyzerSelectionMode
|
||||||
if selectionMode == "" {
|
if selectionMode == "" {
|
||||||
@@ -83,7 +90,6 @@ func NewEngine(config Config) (Engine, error) {
|
|||||||
|
|
||||||
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
|
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
|
||||||
e.verdictsGen.Add(1)
|
e.verdictsGen.Add(1)
|
||||||
e.verdicts = sync.Map{}
|
|
||||||
for _, w := range e.workers {
|
for _, w := range e.workers {
|
||||||
if err := w.UpdateRuleset(r); err != nil {
|
if err := w.UpdateRuleset(r); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -100,6 +106,7 @@ func (e *engine) Run(ctx context.Context) error {
|
|||||||
go w.Run(ioCtx)
|
go w.Run(ioCtx)
|
||||||
}
|
}
|
||||||
go e.drainResults(ioCtx)
|
go e.drainResults(ioCtx)
|
||||||
|
go e.sweepVerdicts(ioCtx)
|
||||||
|
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
|
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
|
||||||
@@ -124,6 +131,7 @@ func (e *engine) Run(ctx context.Context) error {
|
|||||||
func (e *engine) dispatch(p io.Packet) bool {
|
func (e *engine) dispatch(p io.Packet) bool {
|
||||||
streamID := p.StreamID()
|
streamID := p.StreamID()
|
||||||
|
|
||||||
|
if streamID != 0 {
|
||||||
if v, ok := e.verdicts.Load(streamID); ok {
|
if v, ok := e.verdicts.Load(streamID); ok {
|
||||||
entry := v.(verdictEntry)
|
entry := v.(verdictEntry)
|
||||||
if entry.Gen == e.verdictsGen.Load() {
|
if entry.Gen == e.verdictsGen.Load() {
|
||||||
@@ -131,6 +139,7 @@ func (e *engine) dispatch(p io.Packet) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
data := p.Data()
|
data := p.Data()
|
||||||
if !validPacket(data) {
|
if !validPacket(data) {
|
||||||
@@ -163,12 +172,32 @@ func (e *engine) dispatch(p io.Packet) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *engine) applyWorkerResult(r workerResult) {
|
func (e *engine) applyWorkerResult(r workerResult) {
|
||||||
if r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream {
|
if r.StreamID != 0 && (r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream) {
|
||||||
e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen})
|
e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen, CreatedAt: time.Now()})
|
||||||
}
|
}
|
||||||
_ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket)
|
_ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *engine) sweepVerdicts(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(verdictSweepInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
now := time.Now()
|
||||||
|
e.verdicts.Range(func(key, value interface{}) bool {
|
||||||
|
entry := value.(verdictEntry)
|
||||||
|
if now.Sub(entry.CreatedAt) > verdictTTL {
|
||||||
|
e.verdicts.Delete(key)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func validPacket(data []byte) bool {
|
func validPacket(data []byte) bool {
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
return false
|
return false
|
||||||
|
|||||||
+4
-1
@@ -59,7 +59,10 @@ func ParseL3(data []byte) (l3 L3Info, transport []byte, ok bool) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
totalLen := int(uint16(data[2])<<8 | uint16(data[3]))
|
totalLen := int(uint16(data[2])<<8 | uint16(data[3]))
|
||||||
if totalLen < int(ihl)*4 || totalLen > len(data) {
|
if totalLen < int(ihl)*4 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if totalLen > len(data) {
|
||||||
totalLen = len(data)
|
totalLen = len(data)
|
||||||
}
|
}
|
||||||
return L3Info{
|
return L3Info{
|
||||||
|
|||||||
+13
-28
@@ -1,7 +1,6 @@
|
|||||||
package engine
|
package engine
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||||
@@ -9,8 +8,6 @@ import (
|
|||||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||||
|
|
||||||
"github.com/bwmarrin/snowflake"
|
"github.com/bwmarrin/snowflake"
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type fixedRuleset struct {
|
type fixedRuleset struct {
|
||||||
@@ -60,25 +57,19 @@ func TestUDPStreamUsesUpdatedRuleset(t *testing.T) {
|
|||||||
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
|
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
|
||||||
}
|
}
|
||||||
|
|
||||||
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
|
tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 12345, BPort: 53}
|
||||||
udp := &layers.UDP{
|
payload := []byte("query")
|
||||||
SrcPort: 12345,
|
|
||||||
DstPort: 53,
|
|
||||||
BaseLayer: layers.BaseLayer{
|
|
||||||
Payload: []byte("query"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := &udpContext{Verdict: udpVerdictAccept}
|
ctx := &udpContext{Verdict: udpVerdictAccept}
|
||||||
s := f.New(ipFlow, udp.TransportFlow(), udp, ctx)
|
s := f.New(tuple, payload, ctx)
|
||||||
|
|
||||||
if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil {
|
if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil {
|
||||||
t.Fatalf("update ruleset: %v", err)
|
t.Fatalf("update ruleset: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.Accept(udp, false, ctx) {
|
if !s.Accept(false, ctx) {
|
||||||
t.Fatalf("unexpected Accept=false for virgin stream")
|
t.Fatalf("unexpected Accept=false for virgin stream")
|
||||||
}
|
}
|
||||||
s.Feed(udp, false, ctx)
|
s.Feed(false, payload, ctx)
|
||||||
if ctx.Verdict != udpVerdictDropStream {
|
if ctx.Verdict != udpVerdictDropStream {
|
||||||
t.Fatalf("verdict=%v want=%v", ctx.Verdict, udpVerdictDropStream)
|
t.Fatalf("verdict=%v want=%v", ctx.Verdict, udpVerdictDropStream)
|
||||||
}
|
}
|
||||||
@@ -96,21 +87,15 @@ func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
|
|||||||
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
|
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
|
||||||
}
|
}
|
||||||
|
|
||||||
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
|
tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 12345, BPort: 53}
|
||||||
udp := &layers.UDP{
|
payload := []byte("query")
|
||||||
SrcPort: 12345,
|
|
||||||
DstPort: 53,
|
|
||||||
BaseLayer: layers.BaseLayer{
|
|
||||||
Payload: []byte("query"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx1 := &udpContext{Verdict: udpVerdictAccept}
|
ctx1 := &udpContext{Verdict: udpVerdictAccept}
|
||||||
s := f.New(ipFlow, udp.TransportFlow(), udp, ctx1)
|
s := f.New(tuple, payload, ctx1)
|
||||||
if !s.Accept(udp, false, ctx1) {
|
if !s.Accept(false, ctx1) {
|
||||||
t.Fatalf("unexpected Accept=false before first feed")
|
t.Fatalf("unexpected Accept=false before first feed")
|
||||||
}
|
}
|
||||||
s.Feed(udp, false, ctx1)
|
s.Feed(false, payload, ctx1)
|
||||||
if ctx1.Verdict != udpVerdictAcceptStream {
|
if ctx1.Verdict != udpVerdictAcceptStream {
|
||||||
t.Fatalf("verdict=%v want=%v", ctx1.Verdict, udpVerdictAcceptStream)
|
t.Fatalf("verdict=%v want=%v", ctx1.Verdict, udpVerdictAcceptStream)
|
||||||
}
|
}
|
||||||
@@ -120,16 +105,16 @@ func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx2 := &udpContext{Verdict: udpVerdictAccept}
|
ctx2 := &udpContext{Verdict: udpVerdictAccept}
|
||||||
if !s.Accept(udp, false, ctx2) {
|
if !s.Accept(false, ctx2) {
|
||||||
t.Fatalf("expected Accept=true after ruleset update")
|
t.Fatalf("expected Accept=true after ruleset update")
|
||||||
}
|
}
|
||||||
s.Feed(udp, false, ctx2)
|
s.Feed(false, payload, ctx2)
|
||||||
if ctx2.Verdict != udpVerdictDropStream {
|
if ctx2.Verdict != udpVerdictDropStream {
|
||||||
t.Fatalf("verdict=%v want=%v", ctx2.Verdict, udpVerdictDropStream)
|
t.Fatalf("verdict=%v want=%v", ctx2.Verdict, udpVerdictDropStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx3 := &udpContext{Verdict: udpVerdictAccept}
|
ctx3 := &udpContext{Verdict: udpVerdictAccept}
|
||||||
if s.Accept(udp, false, ctx3) {
|
if s.Accept(false, ctx3) {
|
||||||
t.Fatalf("expected Accept=false with unchanged ruleset and no active entries")
|
t.Fatalf("expected Accept=false with unchanged ruleset and no active entries")
|
||||||
}
|
}
|
||||||
if ctx3.Verdict != udpVerdictDropStream {
|
if ctx3.Verdict != udpVerdictDropStream {
|
||||||
|
|||||||
+24
-3
@@ -3,6 +3,7 @@ package engine
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||||
"git.difuse.io/Difuse/Mellaris/io"
|
"git.difuse.io/Difuse/Mellaris/io"
|
||||||
@@ -13,6 +14,8 @@ import (
|
|||||||
|
|
||||||
const tcpFlowMaxBuffer = 16384
|
const tcpFlowMaxBuffer = 16384
|
||||||
|
|
||||||
|
const tcpFlowIdleTimeout = 10 * time.Minute
|
||||||
|
|
||||||
type tcpFlowDirection uint8
|
type tcpFlowDirection uint8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -37,6 +40,7 @@ type tcpFlow struct {
|
|||||||
doneEntries []*tcpFlowEntry
|
doneEntries []*tcpFlowEntry
|
||||||
lastVerdict io.Verdict
|
lastVerdict io.Verdict
|
||||||
feedCalled [2]bool
|
feedCalled [2]bool
|
||||||
|
lastSeen time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type tcpFlowEntry struct {
|
type tcpFlowEntry struct {
|
||||||
@@ -67,16 +71,17 @@ func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
|
|||||||
expected := f.dirSeq[dir]
|
expected := f.dirSeq[dir]
|
||||||
if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected {
|
if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected {
|
||||||
f.feedCalled[dir] = true
|
f.feedCalled[dir] = true
|
||||||
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
|
|
||||||
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
|
|
||||||
if len(f.dirBuf[dir]) <= tcpFlowMaxBuffer {
|
if len(f.dirBuf[dir]) <= tcpFlowMaxBuffer {
|
||||||
|
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
|
||||||
propUpdated = f.feedAnalyzers(rev)
|
propUpdated = f.feedAnalyzers(rev)
|
||||||
}
|
}
|
||||||
|
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
f.runMatch(rs, version, rulesetChanged, propUpdated)
|
f.runMatch(rs, version, rulesetChanged, propUpdated)
|
||||||
f.maybeFinalizeVerdict()
|
f.maybeFinalizeVerdict()
|
||||||
|
f.lastSeen = time.Now()
|
||||||
return f.lastVerdict
|
return f.lastVerdict
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -218,7 +223,11 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
|
|||||||
Props: make(analyzer.CombinedPropMap),
|
Props: make(analyzer.CombinedPropMap),
|
||||||
}
|
}
|
||||||
m.logger.TCPStreamNew(m.workerID, info)
|
m.logger.TCPStreamNew(m.workerID, info)
|
||||||
rs, version := m.rulesetSource()
|
var rs ruleset.Ruleset
|
||||||
|
var version uint64
|
||||||
|
if m.rulesetSource != nil {
|
||||||
|
rs, version = m.rulesetSource()
|
||||||
|
}
|
||||||
var ans []analyzer.TCPAnalyzer
|
var ans []analyzer.TCPAnalyzer
|
||||||
if rs != nil {
|
if rs != nil {
|
||||||
baseAns := rs.Analyzers(info)
|
baseAns := rs.Analyzers(info)
|
||||||
@@ -255,6 +264,7 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
|
|||||||
rulesetVersion: version,
|
rulesetVersion: version,
|
||||||
activeEntries: entries,
|
activeEntries: entries,
|
||||||
lastVerdict: io.VerdictAccept,
|
lastVerdict: io.VerdictAccept,
|
||||||
|
lastSeen: time.Now(),
|
||||||
}
|
}
|
||||||
flow.dirSeq[tcpDirC2S] = tcp.Seq + 1
|
flow.dirSeq[tcpDirC2S] = tcp.Seq + 1
|
||||||
return flow
|
return flow
|
||||||
@@ -266,6 +276,17 @@ func (m *tcpFlowManager) updateRuleset(r ruleset.Ruleset, version uint64) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *tcpFlowManager) cleanupIdle(now time.Time) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
for id, flow := range m.flows {
|
||||||
|
if now.Sub(flow.lastSeen) > tcpFlowIdleTimeout {
|
||||||
|
flow.closeActiveEntries()
|
||||||
|
delete(m.flows, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func feedFlowEntry(entry *tcpFlowEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
|
func feedFlowEntry(entry *tcpFlowEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
|
||||||
if !entry.HasLimit {
|
if !entry.HasLimit {
|
||||||
update, done = entry.Stream.Feed(rev, true, false, 0, data)
|
update, done = entry.Stream.Feed(rev, true, false, 0, data)
|
||||||
|
|||||||
+49
-53
@@ -12,8 +12,6 @@ import (
|
|||||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||||
|
|
||||||
"github.com/bwmarrin/snowflake"
|
"github.com/bwmarrin/snowflake"
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
lru "github.com/hashicorp/golang-lru/v2"
|
lru "github.com/hashicorp/golang-lru/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,9 +47,10 @@ type udpStreamFactory struct {
|
|||||||
RulesetVersion uint64
|
RulesetVersion uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) *udpStream {
|
func (f *udpStreamFactory) New(k udpTupleKey, payload []byte, uc *udpContext) *udpStream {
|
||||||
id := f.Node.Generate()
|
id := f.Node.Generate()
|
||||||
ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw())
|
ipSrc := net.IP(k.AIP[:k.ALen])
|
||||||
|
ipDst := net.IP(k.BIP[:k.BLen])
|
||||||
info := ruleset.StreamInfo{
|
info := ruleset.StreamInfo{
|
||||||
ID: id.Int64(),
|
ID: id.Int64(),
|
||||||
Protocol: ruleset.ProtocolUDP,
|
Protocol: ruleset.ProtocolUDP,
|
||||||
@@ -59,8 +58,8 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
|
|||||||
DstMAC: append(net.HardwareAddr(nil), uc.DstMAC...),
|
DstMAC: append(net.HardwareAddr(nil), uc.DstMAC...),
|
||||||
SrcIP: ipSrc,
|
SrcIP: ipSrc,
|
||||||
DstIP: ipDst,
|
DstIP: ipDst,
|
||||||
SrcPort: uint16(udp.SrcPort),
|
SrcPort: k.APort,
|
||||||
DstPort: uint16(udp.DstPort),
|
DstPort: k.BPort,
|
||||||
Props: make(analyzer.CombinedPropMap),
|
Props: make(analyzer.CombinedPropMap),
|
||||||
}
|
}
|
||||||
f.Logger.UDPStreamNew(f.WorkerID, info)
|
f.Logger.UDPStreamNew(f.WorkerID, info)
|
||||||
@@ -69,11 +68,10 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
|
|||||||
if rs != nil {
|
if rs != nil {
|
||||||
baseAns := rs.Analyzers(info)
|
baseAns := rs.Analyzers(info)
|
||||||
if f.Selector != nil {
|
if f.Selector != nil {
|
||||||
baseAns = f.Selector.SelectUDP(baseAns, udp.Payload)
|
baseAns = f.Selector.SelectUDP(baseAns, payload)
|
||||||
}
|
}
|
||||||
ans = analyzersToUDPAnalyzers(baseAns)
|
ans = analyzersToUDPAnalyzers(baseAns)
|
||||||
}
|
}
|
||||||
// Create entries for each analyzer
|
|
||||||
entries := make([]*udpStreamEntry, 0, len(ans))
|
entries := make([]*udpStreamEntry, 0, len(ans))
|
||||||
for _, a := range ans {
|
for _, a := range ans {
|
||||||
entries = append(entries, &udpStreamEntry{
|
entries = append(entries, &udpStreamEntry{
|
||||||
@@ -81,8 +79,8 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
|
|||||||
Stream: a.NewUDP(analyzer.UDPInfo{
|
Stream: a.NewUDP(analyzer.UDPInfo{
|
||||||
SrcIP: ipSrc,
|
SrcIP: ipSrc,
|
||||||
DstIP: ipDst,
|
DstIP: ipDst,
|
||||||
SrcPort: uint16(udp.SrcPort),
|
SrcPort: k.APort,
|
||||||
DstPort: uint16(udp.DstPort),
|
DstPort: k.BPort,
|
||||||
}, &analyzerLogger{
|
}, &analyzerLogger{
|
||||||
StreamID: id.Int64(),
|
StreamID: id.Int64(),
|
||||||
Name: a.Name(),
|
Name: a.Name(),
|
||||||
@@ -126,8 +124,13 @@ type udpStreamManager struct {
|
|||||||
|
|
||||||
type udpStreamValue struct {
|
type udpStreamValue struct {
|
||||||
Stream *udpStream
|
Stream *udpStream
|
||||||
IPFlow gopacket.Flow
|
Tuple udpTupleKey
|
||||||
UDPFlow gopacket.Flow
|
}
|
||||||
|
|
||||||
|
func (v *udpStreamValue) Match(k udpTupleKey) (ok, rev bool) {
|
||||||
|
fwd := v.Tuple == k
|
||||||
|
rev = v.Tuple == reverseTuple(k)
|
||||||
|
return fwd || rev, rev
|
||||||
}
|
}
|
||||||
|
|
||||||
type udpTupleKey struct {
|
type udpTupleKey struct {
|
||||||
@@ -139,12 +142,6 @@ type udpTupleKey struct {
|
|||||||
BPort uint16
|
BPort uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) {
|
|
||||||
fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow
|
|
||||||
rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()
|
|
||||||
return fwd || rev, rev
|
|
||||||
}
|
|
||||||
|
|
||||||
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) {
|
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) {
|
||||||
m := &udpStreamManager{
|
m := &udpStreamManager{
|
||||||
factory: factory,
|
factory: factory,
|
||||||
@@ -153,6 +150,9 @@ func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *stats
|
|||||||
stats: stats,
|
stats: stats,
|
||||||
}
|
}
|
||||||
ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) {
|
ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) {
|
||||||
|
if v != nil && v.Stream != nil {
|
||||||
|
v.Stream.Close()
|
||||||
|
}
|
||||||
m.removeTupleMappingLocked(k)
|
m.removeTupleMappingLocked(k)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -162,16 +162,12 @@ func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *stats
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) {
|
func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, rev bool, payload []byte, uc *udpContext) {
|
||||||
rev := false
|
|
||||||
value, ok := m.streams.Get(streamID)
|
value, ok := m.streams.Get(streamID)
|
||||||
tuple := canonicalUDPTupleKey(ipFlow, udp)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
if m.stats != nil {
|
if m.stats != nil {
|
||||||
m.stats.UDPTupleLookups.Add(1)
|
m.stats.UDPTupleLookups.Add(1)
|
||||||
}
|
}
|
||||||
// Conntrack IDs can change during early flow lifetime on some systems.
|
|
||||||
// Rebind by canonical 5-tuple in O(1).
|
|
||||||
matchedKey, found := m.tupleIndex[tuple]
|
matchedKey, found := m.tupleIndex[tuple]
|
||||||
var matchedValue *udpStreamValue
|
var matchedValue *udpStreamValue
|
||||||
var matchedRev bool
|
var matchedRev bool
|
||||||
@@ -188,7 +184,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if found {
|
if found {
|
||||||
_, matchedRev = matchedValue.Match(ipFlow, udp.TransportFlow())
|
_, matchedRev = matchedValue.Match(tuple)
|
||||||
value = matchedValue
|
value = matchedValue
|
||||||
rev = matchedRev
|
rev = matchedRev
|
||||||
if matchedKey != streamID {
|
if matchedKey != streamID {
|
||||||
@@ -197,32 +193,27 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
|
|||||||
m.bindTupleLocked(streamID, tuple)
|
m.bindTupleLocked(streamID, tuple)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// New stream
|
|
||||||
value = &udpStreamValue{
|
value = &udpStreamValue{
|
||||||
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc),
|
Stream: m.factory.New(tuple, payload, uc),
|
||||||
IPFlow: ipFlow,
|
Tuple: tuple,
|
||||||
UDPFlow: udp.TransportFlow(),
|
|
||||||
}
|
}
|
||||||
m.streams.Add(streamID, value)
|
m.streams.Add(streamID, value)
|
||||||
m.bindTupleLocked(streamID, tuple)
|
m.bindTupleLocked(streamID, tuple)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Stream ID exists, but is it really the same stream?
|
ok, rev = value.Match(tuple)
|
||||||
ok, rev = value.Match(ipFlow, udp.TransportFlow())
|
|
||||||
if !ok {
|
if !ok {
|
||||||
// It's not - close the old stream & replace it with a new one
|
|
||||||
value.Stream.Close()
|
value.Stream.Close()
|
||||||
value = &udpStreamValue{
|
value = &udpStreamValue{
|
||||||
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc),
|
Stream: m.factory.New(tuple, payload, uc),
|
||||||
IPFlow: ipFlow,
|
Tuple: tuple,
|
||||||
UDPFlow: udp.TransportFlow(),
|
|
||||||
}
|
}
|
||||||
m.streams.Add(streamID, value)
|
m.streams.Add(streamID, value)
|
||||||
m.bindTupleLocked(streamID, tuple)
|
m.bindTupleLocked(streamID, tuple)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if value.Stream.Accept(udp, rev, uc) {
|
if value.Stream.Accept(rev, uc) {
|
||||||
value.Stream.Feed(udp, rev, uc)
|
value.Stream.Feed(rev, payload, uc)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -242,25 +233,34 @@ func (m *udpStreamManager) removeTupleMappingLocked(streamID uint32) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func canonicalUDPTupleKey(ipFlow gopacket.Flow, udp *layers.UDP) udpTupleKey {
|
func canonicalUDPTupleKey(srcIP, dstIP net.IP, srcPort, dstPort uint16) udpTupleKey {
|
||||||
srcIP := ipFlow.Src().Raw()
|
srcRaw := []byte(srcIP)
|
||||||
dstIP := ipFlow.Dst().Raw()
|
dstRaw := []byte(dstIP)
|
||||||
srcPort := uint16(udp.SrcPort)
|
|
||||||
dstPort := uint16(udp.DstPort)
|
|
||||||
|
|
||||||
if compareIPEndpoint(srcIP, srcPort, dstIP, dstPort) > 0 {
|
if compareIPEndpoint(srcRaw, srcPort, dstRaw, dstPort) > 0 {
|
||||||
srcIP, dstIP = dstIP, srcIP
|
srcRaw, dstRaw = dstRaw, srcRaw
|
||||||
srcPort, dstPort = dstPort, srcPort
|
srcPort, dstPort = dstPort, srcPort
|
||||||
}
|
}
|
||||||
|
|
||||||
var key udpTupleKey
|
var key udpTupleKey
|
||||||
key.ALen = uint8(copy(key.AIP[:], srcIP))
|
key.ALen = uint8(copy(key.AIP[:], srcRaw))
|
||||||
key.BLen = uint8(copy(key.BIP[:], dstIP))
|
key.BLen = uint8(copy(key.BIP[:], dstRaw))
|
||||||
key.APort = srcPort
|
key.APort = srcPort
|
||||||
key.BPort = dstPort
|
key.BPort = dstPort
|
||||||
return key
|
return key
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func reverseTuple(k udpTupleKey) udpTupleKey {
|
||||||
|
var r udpTupleKey
|
||||||
|
r.ALen = k.BLen
|
||||||
|
r.BLen = k.ALen
|
||||||
|
r.AIP = k.BIP
|
||||||
|
r.BIP = k.AIP
|
||||||
|
r.APort = k.BPort
|
||||||
|
r.BPort = k.APort
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int {
|
func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int {
|
||||||
if len(aIP) != len(bIP) {
|
if len(aIP) != len(bIP) {
|
||||||
if len(aIP) < len(bIP) {
|
if len(aIP) < len(bIP) {
|
||||||
@@ -298,11 +298,8 @@ type udpStreamEntry struct {
|
|||||||
Quota int
|
Quota int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool {
|
func (s *udpStream) Accept(rev bool, uc *udpContext) bool {
|
||||||
if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() {
|
if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() {
|
||||||
// Make sure every stream matches against the ruleset at least once,
|
|
||||||
// even if there are no activeEntries, as the ruleset may have built-in
|
|
||||||
// properties that need to be matched.
|
|
||||||
return true
|
return true
|
||||||
} else {
|
} else {
|
||||||
uc.Verdict = s.lastVerdict
|
uc.Verdict = s.lastVerdict
|
||||||
@@ -310,12 +307,11 @@ func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) {
|
func (s *udpStream) Feed(rev bool, payload []byte, uc *udpContext) {
|
||||||
updated := false
|
updated := false
|
||||||
for i := len(s.activeEntries) - 1; i >= 0; i-- {
|
for i := len(s.activeEntries) - 1; i >= 0; i-- {
|
||||||
// Important: reverse order so we can remove entries
|
|
||||||
entry := s.activeEntries[i]
|
entry := s.activeEntries[i]
|
||||||
update, closeUpdate, done := s.feedEntry(entry, rev, udp.Payload)
|
update, closeUpdate, done := s.feedEntry(entry, rev, payload)
|
||||||
up1 := processPropUpdate(s.info.Props, entry.Name, update)
|
up1 := processPropUpdate(s.info.Props, entry.Name, update)
|
||||||
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
|
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
|
||||||
updated = updated || up1 || up2
|
updated = updated || up1 || up2
|
||||||
@@ -345,7 +341,7 @@ func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) {
|
|||||||
action = ruleset.ActionMaybe
|
action = ruleset.ActionMaybe
|
||||||
} else {
|
} else {
|
||||||
var err error
|
var err error
|
||||||
uc.Packet, err = udpMI.Process(udp.Payload)
|
uc.Packet, err = udpMI.Process(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Modifier error, fallback to maybe
|
// Modifier error, fallback to maybe
|
||||||
s.logger.ModifyError(s.info, err)
|
s.logger.ModifyError(s.info, err)
|
||||||
|
|||||||
@@ -1,20 +1,16 @@
|
|||||||
package engine
|
package engine
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||||
|
|
||||||
"github.com/bwmarrin/snowflake"
|
"github.com/bwmarrin/snowflake"
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type legacyUDPStreamValue struct {
|
type legacyUDPStreamValue struct {
|
||||||
IPFlow gopacket.Flow
|
Tuple udpTupleKey
|
||||||
UDPFlow gopacket.Flow
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type emptyRuleset struct{}
|
type emptyRuleset struct{}
|
||||||
@@ -36,17 +32,20 @@ func benchmarkUDPManager(b *testing.B, churn bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const flowCount = 20000
|
const flowCount = 20000
|
||||||
flows := make([]gopacket.Flow, flowCount)
|
tuples := make([]udpTupleKey, flowCount)
|
||||||
udps := make([]*layers.UDP, flowCount)
|
payloads := make([][]byte, flowCount)
|
||||||
for i := 0; i < flowCount; i++ {
|
for i := 0; i < flowCount; i++ {
|
||||||
a := byte(i >> 8)
|
a := byte(i >> 8)
|
||||||
c := byte(i)
|
c := byte(i)
|
||||||
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4())
|
var t udpTupleKey
|
||||||
udps[i] = &layers.UDP{
|
t.AIP = [16]byte{10, a, 0, c}
|
||||||
SrcPort: layers.UDPPort(1024 + i%20000),
|
t.ALen = 4
|
||||||
DstPort: layers.UDPPort(20000 + (i*7)%20000),
|
t.BIP = [16]byte{172, 16, a, c}
|
||||||
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}},
|
t.BLen = 4
|
||||||
}
|
t.APort = 1024 + uint16(i%20000)
|
||||||
|
t.BPort = 20000 + uint16((i*7)%20000)
|
||||||
|
tuples[i] = t
|
||||||
|
payloads[i] = []byte{0x01, 0x00, 0x00, 0x00}
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := &udpContext{Verdict: udpVerdictAccept}
|
ctx := &udpContext{Verdict: udpVerdictAccept}
|
||||||
@@ -59,7 +58,7 @@ func benchmarkUDPManager(b *testing.B, churn bool) {
|
|||||||
}
|
}
|
||||||
ctx.Verdict = udpVerdictAccept
|
ctx.Verdict = udpVerdictAccept
|
||||||
ctx.Packet = nil
|
ctx.Packet = nil
|
||||||
mgr.MatchWithContext(streamID, flows[idx], udps[idx], ctx)
|
mgr.MatchWithContext(streamID, tuples[idx], false, payloads[idx], ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,27 +72,25 @@ func BenchmarkUDPManagerMatchStreamIDChurn(b *testing.B) {
|
|||||||
|
|
||||||
func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) {
|
func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) {
|
||||||
const flowCount = 5000
|
const flowCount = 5000
|
||||||
flows := make([]gopacket.Flow, flowCount)
|
tuples := make([]udpTupleKey, flowCount)
|
||||||
udps := make([]*layers.UDP, flowCount)
|
|
||||||
for i := 0; i < flowCount; i++ {
|
for i := 0; i < flowCount; i++ {
|
||||||
a := byte(i >> 8)
|
a := byte(i >> 8)
|
||||||
c := byte(i)
|
c := byte(i)
|
||||||
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4())
|
var t udpTupleKey
|
||||||
udps[i] = &layers.UDP{
|
t.AIP = [16]byte{10, a, 0, c}
|
||||||
SrcPort: layers.UDPPort(1024 + i%20000),
|
t.ALen = 4
|
||||||
DstPort: layers.UDPPort(20000 + (i*7)%20000),
|
t.BIP = [16]byte{172, 16, a, c}
|
||||||
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}},
|
t.BLen = 4
|
||||||
}
|
t.APort = 1024 + uint16(i%20000)
|
||||||
|
t.BPort = 20000 + uint16((i*7)%20000)
|
||||||
|
tuples[i] = t
|
||||||
}
|
}
|
||||||
|
|
||||||
streams := make(map[uint32]*legacyUDPStreamValue, flowCount)
|
streams := make(map[uint32]*legacyUDPStreamValue, flowCount)
|
||||||
keys := make([]uint32, 0, flowCount)
|
keys := make([]uint32, 0, flowCount)
|
||||||
for i := 0; i < flowCount; i++ {
|
for i := 0; i < flowCount; i++ {
|
||||||
streamID := uint32(i + 1)
|
streamID := uint32(i + 1)
|
||||||
streams[streamID] = &legacyUDPStreamValue{
|
streams[streamID] = &legacyUDPStreamValue{Tuple: tuples[i]}
|
||||||
IPFlow: flows[i],
|
|
||||||
UDPFlow: udps[i].TransportFlow(),
|
|
||||||
}
|
|
||||||
keys = append(keys, streamID)
|
keys = append(keys, streamID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,15 +101,14 @@ func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) {
|
|||||||
if _, ok := streams[streamID]; ok {
|
if _, ok := streams[streamID]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ipFlow := flows[idx]
|
tuple := tuples[idx]
|
||||||
udpFlow := udps[idx].TransportFlow()
|
revTuple := reverseTuple(tuple)
|
||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
v, ok := streams[k]
|
v, ok := streams[k]
|
||||||
if !ok || v == nil {
|
if !ok || v == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if (v.IPFlow == ipFlow && v.UDPFlow == udpFlow) ||
|
if v.Tuple == tuple || v.Tuple == revTuple {
|
||||||
(v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()) {
|
|
||||||
delete(streams, k)
|
delete(streams, k)
|
||||||
streams[streamID] = v
|
streams[streamID] = v
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package engine
|
package engine
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -9,8 +8,6 @@ import (
|
|||||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||||
|
|
||||||
"github.com/bwmarrin/snowflake"
|
"github.com/bwmarrin/snowflake"
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type countingRuleset struct {
|
type countingRuleset struct {
|
||||||
@@ -54,17 +51,17 @@ func TestUDPStreamManagerRebindsByTupleInO1Path(t *testing.T) {
|
|||||||
t.Fatalf("new manager: %v", err)
|
t.Fatalf("new manager: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
|
tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 50000, BPort: 443}
|
||||||
udp := &layers.UDP{SrcPort: 50000, DstPort: 443, BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}}
|
payload := []byte{0x01, 0x00, 0x00, 0x00}
|
||||||
|
|
||||||
ctx1 := &udpContext{Verdict: udpVerdictAccept}
|
ctx1 := &udpContext{Verdict: udpVerdictAccept}
|
||||||
mgr.MatchWithContext(100, ipFlow, udp, ctx1)
|
mgr.MatchWithContext(100, tuple, false, payload, ctx1)
|
||||||
if got := newCalls.Load(); got != 1 {
|
if got := newCalls.Load(); got != 1 {
|
||||||
t.Fatalf("new stream calls=%d want=1", got)
|
t.Fatalf("new stream calls=%d want=1", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx2 := &udpContext{Verdict: udpVerdictAccept}
|
ctx2 := &udpContext{Verdict: udpVerdictAccept}
|
||||||
mgr.MatchWithContext(200, ipFlow, udp, ctx2)
|
mgr.MatchWithContext(200, tuple, false, payload, ctx2)
|
||||||
if got := newCalls.Load(); got != 1 {
|
if got := newCalls.Load(); got != 1 {
|
||||||
t.Fatalf("expected stream reuse by tuple, new stream calls=%d want=1", got)
|
t.Fatalf("expected stream reuse by tuple, new stream calls=%d want=1", got)
|
||||||
}
|
}
|
||||||
|
|||||||
+11
-16
@@ -3,6 +3,7 @@ package engine
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.difuse.io/Difuse/Mellaris/io"
|
"git.difuse.io/Difuse/Mellaris/io"
|
||||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||||
@@ -119,10 +120,16 @@ func (w *worker) FeedBlocking(p *workerPacket) {
|
|||||||
func (w *worker) Run(ctx context.Context) {
|
func (w *worker) Run(ctx context.Context) {
|
||||||
w.logger.WorkerStart(w.id)
|
w.logger.WorkerStart(w.id)
|
||||||
defer w.logger.WorkerStop(w.id)
|
defer w.logger.WorkerStop(w.id)
|
||||||
|
|
||||||
|
tcpSweepTicker := time.NewTicker(1 * time.Minute)
|
||||||
|
defer tcpSweepTicker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
|
case <-tcpSweepTicker.C:
|
||||||
|
w.tcpFlowMgr.cleanupIdle(time.Now())
|
||||||
case wp := <-w.packetChan:
|
case wp := <-w.packetChan:
|
||||||
if wp == nil {
|
if wp == nil {
|
||||||
return
|
return
|
||||||
@@ -202,15 +209,6 @@ func (w *worker) handleIPPacket(wp *workerPacket, data []byte) (io.Verdict, []by
|
|||||||
func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) {
|
func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) {
|
||||||
ipSrc := l3.SrcIPAddr()
|
ipSrc := l3.SrcIPAddr()
|
||||||
ipDst := l3.DstIPAddr()
|
ipDst := l3.DstIPAddr()
|
||||||
endpointType := layers.EndpointIPv4
|
|
||||||
flowSrc := ipSrc.To4()
|
|
||||||
flowDst := ipDst.To4()
|
|
||||||
if l3.Version == 6 {
|
|
||||||
endpointType = layers.EndpointIPv6
|
|
||||||
flowSrc = ipSrc.To16()
|
|
||||||
flowDst = ipDst.To16()
|
|
||||||
}
|
|
||||||
ipFlow := gopacket.NewFlow(endpointType, flowSrc, flowDst)
|
|
||||||
|
|
||||||
if len(srcMAC) == 0 && w.macResolver != nil {
|
if len(srcMAC) == 0 && w.macResolver != nil {
|
||||||
srcMAC = w.macResolver.Resolve(ipSrc)
|
srcMAC = w.macResolver.Resolve(ipSrc)
|
||||||
@@ -221,12 +219,9 @@ func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []by
|
|||||||
SrcMAC: srcMAC,
|
SrcMAC: srcMAC,
|
||||||
DstMAC: dstMAC,
|
DstMAC: dstMAC,
|
||||||
}
|
}
|
||||||
// Temporarily set payload on a UDP layer so existing UDP handling works.
|
|
||||||
w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{
|
tuple := canonicalUDPTupleKey(ipSrc, ipDst, udp.SrcPort, udp.DstPort)
|
||||||
BaseLayer: layers.BaseLayer{Payload: payload},
|
w.udpSM.MatchWithContext(streamID, tuple, false, payload, uc)
|
||||||
SrcPort: layers.UDPPort(udp.SrcPort),
|
|
||||||
DstPort: layers.UDPPort(udp.DstPort),
|
|
||||||
}, uc)
|
|
||||||
return io.Verdict(uc.Verdict), uc.Packet
|
return io.Verdict(uc.Verdict), uc.Packet
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -253,7 +248,7 @@ func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, modPayload []b
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return io.VerdictAccept, nil
|
return io.VerdictAccept, nil
|
||||||
}
|
}
|
||||||
return io.VerdictAcceptModify, w.modSerializeBuffer.Bytes()
|
return io.VerdictAcceptModify, append([]byte(nil), w.modSerializeBuffer.Bytes()...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractL3PayloadFromEthernet(data []byte) ([]byte, bool) {
|
func extractL3PayloadFromEthernet(data []byte) ([]byte, bool) {
|
||||||
|
|||||||
+1
-1
@@ -456,7 +456,7 @@ func ctIDFromCtBytes(ct []byte) uint32 {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
for _, attr := range ctAttrs {
|
for _, attr := range ctAttrs {
|
||||||
if attr.Type == 12 { // CTA_ID
|
if attr.Type == 12 && len(attr.Data) >= 4 { // CTA_ID
|
||||||
return binary.BigEndian.Uint32(attr.Data)
|
return binary.BigEndian.Uint32(attr.Data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
|
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
|
||||||
@@ -31,6 +32,7 @@ type V2GeoLoader struct {
|
|||||||
DownloadFunc func(filename, url string)
|
DownloadFunc func(filename, url string)
|
||||||
DownloadErrFunc func(err error)
|
DownloadErrFunc func(err error)
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
geoipMap map[string]*v2geo.GeoIP
|
geoipMap map[string]*v2geo.GeoIP
|
||||||
geositeMap map[string]*v2geo.GeoSite
|
geositeMap map[string]*v2geo.GeoSite
|
||||||
}
|
}
|
||||||
@@ -80,6 +82,8 @@ func (l *V2GeoLoader) download(filename, url string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *V2GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) {
|
func (l *V2GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
if l.geoipMap != nil {
|
if l.geoipMap != nil {
|
||||||
return l.geoipMap, nil
|
return l.geoipMap, nil
|
||||||
}
|
}
|
||||||
@@ -104,6 +108,8 @@ func (l *V2GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *V2GeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) {
|
func (l *V2GeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
if l.geositeMap != nil {
|
if l.geositeMap != nil {
|
||||||
return l.geositeMap, nil
|
return l.geositeMap, nil
|
||||||
}
|
}
|
||||||
|
|||||||
+30
-6
@@ -519,7 +519,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
|
|||||||
InitFunc: geoMatcher.LoadGeoIP,
|
InitFunc: geoMatcher.LoadGeoIP,
|
||||||
PatchFunc: nil,
|
PatchFunc: nil,
|
||||||
Func: func(params ...any) (any, error) {
|
Func: func(params ...any) (any, error) {
|
||||||
return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil
|
a, ok1 := params[0].(string)
|
||||||
|
b, ok2 := params[1].(string)
|
||||||
|
if !ok1 || !ok2 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return geoMatcher.MatchGeoIp(a, b), nil
|
||||||
},
|
},
|
||||||
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)},
|
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)},
|
||||||
},
|
},
|
||||||
@@ -527,7 +532,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
|
|||||||
InitFunc: geoMatcher.LoadGeoSite,
|
InitFunc: geoMatcher.LoadGeoSite,
|
||||||
PatchFunc: nil,
|
PatchFunc: nil,
|
||||||
Func: func(params ...any) (any, error) {
|
Func: func(params ...any) (any, error) {
|
||||||
return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil
|
a, ok1 := params[0].(string)
|
||||||
|
b, ok2 := params[1].(string)
|
||||||
|
if !ok1 || !ok2 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return geoMatcher.MatchGeoSite(a, b), nil
|
||||||
},
|
},
|
||||||
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
|
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
|
||||||
},
|
},
|
||||||
@@ -535,7 +545,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
|
|||||||
InitFunc: geoMatcher.LoadGeoSite,
|
InitFunc: geoMatcher.LoadGeoSite,
|
||||||
PatchFunc: nil,
|
PatchFunc: nil,
|
||||||
Func: func(params ...any) (any, error) {
|
Func: func(params ...any) (any, error) {
|
||||||
return geoMatcher.MatchGeoSiteSet(params[0].(string), params[1].(*geo.SiteConditionSet)), nil
|
a, ok1 := params[0].(string)
|
||||||
|
b, ok2 := params[1].(*geo.SiteConditionSet)
|
||||||
|
if !ok1 || !ok2 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return geoMatcher.MatchGeoSiteSet(a, b), nil
|
||||||
},
|
},
|
||||||
Types: []reflect.Type{
|
Types: []reflect.Type{
|
||||||
reflect.TypeOf((func(string, *geo.SiteConditionSet) bool)(nil)),
|
reflect.TypeOf((func(string, *geo.SiteConditionSet) bool)(nil)),
|
||||||
@@ -556,7 +571,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
Func: func(params ...any) (any, error) {
|
Func: func(params ...any) (any, error) {
|
||||||
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil
|
a, ok1 := params[0].(string)
|
||||||
|
b, ok2 := params[1].(*net.IPNet)
|
||||||
|
if !ok1 || !ok2 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return builtins.MatchCIDR(a, b), nil
|
||||||
},
|
},
|
||||||
Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)},
|
Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)},
|
||||||
},
|
},
|
||||||
@@ -565,7 +585,6 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
|
|||||||
PatchFunc: func(args *[]ast.Node) error {
|
PatchFunc: func(args *[]ast.Node) error {
|
||||||
var serverStr *ast.StringNode
|
var serverStr *ast.StringNode
|
||||||
if len(*args) > 1 {
|
if len(*args) > 1 {
|
||||||
// Has the optional server argument
|
|
||||||
var ok bool
|
var ok bool
|
||||||
serverStr, ok = (*args)[1].(*ast.StringNode)
|
serverStr, ok = (*args)[1].(*ast.StringNode)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -595,9 +614,14 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
|
|||||||
stats.LookupLatencyNanos.Add(uint64(time.Since(start).Nanoseconds()))
|
stats.LookupLatencyNanos.Add(uint64(time.Since(start).Nanoseconds()))
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
a, ok1 := params[0].(string)
|
||||||
|
b, ok2 := params[1].(*net.Resolver)
|
||||||
|
if !ok1 || !ok2 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
out, err := params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
|
out, err := b.LookupHost(ctx, a)
|
||||||
if err != nil && stats != nil {
|
if err != nil && stats != nil {
|
||||||
stats.LookupErrors.Add(1)
|
stats.LookupErrors.Add(1)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user