Compare commits

...

8 Commits

24 changed files with 1654 additions and 320 deletions
+2 -2
View File
@@ -130,8 +130,8 @@ func (s *httpStream) parseResponseLine() utils.LSMAction {
return utils.LSMActionCancel return utils.LSMActionCancel
} }
version := fields[0] version := fields[0]
status, _ := strconv.Atoi(fields[1]) status, err := strconv.Atoi(fields[1])
if !strings.HasPrefix(version, "HTTP/") || status == 0 { if err != nil || !strings.HasPrefix(version, "HTTP/") || status == 0 {
// Invalid version // Invalid version
return utils.LSMActionCancel return utils.LSMActionCancel
} }
+14 -4
View File
@@ -6,6 +6,8 @@ import (
"git.difuse.io/Difuse/Mellaris/analyzer/utils" "git.difuse.io/Difuse/Mellaris/analyzer/utils"
) )
const maxHandshakeLen = 65536
var _ analyzer.TCPAnalyzer = (*TLSAnalyzer)(nil) var _ analyzer.TCPAnalyzer = (*TLSAnalyzer)(nil)
type TLSAnalyzer struct{} type TLSAnalyzer struct{}
@@ -30,12 +32,14 @@ type tlsStream struct {
reqUpdated bool reqUpdated bool
reqLSM *utils.LinearStateMachine reqLSM *utils.LinearStateMachine
reqDone bool reqDone bool
reqFed int
respBuf *utils.ByteBuffer respBuf *utils.ByteBuffer
respMap analyzer.PropMap respMap analyzer.PropMap
respUpdated bool respUpdated bool
respLSM *utils.LinearStateMachine respLSM *utils.LinearStateMachine
respDone bool respDone bool
respFed int
clientHelloLen int clientHelloLen int
serverHelloLen int serverHelloLen int
@@ -64,7 +68,10 @@ func (s *tlsStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyz
var update *analyzer.PropUpdate var update *analyzer.PropUpdate
var cancelled bool var cancelled bool
if rev { if rev {
s.respBuf.Append(data) if len(data) > s.respFed {
s.respBuf.Append(data[s.respFed:])
s.respFed = len(data)
}
s.respUpdated = false s.respUpdated = false
cancelled, s.respDone = s.respLSM.Run() cancelled, s.respDone = s.respLSM.Run()
if s.respUpdated { if s.respUpdated {
@@ -75,7 +82,10 @@ func (s *tlsStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyz
s.respUpdated = false s.respUpdated = false
} }
} else { } else {
s.reqBuf.Append(data) if len(data) > s.reqFed {
s.reqBuf.Append(data[s.reqFed:])
s.reqFed = len(data)
}
s.reqUpdated = false s.reqUpdated = false
cancelled, s.reqDone = s.reqLSM.Run() cancelled, s.reqDone = s.reqLSM.Run()
if s.reqUpdated { if s.reqUpdated {
@@ -115,7 +125,7 @@ func (s *tlsStream) tlsClientHelloPreprocess() utils.LSMAction {
} }
s.clientHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8]) s.clientHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8])
if s.clientHelloLen < minDataSize { if s.clientHelloLen < minDataSize || s.clientHelloLen > maxHandshakeLen {
return utils.LSMActionCancel return utils.LSMActionCancel
} }
@@ -159,7 +169,7 @@ func (s *tlsStream) tlsServerHelloPreprocess() utils.LSMAction {
} }
s.serverHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8]) s.serverHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8])
if s.serverHelloLen < minDataSize { if s.serverHelloLen < minDataSize || s.serverHelloLen > maxHandshakeLen {
return utils.LSMActionCancel return utils.LSMActionCancel
} }
+3 -2
View File
@@ -38,6 +38,7 @@ const (
OpenVPNMinPktLen = 6 OpenVPNMinPktLen = 6
OpenVPNTCPPktDefaultLimit = 256 OpenVPNTCPPktDefaultLimit = 256
OpenVPNUDPPktDefaultLimit = 256 OpenVPNUDPPktDefaultLimit = 256
OpenVPNTCPMaxPktLen = 4096
) )
type OpenVPNAnalyzer struct{} type OpenVPNAnalyzer struct{}
@@ -195,7 +196,7 @@ func newOpenVPNUDPStream(logger analyzer.Logger) *openvpnUDPStream {
} }
func (o *openvpnUDPStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, d bool) { func (o *openvpnUDPStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, d bool) {
if len(data) == 0 { if len(data) < OpenVPNMinPktLen {
return nil, false return nil, false
} }
var update *analyzer.PropUpdate var update *analyzer.PropUpdate
@@ -338,7 +339,7 @@ func (o *openvpnTCPStream) parsePkt(rev bool) (p *openvpnPkt, action utils.LSMAc
return nil, utils.LSMActionPause return nil, utils.LSMActionPause
} }
if pktLen < OpenVPNMinPktLen { if pktLen < OpenVPNMinPktLen || pktLen > OpenVPNTCPMaxPktLen {
return nil, utils.LSMActionCancel return nil, utils.LSMActionCancel
} }
+4
View File
@@ -14,6 +14,7 @@ import (
const ( const (
quicInvalidCountThreshold = 16 quicInvalidCountThreshold = 16
quicMaxCryptoDataLen = 256 * 1024 quicMaxCryptoDataLen = 256 * 1024
quicMaxFrameEntries = 100
) )
var ( var (
@@ -158,6 +159,9 @@ func (s *quicStream) mergeFrame(offset int64, data []byte) {
if len(data) == 0 || offset < 0 { if len(data) == 0 || offset < 0 {
return return
} }
if len(s.frames) >= quicMaxFrameEntries {
return
}
if s.frames == nil { if s.frames == nil {
s.frames = make(map[int64][]byte) s.frames = make(map[int64][]byte)
} }
+44 -10
View File
@@ -5,6 +5,7 @@ import (
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset" "git.difuse.io/Difuse/Mellaris/ruleset"
@@ -13,13 +14,20 @@ import (
var _ Engine = (*engine)(nil) var _ Engine = (*engine)(nil)
type verdictEntry struct { type verdictEntry struct {
Verdict io.Verdict Verdict io.Verdict
Gen int64 Gen int64
CreatedAt time.Time
} }
const (
verdictTTL = 15 * time.Second
verdictSweepInterval = 15 * time.Second
)
type engine struct { type engine struct {
logger Logger logger Logger
io io.PacketIO io io.PacketIO
macResolver *sourceMACResolver
workers []*worker workers []*worker
stats *statsCounters stats *statsCounters
verdicts sync.Map // streamID(uint32) -> verdictEntry verdicts sync.Map // streamID(uint32) -> verdictEntry
@@ -73,6 +81,7 @@ func NewEngine(config Config) (Engine, error) {
e := &engine{ e := &engine{
logger: config.Logger, logger: config.Logger,
io: config.IO, io: config.IO,
macResolver: macResolver,
workers: workers, workers: workers,
stats: stats, stats: stats,
overflowPolicy: overflowPolicy, overflowPolicy: overflowPolicy,
@@ -83,7 +92,6 @@ func NewEngine(config Config) (Engine, error) {
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error { func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
e.verdictsGen.Add(1) e.verdictsGen.Add(1)
e.verdicts = sync.Map{}
for _, w := range e.workers { for _, w := range e.workers {
if err := w.UpdateRuleset(r); err != nil { if err := w.UpdateRuleset(r); err != nil {
return err return err
@@ -99,7 +107,11 @@ func (e *engine) Run(ctx context.Context) error {
for _, w := range e.workers { for _, w := range e.workers {
go w.Run(ioCtx) go w.Run(ioCtx)
} }
if e.macResolver != nil {
go e.macResolver.Run(ioCtx)
}
go e.drainResults(ioCtx) go e.drainResults(ioCtx)
go e.sweepVerdicts(ioCtx)
errChan := make(chan error, 1) errChan := make(chan error, 1)
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool { err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
@@ -124,11 +136,13 @@ func (e *engine) Run(ctx context.Context) error {
func (e *engine) dispatch(p io.Packet) bool { func (e *engine) dispatch(p io.Packet) bool {
streamID := p.StreamID() streamID := p.StreamID()
if v, ok := e.verdicts.Load(streamID); ok { if streamID != 0 {
entry := v.(verdictEntry) if v, ok := e.verdicts.Load(streamID); ok {
if entry.Gen == e.verdictsGen.Load() { entry := v.(verdictEntry)
_ = e.io.SetVerdict(p, entry.Verdict, nil) if entry.Gen == e.verdictsGen.Load() {
return true _ = e.io.SetVerdict(p, entry.Verdict, nil)
return true
}
} }
} }
@@ -163,12 +177,32 @@ func (e *engine) dispatch(p io.Packet) bool {
} }
func (e *engine) applyWorkerResult(r workerResult) { func (e *engine) applyWorkerResult(r workerResult) {
if r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream { if r.StreamID != 0 && (r.Verdict == io.VerdictAcceptStream || r.Verdict == io.VerdictDropStream) {
e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen}) e.verdicts.Store(r.StreamID, verdictEntry{Verdict: r.Verdict, Gen: r.Gen, CreatedAt: time.Now()})
} }
_ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket) _ = e.io.SetVerdict(r.Packet, r.Verdict, r.ModifiedPacket)
} }
func (e *engine) sweepVerdicts(ctx context.Context) {
ticker := time.NewTicker(verdictSweepInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
e.verdicts.Range(func(key, value interface{}) bool {
entry := value.(verdictEntry)
if now.Sub(entry.CreatedAt) > verdictTTL {
e.verdicts.Delete(key)
}
return true
})
}
}
}
func validPacket(data []byte) bool { func validPacket(data []byte) bool {
if len(data) == 0 { if len(data) == 0 {
return false return false
+30 -41
View File
@@ -5,6 +5,7 @@ package engine
import ( import (
"bufio" "bufio"
"context"
"net" "net"
"os" "os"
"os/exec" "os/exec"
@@ -52,38 +53,6 @@ func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr {
return nil return nil
} }
now := time.Now()
r.mu.RLock()
ifaceRefreshDue := now.Sub(r.lastIfaceRefresh) > ifaceCacheTTL
arpRefreshDue := now.Sub(r.lastARPRefresh) > arpCacheTTL
ndpRefreshDue := now.Sub(r.lastNDPRefresh) > ndpCacheTTL
if mac := r.ifaceByIP[ipKey]; len(mac) != 0 {
out := append(net.HardwareAddr(nil), mac...)
r.mu.RUnlock()
return out
}
if mac := r.arpByIP[ipKey]; len(mac) != 0 && !arpRefreshDue {
out := append(net.HardwareAddr(nil), mac...)
r.mu.RUnlock()
return out
}
if mac := r.ndpByIP[ipKey]; len(mac) != 0 && !ndpRefreshDue {
out := append(net.HardwareAddr(nil), mac...)
r.mu.RUnlock()
return out
}
r.mu.RUnlock()
if ifaceRefreshDue {
r.refreshIfaceCache(now)
}
if arpRefreshDue {
r.refreshARPCache(now)
}
if ndpRefreshDue {
r.refreshNDPCache(now)
}
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
if mac := r.ifaceByIP[ipKey]; len(mac) != 0 { if mac := r.ifaceByIP[ipKey]; len(mac) != 0 {
@@ -95,18 +64,38 @@ func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr {
if mac := r.ndpByIP[ipKey]; len(mac) != 0 { if mac := r.ndpByIP[ipKey]; len(mac) != 0 {
return append(net.HardwareAddr(nil), mac...) return append(net.HardwareAddr(nil), mac...)
} }
return nil
}
// On-demand IPv6 neighbor lookup via route-netlink as a last fast path. func (r *sourceMACResolver) Run(ctx context.Context) {
if ip.To4() == nil { r.refreshAll(time.Now())
if mac, ok := lookupNeighborMACNetlink(ip); ok { ticker := time.NewTicker(arpCacheTTL)
out := append(net.HardwareAddr(nil), mac...) defer ticker.Stop()
r.mu.Lock() for {
r.ndpByIP[ipKey] = append(net.HardwareAddr(nil), mac...) select {
r.mu.Unlock() case <-ctx.Done():
return out return
case now := <-ticker.C:
r.refreshAll(now)
} }
} }
return nil }
func (r *sourceMACResolver) refreshAll(now time.Time) {
r.mu.RLock()
ifaceRefreshDue := now.Sub(r.lastIfaceRefresh) > ifaceCacheTTL
arpRefreshDue := now.Sub(r.lastARPRefresh) > arpCacheTTL
ndpRefreshDue := now.Sub(r.lastNDPRefresh) > ndpCacheTTL
r.mu.RUnlock()
if ifaceRefreshDue {
r.refreshIfaceCache(now)
}
if arpRefreshDue {
r.refreshARPCache(now)
}
if ndpRefreshDue {
r.refreshNDPCache(now)
}
} }
func (r *sourceMACResolver) refreshIfaceCache(now time.Time) { func (r *sourceMACResolver) refreshIfaceCache(now time.Time) {
+8 -1
View File
@@ -3,7 +3,10 @@
package engine package engine
import "net" import (
"context"
"net"
)
type sourceMACResolver struct{} type sourceMACResolver struct{}
@@ -15,3 +18,7 @@ func (r *sourceMACResolver) Resolve(ip net.IP) net.HardwareAddr {
_ = ip _ = ip
return nil return nil
} }
func (r *sourceMACResolver) Run(ctx context.Context) {
<-ctx.Done()
}
+70
View File
@@ -0,0 +1,70 @@
package engine
import (
"context"
"net"
"testing"
"git.difuse.io/Difuse/Mellaris/io"
)
type recordingPacket struct {
streamID uint32
data []byte
}
func (p recordingPacket) StreamID() uint32 { return p.streamID }
func (p recordingPacket) Data() []byte { return p.data }
type recordingPacketIO struct {
verdict io.Verdict
}
func (r *recordingPacketIO) Register(context.Context, io.PacketCallback) error { return nil }
func (r *recordingPacketIO) SetVerdict(_ io.Packet, v io.Verdict, _ []byte) error {
r.verdict = v
return nil
}
func (r *recordingPacketIO) ProtectedDialContext(context.Context, string, string) (net.Conn, error) {
return nil, nil
}
func (r *recordingPacketIO) Close() error { return nil }
func TestEngineDefaultOverflowPolicyAccepts(t *testing.T) {
packetIO := &recordingPacketIO{}
eng, err := NewEngine(Config{
Logger: noopTestLogger{},
IO: packetIO,
Workers: 1,
WorkerQueueSize: 1,
})
if err != nil {
t.Fatalf("NewEngine error: %v", err)
}
e := eng.(*engine)
if e.overflowPolicy != OverflowPolicyAccept {
t.Fatalf("overflow policy=%v want=%v", e.overflowPolicy, OverflowPolicyAccept)
}
e.workers[0].packetChan <- &workerPacket{}
packet := recordingPacket{
streamID: 1,
data: serializeIPv6TCP(
t,
net.ParseIP("2001:db8::11").To16(),
net.ParseIP("2001:db8::22").To16(),
42310,
443,
1000,
),
}
e.dispatch(packet)
stats := e.Stats()
if packetIO.verdict != io.VerdictAccept {
t.Fatalf("overflow verdict=%v want=%v", packetIO.verdict, io.VerdictAccept)
}
if stats.OverflowEvents != 1 || stats.OverflowAccepts != 1 || stats.OverflowDrops != 0 {
t.Fatalf("overflow stats=%+v", stats)
}
}
+111 -19
View File
@@ -8,11 +8,24 @@ type L3Info struct {
IHL uint8 IHL uint8
SrcIP [4]byte SrcIP [4]byte
DstIP [4]byte DstIP [4]byte
SrcIPv6 [16]byte
DstIPv6 [16]byte
Length uint16 Length uint16
} }
func (i L3Info) SrcIPAddr() net.IP { return net.IP(i.SrcIP[:]) } func (i L3Info) SrcIPAddr() net.IP {
func (i L3Info) DstIPAddr() net.IP { return net.IP(i.DstIP[:]) } if i.Version == 6 {
return net.IP(i.SrcIPv6[:])
}
return net.IP(i.SrcIP[:])
}
func (i L3Info) DstIPAddr() net.IP {
if i.Version == 6 {
return net.IP(i.DstIPv6[:])
}
return net.IP(i.DstIP[:])
}
type TCPInfo struct { type TCPInfo struct {
SrcPort uint16 SrcPort uint16
@@ -32,29 +45,108 @@ type UDPInfo struct {
} }
func ParseL3(data []byte) (l3 L3Info, transport []byte, ok bool) { func ParseL3(data []byte) (l3 L3Info, transport []byte, ok bool) {
if len(data) < 20 { if len(data) < 1 {
return return
} }
version := data[0] >> 4 version := data[0] >> 4
if version != 4 { switch version {
case 4:
if len(data) < 20 {
return
}
ihl := data[0] & 0x0F
if ihl < 5 || len(data) < int(ihl)*4 {
return
}
totalLen := int(uint16(data[2])<<8 | uint16(data[3]))
if totalLen < int(ihl)*4 {
return
}
if totalLen > len(data) {
totalLen = len(data)
}
return L3Info{
Version: 4,
Protocol: data[9],
IHL: ihl,
Length: uint16(totalLen),
SrcIP: [4]byte{data[12], data[13], data[14], data[15]},
DstIP: [4]byte{data[16], data[17], data[18], data[19]},
}, data[ihl*4 : totalLen], true
case 6:
if len(data) < 40 {
return
}
payloadLen := int(uint16(data[4])<<8 | uint16(data[5]))
totalLen := 40 + payloadLen
if payloadLen == 0 || totalLen > len(data) {
totalLen = len(data)
}
protocol, tr, ipv6OK := parseIPv6Transport(data, data[6], totalLen)
if !ipv6OK {
return
}
var srcIP, dstIP [16]byte
copy(srcIP[:], data[8:24])
copy(dstIP[:], data[24:40])
return L3Info{
Version: 6,
Protocol: protocol,
SrcIPv6: srcIP,
DstIPv6: dstIP,
Length: uint16(totalLen),
}, tr, true
default:
return return
} }
ihl := data[0] & 0x0F }
if ihl < 5 || len(data) < int(ihl)*4 {
return func parseIPv6Transport(data []byte, nextHeader uint8, totalLen int) (protocol uint8, transport []byte, ok bool) {
offset := 40
proto := nextHeader
for {
if offset > totalLen {
return 0, nil, false
}
switch proto {
case 0, 43, 60: // Hop-by-hop options, Routing, Destination options
if offset+2 > totalLen {
return 0, nil, false
}
hdrLen := (int(data[offset+1]) + 1) * 8
if hdrLen < 8 || offset+hdrLen > totalLen {
return 0, nil, false
}
proto = data[offset]
offset += hdrLen
case 44: // Fragment
if offset+8 > totalLen {
return 0, nil, false
}
// Only first fragment carries L4 headers.
fragOffset := (uint16(data[offset+2])<<8 | uint16(data[offset+3])) >> 3
if fragOffset != 0 {
return 0, nil, false
}
proto = data[offset]
offset += 8
case 51: // Authentication Header
if offset+2 > totalLen {
return 0, nil, false
}
hdrLen := (int(data[offset+1]) + 2) * 4
if hdrLen < 8 || offset+hdrLen > totalLen {
return 0, nil, false
}
proto = data[offset]
offset += hdrLen
default:
if offset > totalLen {
return 0, nil, false
}
return proto, data[offset:totalLen], true
}
} }
totalLen := int(uint16(data[2])<<8 | uint16(data[3]))
if totalLen < int(ihl)*4 || totalLen > len(data) {
totalLen = len(data)
}
return L3Info{
Version: 4,
Protocol: data[9],
IHL: ihl,
Length: uint16(totalLen),
SrcIP: [4]byte{data[12], data[13], data[14], data[15]},
DstIP: [4]byte{data[16], data[17], data[18], data[19]},
}, data[ihl*4:totalLen], true
} }
func ParseTCP(transport []byte) (TCPInfo, []byte, bool) { func ParseTCP(transport []byte) (TCPInfo, []byte, bool) {
+127
View File
@@ -0,0 +1,127 @@
package engine
import (
"encoding/binary"
"net"
"testing"
"github.com/google/gopacket/layers"
)
func TestParseL3IPv6UDP(t *testing.T) {
src := net.ParseIP("2001:db8::10").To16()
dst := net.ParseIP("2001:db8::20").To16()
payload := []byte("hello")
pkt := buildIPv6UDPPacket(t, src, dst, 12345, 443, payload)
l3, transport, ok := ParseL3(pkt)
if !ok {
t.Fatal("ParseL3 should parse IPv6 packet")
}
if l3.Version != 6 {
t.Fatalf("version=%d want=6", l3.Version)
}
if l3.Protocol != 17 {
t.Fatalf("protocol=%d want=17 (udp)", l3.Protocol)
}
if !l3.SrcIPAddr().Equal(src) {
t.Fatalf("src=%v want=%v", l3.SrcIPAddr(), src)
}
if !l3.DstIPAddr().Equal(dst) {
t.Fatalf("dst=%v want=%v", l3.DstIPAddr(), dst)
}
udp, gotPayload, ok := ParseUDP(transport)
if !ok {
t.Fatal("ParseUDP should parse transport payload")
}
if udp.SrcPort != 12345 || udp.DstPort != 443 {
t.Fatalf("ports=%d->%d want=12345->443", udp.SrcPort, udp.DstPort)
}
if string(gotPayload) != string(payload) {
t.Fatalf("payload=%q want=%q", string(gotPayload), string(payload))
}
}
func TestParseL3IPv6HopByHopThenUDP(t *testing.T) {
src := net.ParseIP("2001:db8::1").To16()
dst := net.ParseIP("2001:db8::2").To16()
udpPayload := []byte("abc")
udpLen := 8 + len(udpPayload)
totalPayloadLen := 8 + udpLen // 8-byte hop-by-hop extension + udp packet
pkt := make([]byte, 40+totalPayloadLen)
pkt[0] = 0x60
binary.BigEndian.PutUint16(pkt[4:6], uint16(totalPayloadLen))
pkt[6] = 0 // Hop-by-hop
pkt[7] = 64 // Hop limit
copy(pkt[8:24], src)
copy(pkt[24:40], dst)
off := 40
pkt[off+0] = 17 // Next header: UDP
pkt[off+1] = 0 // Hdr Ext Len: (0+1)*8 = 8 bytes
off += 8
binary.BigEndian.PutUint16(pkt[off:off+2], 5353)
binary.BigEndian.PutUint16(pkt[off+2:off+4], 53)
binary.BigEndian.PutUint16(pkt[off+4:off+6], uint16(udpLen))
copy(pkt[off+8:], udpPayload)
l3, transport, ok := ParseL3(pkt)
if !ok {
t.Fatal("ParseL3 should parse IPv6 packet with extension headers")
}
if l3.Version != 6 || l3.Protocol != 17 {
t.Fatalf("version/protocol=%d/%d want=6/17", l3.Version, l3.Protocol)
}
udp, gotPayload, ok := ParseUDP(transport)
if !ok {
t.Fatal("ParseUDP should parse UDP after hop-by-hop extension")
}
if udp.SrcPort != 5353 || udp.DstPort != 53 {
t.Fatalf("ports=%d->%d want=5353->53", udp.SrcPort, udp.DstPort)
}
if string(gotPayload) != string(udpPayload) {
t.Fatalf("payload=%q want=%q", string(gotPayload), string(udpPayload))
}
}
func buildIPv6UDPPacket(t *testing.T, src, dst net.IP, srcPort, dstPort uint16, payload []byte) []byte {
t.Helper()
udpLen := 8 + len(payload)
pkt := make([]byte, 40+udpLen)
pkt[0] = 0x60
binary.BigEndian.PutUint16(pkt[4:6], uint16(udpLen))
pkt[6] = 17
pkt[7] = 64
copy(pkt[8:24], src.To16())
copy(pkt[24:40], dst.To16())
off := 40
binary.BigEndian.PutUint16(pkt[off:off+2], srcPort)
binary.BigEndian.PutUint16(pkt[off+2:off+4], dstPort)
binary.BigEndian.PutUint16(pkt[off+4:off+6], uint16(udpLen))
copy(pkt[off+8:], payload)
return pkt
}
func TestParseL3IPv6FragmentNonFirst(t *testing.T) {
src := net.ParseIP("2001:db8::a").To16()
dst := net.ParseIP("2001:db8::b").To16()
// IPv6 header + fragment header (offset != 0)
pkt := make([]byte, 48)
pkt[0] = 0x60
binary.BigEndian.PutUint16(pkt[4:6], 8)
pkt[6] = 44
pkt[7] = 64
copy(pkt[8:24], src)
copy(pkt[24:40], dst)
pkt[40] = uint8(layers.IPProtocolUDP)
// fragment offset in 8-byte units: 1 (non-first fragment)
binary.BigEndian.PutUint16(pkt[42:44], 1<<3)
_, _, ok := ParseL3(pkt)
if ok {
t.Fatal("ParseL3 should reject non-first IPv6 fragments for L4 parsing")
}
}
+183 -28
View File
@@ -1,7 +1,6 @@
package engine package engine
import ( import (
"net"
"testing" "testing"
"git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/analyzer"
@@ -9,8 +8,6 @@ import (
"git.difuse.io/Difuse/Mellaris/ruleset" "git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake" "github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
) )
type fixedRuleset struct { type fixedRuleset struct {
@@ -25,6 +22,89 @@ func (r fixedRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult {
return ruleset.MatchResult{Action: r.action} return ruleset.MatchResult{Action: r.action}
} }
type analyzerRuleset struct {
action ruleset.Action
ans []analyzer.Analyzer
}
func (r analyzerRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer {
return r.ans
}
func (r analyzerRuleset) Match(ruleset.StreamInfo) ruleset.MatchResult {
return ruleset.MatchResult{Action: r.action}
}
type countingTCPAnalyzer struct {
newCalls *int
feedCalls *int
}
func (a countingTCPAnalyzer) Name() string { return "tls" }
func (a countingTCPAnalyzer) Limit() int { return 0 }
func (a countingTCPAnalyzer) NewTCP(analyzer.TCPInfo, analyzer.Logger) analyzer.TCPStream {
(*a.newCalls)++
return countingTCPStream{feedCalls: a.feedCalls}
}
type countingTCPStream struct {
feedCalls *int
}
func (s countingTCPStream) Feed(bool, bool, bool, int, []byte) (*analyzer.PropUpdate, bool) {
(*s.feedCalls)++
return nil, false
}
func (s countingTCPStream) Close(bool) *analyzer.PropUpdate {
return nil
}
type logFinalizingRuleset struct {
ans []analyzer.Analyzer
}
func (r logFinalizingRuleset) Analyzers(ruleset.StreamInfo) []analyzer.Analyzer {
return r.ans
}
func (r logFinalizingRuleset) Match(info ruleset.StreamInfo) ruleset.MatchResult {
if _, ok := info.Props["tls"]; ok {
return ruleset.MatchResult{Action: ruleset.ActionMaybe, Logged: true}
}
return ruleset.MatchResult{Action: ruleset.ActionMaybe}
}
func (r logFinalizingRuleset) CanFinalizeAfterLog(ruleset.StreamInfo, []string) bool {
return true
}
type requestPropTCPAnalyzer struct {
closeCalls *int
}
func (a requestPropTCPAnalyzer) Name() string { return "tls" }
func (a requestPropTCPAnalyzer) Limit() int { return 0 }
func (a requestPropTCPAnalyzer) NewTCP(analyzer.TCPInfo, analyzer.Logger) analyzer.TCPStream {
return requestPropTCPStream{closeCalls: a.closeCalls}
}
type requestPropTCPStream struct {
closeCalls *int
}
func (s requestPropTCPStream) Feed(bool, bool, bool, int, []byte) (*analyzer.PropUpdate, bool) {
return &analyzer.PropUpdate{
Type: analyzer.PropUpdateMerge,
M: analyzer.PropMap{"req": analyzer.PropMap{"sni": "good.example"}},
}, false
}
func (s requestPropTCPStream) Close(bool) *analyzer.PropUpdate {
(*s.closeCalls)++
return nil
}
type noopTestLogger struct{} type noopTestLogger struct{}
func (noopTestLogger) WorkerStart(int) {} func (noopTestLogger) WorkerStart(int) {}
@@ -60,25 +140,19 @@ func TestUDPStreamUsesUpdatedRuleset(t *testing.T) {
Ruleset: fixedRuleset{action: ruleset.ActionAllow}, Ruleset: fixedRuleset{action: ruleset.ActionAllow},
} }
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 12345, BPort: 53}
udp := &layers.UDP{ payload := []byte("query")
SrcPort: 12345,
DstPort: 53,
BaseLayer: layers.BaseLayer{
Payload: []byte("query"),
},
}
ctx := &udpContext{Verdict: udpVerdictAccept} ctx := &udpContext{Verdict: udpVerdictAccept}
s := f.New(ipFlow, udp.TransportFlow(), udp, ctx) s := f.New(tuple, payload, ctx)
if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil { if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil {
t.Fatalf("update ruleset: %v", err) t.Fatalf("update ruleset: %v", err)
} }
if !s.Accept(udp, false, ctx) { if !s.Accept(false, ctx) {
t.Fatalf("unexpected Accept=false for virgin stream") t.Fatalf("unexpected Accept=false for virgin stream")
} }
s.Feed(udp, false, ctx) s.Feed(false, payload, ctx)
if ctx.Verdict != udpVerdictDropStream { if ctx.Verdict != udpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx.Verdict, udpVerdictDropStream) t.Fatalf("verdict=%v want=%v", ctx.Verdict, udpVerdictDropStream)
} }
@@ -96,21 +170,15 @@ func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
Ruleset: fixedRuleset{action: ruleset.ActionAllow}, Ruleset: fixedRuleset{action: ruleset.ActionAllow},
} }
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 12345, BPort: 53}
udp := &layers.UDP{ payload := []byte("query")
SrcPort: 12345,
DstPort: 53,
BaseLayer: layers.BaseLayer{
Payload: []byte("query"),
},
}
ctx1 := &udpContext{Verdict: udpVerdictAccept} ctx1 := &udpContext{Verdict: udpVerdictAccept}
s := f.New(ipFlow, udp.TransportFlow(), udp, ctx1) s := f.New(tuple, payload, ctx1)
if !s.Accept(udp, false, ctx1) { if !s.Accept(false, ctx1) {
t.Fatalf("unexpected Accept=false before first feed") t.Fatalf("unexpected Accept=false before first feed")
} }
s.Feed(udp, false, ctx1) s.Feed(false, payload, ctx1)
if ctx1.Verdict != udpVerdictAcceptStream { if ctx1.Verdict != udpVerdictAcceptStream {
t.Fatalf("verdict=%v want=%v", ctx1.Verdict, udpVerdictAcceptStream) t.Fatalf("verdict=%v want=%v", ctx1.Verdict, udpVerdictAcceptStream)
} }
@@ -120,16 +188,16 @@ func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
} }
ctx2 := &udpContext{Verdict: udpVerdictAccept} ctx2 := &udpContext{Verdict: udpVerdictAccept}
if !s.Accept(udp, false, ctx2) { if !s.Accept(false, ctx2) {
t.Fatalf("expected Accept=true after ruleset update") t.Fatalf("expected Accept=true after ruleset update")
} }
s.Feed(udp, false, ctx2) s.Feed(false, payload, ctx2)
if ctx2.Verdict != udpVerdictDropStream { if ctx2.Verdict != udpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx2.Verdict, udpVerdictDropStream) t.Fatalf("verdict=%v want=%v", ctx2.Verdict, udpVerdictDropStream)
} }
ctx3 := &udpContext{Verdict: udpVerdictAccept} ctx3 := &udpContext{Verdict: udpVerdictAccept}
if s.Accept(udp, false, ctx3) { if s.Accept(false, ctx3) {
t.Fatalf("expected Accept=false with unchanged ruleset and no active entries") t.Fatalf("expected Accept=false with unchanged ruleset and no active entries")
} }
if ctx3.Verdict != udpVerdictDropStream { if ctx3.Verdict != udpVerdictDropStream {
@@ -222,3 +290,90 @@ func TestTCPFlowReevaluatesAfterRulesetVersionChange(t *testing.T) {
t.Fatalf("cached verdict after update=%v want=%v", v, io.VerdictDropStream) t.Fatalf("cached verdict after update=%v want=%v", v, io.VerdictDropStream)
} }
} }
func TestTCPFlowDelaysAnalyzerCreationUntilPayload(t *testing.T) {
node, err := snowflake.NewNode(0)
if err != nil {
t.Fatalf("create node: %v", err)
}
newCalls := 0
feedCalls := 0
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{}))
mgr.updateRuleset(analyzerRuleset{
action: ruleset.ActionMaybe,
ans: []analyzer.Analyzer{countingTCPAnalyzer{
newCalls: &newCalls,
feedCalls: &feedCalls,
}},
}, 0)
l3 := L3Info{
Version: 4,
Protocol: 6,
SrcIP: [4]byte{10, 0, 0, 1},
DstIP: [4]byte{10, 0, 0, 2},
}
tcp := TCPInfo{
SrcPort: 12345,
DstPort: 443,
Seq: 100,
}
v := mgr.handle(1, l3, tcp, nil, nil, nil)
if v != io.VerdictAccept {
t.Fatalf("empty packet verdict=%v want=%v", v, io.VerdictAccept)
}
if newCalls != 0 || feedCalls != 0 {
t.Fatalf("empty packet created/feed analyzer: new=%d feed=%d", newCalls, feedCalls)
}
tcp.Seq = 101
v = mgr.handle(1, l3, tcp, []byte{0x16, 0x03, 0x01}, nil, nil)
if v != io.VerdictAccept {
t.Fatalf("payload verdict=%v want=%v", v, io.VerdictAccept)
}
if newCalls != 1 || feedCalls != 1 {
t.Fatalf("payload should create/feed analyzer once: new=%d feed=%d", newCalls, feedCalls)
}
}
func TestTCPFlowFinalizesAfterLogClassification(t *testing.T) {
node, err := snowflake.NewNode(0)
if err != nil {
t.Fatalf("create node: %v", err)
}
closeCalls := 0
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node, newAnalyzerSelector(AnalyzerSelectionModeSignature, &statsCounters{}))
mgr.updateRuleset(logFinalizingRuleset{
ans: []analyzer.Analyzer{requestPropTCPAnalyzer{closeCalls: &closeCalls}},
}, 0)
l3 := L3Info{
Version: 4,
Protocol: 6,
SrcIP: [4]byte{10, 0, 0, 1},
DstIP: [4]byte{10, 0, 0, 2},
}
tcp := TCPInfo{
SrcPort: 12345,
DstPort: 443,
Seq: 100,
}
v := mgr.handle(1, l3, tcp, nil, nil, nil)
if v != io.VerdictAccept {
t.Fatalf("empty packet verdict=%v want=%v", v, io.VerdictAccept)
}
tcp.Seq = 101
v = mgr.handle(1, l3, tcp, []byte{0x16, 0x03, 0x01}, nil, nil)
if v != io.VerdictAcceptStream {
t.Fatalf("payload verdict=%v want=%v", v, io.VerdictAcceptStream)
}
if closeCalls != 1 {
t.Fatalf("expected analyzer to be closed once after finalization, got %d", closeCalls)
}
if _, ok := mgr.flows[1]; ok {
t.Fatal("expected finalized TCP flow to be removed from manager")
}
}
+101 -13
View File
@@ -3,6 +3,7 @@ package engine
import ( import (
"net" "net"
"sync" "sync"
"time"
"git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/io"
@@ -13,6 +14,8 @@ import (
const tcpFlowMaxBuffer = 16384 const tcpFlowMaxBuffer = 16384
const tcpFlowIdleTimeout = 10 * time.Minute
type tcpFlowDirection uint8 type tcpFlowDirection uint8
const ( const (
@@ -22,10 +25,10 @@ const (
type tcpFlow struct { type tcpFlow struct {
streamID uint32 streamID uint32
srcIP [4]byte
dstIP [4]byte
srcPort uint16 srcPort uint16
dstPort uint16 dstPort uint16
srcIP net.IP
dstIP net.IP
dirSeq [2]uint32 dirSeq [2]uint32
dirBuf [2][]byte dirBuf [2][]byte
@@ -39,6 +42,10 @@ type tcpFlow struct {
doneEntries []*tcpFlowEntry doneEntries []*tcpFlowEntry
lastVerdict io.Verdict lastVerdict io.Verdict
feedCalled [2]bool feedCalled [2]bool
lastSeen time.Time
pendingAnalyzers []analyzer.Analyzer
selector *analyzerSelector
} }
type tcpFlowEntry struct { type tcpFlowEntry struct {
@@ -52,7 +59,7 @@ func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
rs, version := f.currentRuleset() rs, version := f.currentRuleset()
rulesetChanged := version != f.rulesetVersion rulesetChanged := version != f.rulesetVersion
if !f.virgin && !rulesetChanged && len(f.activeEntries) == 0 { if !f.virgin && !rulesetChanged && len(f.activeEntries) == 0 && !f.hasPendingAnalyzers() {
return f.lastVerdict return f.lastVerdict
} }
@@ -66,19 +73,23 @@ func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
propUpdated := false propUpdated := false
if len(payload) > 0 { if len(payload) > 0 {
dir, rev := f.resolveDirection(tcp) dir, rev := f.resolveDirection(tcp)
if len(f.pendingAnalyzers) > 0 {
f.initPendingAnalyzers(payload)
}
expected := f.dirSeq[dir] expected := f.dirSeq[dir]
if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected { if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected {
f.feedCalled[dir] = true f.feedCalled[dir] = true
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
if len(f.dirBuf[dir]) <= tcpFlowMaxBuffer { if len(f.dirBuf[dir]) <= tcpFlowMaxBuffer {
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
propUpdated = f.feedAnalyzers(rev) propUpdated = f.feedAnalyzers(rev)
} }
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
} }
} }
f.runMatch(rs, version, rulesetChanged, propUpdated) f.runMatch(rs, version, rulesetChanged, propUpdated)
f.maybeFinalizeVerdict() f.maybeFinalizeVerdict()
f.lastSeen = time.Now()
return f.lastVerdict return f.lastVerdict
} }
@@ -105,6 +116,52 @@ func (f *tcpFlow) feedAnalyzers(rev bool) bool {
return updated return updated
} }
func (f *tcpFlow) initPendingAnalyzers(payload []byte) {
baseAns := f.pendingAnalyzers
f.pendingAnalyzers = nil
if f.selector != nil {
baseAns = f.selector.SelectTCP(baseAns, payload)
}
ans := analyzersToTCPAnalyzers(baseAns)
if len(ans) == 0 {
return
}
entries := make([]*tcpFlowEntry, 0, len(ans))
for _, a := range ans {
entries = append(entries, &tcpFlowEntry{
Name: a.Name(),
Stream: a.NewTCP(analyzer.TCPInfo{
SrcIP: f.srcIP,
DstIP: f.dstIP,
SrcPort: f.srcPort,
DstPort: f.dstPort,
}, &analyzerLogger{
StreamID: f.info.ID,
Name: a.Name(),
Logger: f.logger,
}),
HasLimit: a.Limit() > 0,
Quota: a.Limit(),
})
}
f.activeEntries = append(f.activeEntries, entries...)
}
func (f *tcpFlow) hasPendingAnalyzers() bool {
return len(f.pendingAnalyzers) > 0
}
func (f *tcpFlow) analyzerNames() []string {
names := make([]string, 0, len(f.activeEntries)+len(f.pendingAnalyzers))
for _, entry := range f.activeEntries {
names = append(names, entry.Name)
}
for _, a := range f.pendingAnalyzers {
names = append(names, a.Name())
}
return names
}
func (f *tcpFlow) runMatch(rs ruleset.Ruleset, version uint64, rulesetChanged bool, propUpdated bool) { func (f *tcpFlow) runMatch(rs ruleset.Ruleset, version uint64, rulesetChanged bool, propUpdated bool) {
if !propUpdated && !f.virgin && !rulesetChanged { if !propUpdated && !f.virgin && !rulesetChanged {
return return
@@ -122,11 +179,15 @@ func (f *tcpFlow) runMatch(rs ruleset.Ruleset, version uint64, rulesetChanged bo
f.lastVerdict = verdict f.lastVerdict = verdict
f.closeActiveEntries() f.closeActiveEntries()
f.logger.TCPStreamAction(f.info, action, false) f.logger.TCPStreamAction(f.info, action, false)
} else if result.Logged && canFinalizeAfterLog(rs, f.info, f.analyzerNames()) {
f.lastVerdict = io.VerdictAcceptStream
f.closeActiveEntries()
f.logger.TCPStreamAction(f.info, ruleset.ActionAllow, true)
} }
} }
func (f *tcpFlow) maybeFinalizeVerdict() { func (f *tcpFlow) maybeFinalizeVerdict() {
if len(f.activeEntries) == 0 && f.lastVerdict == io.VerdictAccept { if len(f.activeEntries) == 0 && !f.hasPendingAnalyzers() && f.lastVerdict == io.VerdictAccept {
f.lastVerdict = io.VerdictAcceptStream f.lastVerdict = io.VerdictAcceptStream
f.logger.TCPStreamAction(f.info, ruleset.ActionAllow, true) f.logger.TCPStreamAction(f.info, ruleset.ActionAllow, true)
} }
@@ -203,8 +264,8 @@ func (m *tcpFlowManager) handle(streamID uint32, l3 L3Info, tcp TCPInfo, payload
func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) *tcpFlow { func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) *tcpFlow {
id := m.sfNode.Generate() id := m.sfNode.Generate()
ipSrc := net.IP(l3.SrcIP[:]) ipSrc := l3.SrcIPAddr()
ipDst := net.IP(l3.DstIP[:]) ipDst := l3.DstIPAddr()
if len(srcMAC) == 0 && m.macResolver != nil { if len(srcMAC) == 0 && m.macResolver != nil {
srcMAC = m.macResolver.Resolve(ipSrc) srcMAC = m.macResolver.Resolve(ipSrc)
} }
@@ -220,12 +281,18 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
Props: make(analyzer.CombinedPropMap), Props: make(analyzer.CombinedPropMap),
} }
m.logger.TCPStreamNew(m.workerID, info) m.logger.TCPStreamNew(m.workerID, info)
rs, version := m.rulesetSource() var rs ruleset.Ruleset
var version uint64
if m.rulesetSource != nil {
rs, version = m.rulesetSource()
}
var ans []analyzer.TCPAnalyzer var ans []analyzer.TCPAnalyzer
if rs != nil { if rs != nil {
baseAns := rs.Analyzers(info) baseAns := rs.Analyzers(info)
baseAns = m.selector.SelectTCP(baseAns, payload) if len(payload) > 0 {
ans = analyzersToTCPAnalyzers(baseAns) baseAns = m.selector.SelectTCP(baseAns, payload)
ans = analyzersToTCPAnalyzers(baseAns)
}
} }
entries := make([]*tcpFlowEntry, 0, len(ans)) entries := make([]*tcpFlowEntry, 0, len(ans))
for _, a := range ans { for _, a := range ans {
@@ -248,10 +315,10 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
flow := &tcpFlow{ flow := &tcpFlow{
streamID: streamID, streamID: streamID,
srcIP: l3.SrcIP,
dstIP: l3.DstIP,
srcPort: tcp.SrcPort, srcPort: tcp.SrcPort,
dstPort: tcp.DstPort, dstPort: tcp.DstPort,
srcIP: ipSrc,
dstIP: ipDst,
info: info, info: info,
virgin: true, virgin: true,
logger: m.logger, logger: m.logger,
@@ -259,6 +326,11 @@ func (m *tcpFlowManager) createFlow(streamID uint32, l3 L3Info, tcp TCPInfo, pay
rulesetVersion: version, rulesetVersion: version,
activeEntries: entries, activeEntries: entries,
lastVerdict: io.VerdictAccept, lastVerdict: io.VerdictAccept,
lastSeen: time.Now(),
selector: m.selector,
}
if len(payload) == 0 && rs != nil {
flow.pendingAnalyzers = rs.Analyzers(info)
} }
flow.dirSeq[tcpDirC2S] = tcp.Seq + 1 flow.dirSeq[tcpDirC2S] = tcp.Seq + 1
return flow return flow
@@ -270,6 +342,17 @@ func (m *tcpFlowManager) updateRuleset(r ruleset.Ruleset, version uint64) {
} }
} }
func (m *tcpFlowManager) cleanupIdle(now time.Time) {
m.mu.Lock()
defer m.mu.Unlock()
for id, flow := range m.flows {
if now.Sub(flow.lastSeen) > tcpFlowIdleTimeout {
flow.closeActiveEntries()
delete(m.flows, id)
}
}
}
func feedFlowEntry(entry *tcpFlowEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) { func feedFlowEntry(entry *tcpFlowEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
if !entry.HasLimit { if !entry.HasLimit {
update, done = entry.Stream.Feed(rev, true, false, 0, data) update, done = entry.Stream.Feed(rev, true, false, 0, data)
@@ -308,3 +391,8 @@ func actionToTCPVerdict(a ruleset.Action) io.Verdict {
return io.VerdictAcceptStream return io.VerdictAcceptStream
} }
} }
func canFinalizeAfterLog(rs ruleset.Ruleset, info ruleset.StreamInfo, activeAnalyzers []string) bool {
finalizer, ok := rs.(ruleset.LogFinalizer)
return ok && finalizer.CanFinalizeAfterLog(info, activeAnalyzers)
}
+116 -69
View File
@@ -2,6 +2,7 @@ package engine
import ( import (
"bytes" "bytes"
"container/list"
"errors" "errors"
"net" "net"
"sync" "sync"
@@ -12,9 +13,6 @@ import (
"git.difuse.io/Difuse/Mellaris/ruleset" "git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake" "github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
lru "github.com/hashicorp/golang-lru/v2"
) )
// udpVerdict is a subset of io.Verdict for UDP streams. // udpVerdict is a subset of io.Verdict for UDP streams.
@@ -49,9 +47,10 @@ type udpStreamFactory struct {
RulesetVersion uint64 RulesetVersion uint64
} }
func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) *udpStream { func (f *udpStreamFactory) New(k udpTupleKey, payload []byte, uc *udpContext) *udpStream {
id := f.Node.Generate() id := f.Node.Generate()
ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw()) ipSrc := net.IP(k.AIP[:k.ALen])
ipDst := net.IP(k.BIP[:k.BLen])
info := ruleset.StreamInfo{ info := ruleset.StreamInfo{
ID: id.Int64(), ID: id.Int64(),
Protocol: ruleset.ProtocolUDP, Protocol: ruleset.ProtocolUDP,
@@ -59,8 +58,8 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
DstMAC: append(net.HardwareAddr(nil), uc.DstMAC...), DstMAC: append(net.HardwareAddr(nil), uc.DstMAC...),
SrcIP: ipSrc, SrcIP: ipSrc,
DstIP: ipDst, DstIP: ipDst,
SrcPort: uint16(udp.SrcPort), SrcPort: k.APort,
DstPort: uint16(udp.DstPort), DstPort: k.BPort,
Props: make(analyzer.CombinedPropMap), Props: make(analyzer.CombinedPropMap),
} }
f.Logger.UDPStreamNew(f.WorkerID, info) f.Logger.UDPStreamNew(f.WorkerID, info)
@@ -69,11 +68,10 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
if rs != nil { if rs != nil {
baseAns := rs.Analyzers(info) baseAns := rs.Analyzers(info)
if f.Selector != nil { if f.Selector != nil {
baseAns = f.Selector.SelectUDP(baseAns, udp.Payload) baseAns = f.Selector.SelectUDP(baseAns, payload)
} }
ans = analyzersToUDPAnalyzers(baseAns) ans = analyzersToUDPAnalyzers(baseAns)
} }
// Create entries for each analyzer
entries := make([]*udpStreamEntry, 0, len(ans)) entries := make([]*udpStreamEntry, 0, len(ans))
for _, a := range ans { for _, a := range ans {
entries = append(entries, &udpStreamEntry{ entries = append(entries, &udpStreamEntry{
@@ -81,8 +79,8 @@ func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, u
Stream: a.NewUDP(analyzer.UDPInfo{ Stream: a.NewUDP(analyzer.UDPInfo{
SrcIP: ipSrc, SrcIP: ipSrc,
DstIP: ipDst, DstIP: ipDst,
SrcPort: uint16(udp.SrcPort), SrcPort: k.APort,
DstPort: uint16(udp.DstPort), DstPort: k.BPort,
}, &analyzerLogger{ }, &analyzerLogger{
StreamID: id.Int64(), StreamID: id.Int64(),
Name: a.Name(), Name: a.Name(),
@@ -118,16 +116,24 @@ func (f *udpStreamFactory) currentRuleset() (ruleset.Ruleset, uint64) {
type udpStreamManager struct { type udpStreamManager struct {
factory *udpStreamFactory factory *udpStreamFactory
streams *lru.Cache[uint32, *udpStreamValue] streams map[uint32]*list.Element
order *list.List
maxStreams int
tupleIndex map[udpTupleKey]uint32 tupleIndex map[udpTupleKey]uint32
streamTuples map[uint32]udpTupleKey streamTuples map[uint32]udpTupleKey
stats *statsCounters stats *statsCounters
} }
type udpStreamValue struct { type udpStreamValue struct {
Stream *udpStream StreamID uint32
IPFlow gopacket.Flow Stream *udpStream
UDPFlow gopacket.Flow Tuple udpTupleKey
}
func (v *udpStreamValue) Match(k udpTupleKey) (ok, rev bool) {
fwd := v.Tuple == k
rev = v.Tuple == reverseTuple(k)
return fwd || rev, rev
} }
type udpTupleKey struct { type udpTupleKey struct {
@@ -139,39 +145,28 @@ type udpTupleKey struct {
BPort uint16 BPort uint16
} }
func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) {
fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow
rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()
return fwd || rev, rev
}
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) { func newUDPStreamManager(factory *udpStreamFactory, maxStreams int, stats *statsCounters) (*udpStreamManager, error) {
if maxStreams <= 0 {
maxStreams = 1
}
m := &udpStreamManager{ m := &udpStreamManager{
factory: factory, factory: factory,
streams: make(map[uint32]*list.Element, maxStreams),
order: list.New(),
maxStreams: maxStreams,
tupleIndex: make(map[udpTupleKey]uint32, maxStreams), tupleIndex: make(map[udpTupleKey]uint32, maxStreams),
streamTuples: make(map[uint32]udpTupleKey, maxStreams), streamTuples: make(map[uint32]udpTupleKey, maxStreams),
stats: stats, stats: stats,
} }
ss, err := lru.NewWithEvict[uint32, *udpStreamValue](maxStreams, func(k uint32, v *udpStreamValue) {
m.removeTupleMappingLocked(k)
})
if err != nil {
return nil, err
}
m.streams = ss
return m, nil return m, nil
} }
func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) { func (m *udpStreamManager) MatchWithContext(streamID uint32, tuple udpTupleKey, rev bool, payload []byte, uc *udpContext) {
rev := false value, ok := m.get(streamID)
value, ok := m.streams.Get(streamID)
tuple := canonicalUDPTupleKey(ipFlow, udp)
if !ok { if !ok {
if m.stats != nil { if m.stats != nil {
m.stats.UDPTupleLookups.Add(1) m.stats.UDPTupleLookups.Add(1)
} }
// Conntrack IDs can change during early flow lifetime on some systems.
// Rebind by canonical 5-tuple in O(1).
matchedKey, found := m.tupleIndex[tuple] matchedKey, found := m.tupleIndex[tuple]
var matchedValue *udpStreamValue var matchedValue *udpStreamValue
var matchedRev bool var matchedRev bool
@@ -180,7 +175,7 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
m.stats.UDPTupleHits.Add(1) m.stats.UDPTupleHits.Add(1)
} }
var hasValue bool var hasValue bool
matchedValue, hasValue = m.streams.Get(matchedKey) matchedValue, hasValue = m.get(matchedKey)
if !hasValue || matchedValue == nil { if !hasValue || matchedValue == nil {
delete(m.tupleIndex, tuple) delete(m.tupleIndex, tuple)
delete(m.streamTuples, matchedKey) delete(m.streamTuples, matchedKey)
@@ -188,41 +183,88 @@ func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flo
} }
} }
if found { if found {
_, matchedRev = matchedValue.Match(ipFlow, udp.TransportFlow()) _, matchedRev = matchedValue.Match(tuple)
value = matchedValue value = matchedValue
rev = matchedRev rev = matchedRev
if matchedKey != streamID { if matchedKey != streamID {
m.streams.Remove(matchedKey) m.remove(matchedKey, false)
m.streams.Add(streamID, matchedValue) matchedValue.StreamID = streamID
m.add(streamID, matchedValue)
m.bindTupleLocked(streamID, tuple) m.bindTupleLocked(streamID, tuple)
} }
} else { } else {
// New stream
value = &udpStreamValue{ value = &udpStreamValue{
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc), StreamID: streamID,
IPFlow: ipFlow, Stream: m.factory.New(tuple, payload, uc),
UDPFlow: udp.TransportFlow(), Tuple: tuple,
} }
m.streams.Add(streamID, value) m.add(streamID, value)
m.bindTupleLocked(streamID, tuple) m.bindTupleLocked(streamID, tuple)
} }
} else { } else {
// Stream ID exists, but is it really the same stream? ok, rev = value.Match(tuple)
ok, rev = value.Match(ipFlow, udp.TransportFlow())
if !ok { if !ok {
// It's not - close the old stream & replace it with a new one
value.Stream.Close() value.Stream.Close()
value = &udpStreamValue{ value = &udpStreamValue{
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc), StreamID: streamID,
IPFlow: ipFlow, Stream: m.factory.New(tuple, payload, uc),
UDPFlow: udp.TransportFlow(), Tuple: tuple,
} }
m.streams.Add(streamID, value) m.add(streamID, value)
m.bindTupleLocked(streamID, tuple) m.bindTupleLocked(streamID, tuple)
} }
} }
if value.Stream.Accept(udp, rev, uc) { if value.Stream.Accept(rev, uc) {
value.Stream.Feed(udp, rev, uc) value.Stream.Feed(rev, payload, uc)
}
}
func (m *udpStreamManager) get(streamID uint32) (*udpStreamValue, bool) {
ele, ok := m.streams[streamID]
if !ok || ele == nil {
return nil, false
}
m.order.MoveToFront(ele)
value, ok := ele.Value.(*udpStreamValue)
return value, ok && value != nil
}
func (m *udpStreamManager) add(streamID uint32, value *udpStreamValue) {
if value == nil {
return
}
if existing, ok := m.streams[streamID]; ok {
existing.Value = value
m.order.MoveToFront(existing)
return
}
value.StreamID = streamID
m.streams[streamID] = m.order.PushFront(value)
for len(m.streams) > m.maxStreams {
back := m.order.Back()
if back == nil {
return
}
evicted, _ := back.Value.(*udpStreamValue)
if evicted == nil {
m.order.Remove(back)
continue
}
m.remove(evicted.StreamID, true)
}
}
func (m *udpStreamManager) remove(streamID uint32, closeStream bool) {
ele, ok := m.streams[streamID]
if !ok || ele == nil {
return
}
value, _ := ele.Value.(*udpStreamValue)
delete(m.streams, streamID)
m.order.Remove(ele)
m.removeTupleMappingLocked(streamID)
if closeStream && value != nil && value.Stream != nil {
value.Stream.Close()
} }
} }
@@ -242,25 +284,34 @@ func (m *udpStreamManager) removeTupleMappingLocked(streamID uint32) {
} }
} }
func canonicalUDPTupleKey(ipFlow gopacket.Flow, udp *layers.UDP) udpTupleKey { func canonicalUDPTupleKey(srcIP, dstIP net.IP, srcPort, dstPort uint16) udpTupleKey {
srcIP := ipFlow.Src().Raw() srcRaw := []byte(srcIP)
dstIP := ipFlow.Dst().Raw() dstRaw := []byte(dstIP)
srcPort := uint16(udp.SrcPort)
dstPort := uint16(udp.DstPort)
if compareIPEndpoint(srcIP, srcPort, dstIP, dstPort) > 0 { if compareIPEndpoint(srcRaw, srcPort, dstRaw, dstPort) > 0 {
srcIP, dstIP = dstIP, srcIP srcRaw, dstRaw = dstRaw, srcRaw
srcPort, dstPort = dstPort, srcPort srcPort, dstPort = dstPort, srcPort
} }
var key udpTupleKey var key udpTupleKey
key.ALen = uint8(copy(key.AIP[:], srcIP)) key.ALen = uint8(copy(key.AIP[:], srcRaw))
key.BLen = uint8(copy(key.BIP[:], dstIP)) key.BLen = uint8(copy(key.BIP[:], dstRaw))
key.APort = srcPort key.APort = srcPort
key.BPort = dstPort key.BPort = dstPort
return key return key
} }
func reverseTuple(k udpTupleKey) udpTupleKey {
var r udpTupleKey
r.ALen = k.BLen
r.BLen = k.ALen
r.AIP = k.BIP
r.BIP = k.AIP
r.APort = k.BPort
r.BPort = k.APort
return r
}
func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int { func compareIPEndpoint(aIP []byte, aPort uint16, bIP []byte, bPort uint16) int {
if len(aIP) != len(bIP) { if len(aIP) != len(bIP) {
if len(aIP) < len(bIP) { if len(aIP) < len(bIP) {
@@ -298,11 +349,8 @@ type udpStreamEntry struct {
Quota int Quota int
} }
func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool { func (s *udpStream) Accept(rev bool, uc *udpContext) bool {
if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() { if len(s.activeEntries) > 0 || s.virgin || s.rulesetChanged() {
// Make sure every stream matches against the ruleset at least once,
// even if there are no activeEntries, as the ruleset may have built-in
// properties that need to be matched.
return true return true
} else { } else {
uc.Verdict = s.lastVerdict uc.Verdict = s.lastVerdict
@@ -310,12 +358,11 @@ func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool {
} }
} }
func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) { func (s *udpStream) Feed(rev bool, payload []byte, uc *udpContext) {
updated := false updated := false
for i := len(s.activeEntries) - 1; i >= 0; i-- { for i := len(s.activeEntries) - 1; i >= 0; i-- {
// Important: reverse order so we can remove entries
entry := s.activeEntries[i] entry := s.activeEntries[i]
update, closeUpdate, done := s.feedEntry(entry, rev, udp.Payload) update, closeUpdate, done := s.feedEntry(entry, rev, payload)
up1 := processPropUpdate(s.info.Props, entry.Name, update) up1 := processPropUpdate(s.info.Props, entry.Name, update)
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate) up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
updated = updated || up1 || up2 updated = updated || up1 || up2
@@ -345,7 +392,7 @@ func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) {
action = ruleset.ActionMaybe action = ruleset.ActionMaybe
} else { } else {
var err error var err error
uc.Packet, err = udpMI.Process(udp.Payload) uc.Packet, err = udpMI.Process(payload)
if err != nil { if err != nil {
// Modifier error, fallback to maybe // Modifier error, fallback to maybe
s.logger.ModifyError(s.info, err) s.logger.ModifyError(s.info, err)
+26 -30
View File
@@ -1,20 +1,16 @@
package engine package engine
import ( import (
"net"
"testing" "testing"
"git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/ruleset" "git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake" "github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
) )
type legacyUDPStreamValue struct { type legacyUDPStreamValue struct {
IPFlow gopacket.Flow Tuple udpTupleKey
UDPFlow gopacket.Flow
} }
type emptyRuleset struct{} type emptyRuleset struct{}
@@ -36,17 +32,20 @@ func benchmarkUDPManager(b *testing.B, churn bool) {
} }
const flowCount = 20000 const flowCount = 20000
flows := make([]gopacket.Flow, flowCount) tuples := make([]udpTupleKey, flowCount)
udps := make([]*layers.UDP, flowCount) payloads := make([][]byte, flowCount)
for i := 0; i < flowCount; i++ { for i := 0; i < flowCount; i++ {
a := byte(i >> 8) a := byte(i >> 8)
c := byte(i) c := byte(i)
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4()) var t udpTupleKey
udps[i] = &layers.UDP{ t.AIP = [16]byte{10, a, 0, c}
SrcPort: layers.UDPPort(1024 + i%20000), t.ALen = 4
DstPort: layers.UDPPort(20000 + (i*7)%20000), t.BIP = [16]byte{172, 16, a, c}
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}, t.BLen = 4
} t.APort = 1024 + uint16(i%20000)
t.BPort = 20000 + uint16((i*7)%20000)
tuples[i] = t
payloads[i] = []byte{0x01, 0x00, 0x00, 0x00}
} }
ctx := &udpContext{Verdict: udpVerdictAccept} ctx := &udpContext{Verdict: udpVerdictAccept}
@@ -59,7 +58,7 @@ func benchmarkUDPManager(b *testing.B, churn bool) {
} }
ctx.Verdict = udpVerdictAccept ctx.Verdict = udpVerdictAccept
ctx.Packet = nil ctx.Packet = nil
mgr.MatchWithContext(streamID, flows[idx], udps[idx], ctx) mgr.MatchWithContext(streamID, tuples[idx], false, payloads[idx], ctx)
} }
} }
@@ -73,27 +72,25 @@ func BenchmarkUDPManagerMatchStreamIDChurn(b *testing.B) {
func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) { func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) {
const flowCount = 5000 const flowCount = 5000
flows := make([]gopacket.Flow, flowCount) tuples := make([]udpTupleKey, flowCount)
udps := make([]*layers.UDP, flowCount)
for i := 0; i < flowCount; i++ { for i := 0; i < flowCount; i++ {
a := byte(i >> 8) a := byte(i >> 8)
c := byte(i) c := byte(i)
flows[i] = gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, a, 0, c).To4(), net.IPv4(172, 16, a, c).To4()) var t udpTupleKey
udps[i] = &layers.UDP{ t.AIP = [16]byte{10, a, 0, c}
SrcPort: layers.UDPPort(1024 + i%20000), t.ALen = 4
DstPort: layers.UDPPort(20000 + (i*7)%20000), t.BIP = [16]byte{172, 16, a, c}
BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}, t.BLen = 4
} t.APort = 1024 + uint16(i%20000)
t.BPort = 20000 + uint16((i*7)%20000)
tuples[i] = t
} }
streams := make(map[uint32]*legacyUDPStreamValue, flowCount) streams := make(map[uint32]*legacyUDPStreamValue, flowCount)
keys := make([]uint32, 0, flowCount) keys := make([]uint32, 0, flowCount)
for i := 0; i < flowCount; i++ { for i := 0; i < flowCount; i++ {
streamID := uint32(i + 1) streamID := uint32(i + 1)
streams[streamID] = &legacyUDPStreamValue{ streams[streamID] = &legacyUDPStreamValue{Tuple: tuples[i]}
IPFlow: flows[i],
UDPFlow: udps[i].TransportFlow(),
}
keys = append(keys, streamID) keys = append(keys, streamID)
} }
@@ -104,15 +101,14 @@ func BenchmarkLegacyUDPFallbackScanChurn(b *testing.B) {
if _, ok := streams[streamID]; ok { if _, ok := streams[streamID]; ok {
continue continue
} }
ipFlow := flows[idx] tuple := tuples[idx]
udpFlow := udps[idx].TransportFlow() revTuple := reverseTuple(tuple)
for _, k := range keys { for _, k := range keys {
v, ok := streams[k] v, ok := streams[k]
if !ok || v == nil { if !ok || v == nil {
continue continue
} }
if (v.IPFlow == ipFlow && v.UDPFlow == udpFlow) || if v.Tuple == tuple || v.Tuple == revTuple {
(v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()) {
delete(streams, k) delete(streams, k)
streams[streamID] = v streams[streamID] = v
break break
+4 -7
View File
@@ -1,7 +1,6 @@
package engine package engine
import ( import (
"net"
"sync/atomic" "sync/atomic"
"testing" "testing"
@@ -9,8 +8,6 @@ import (
"git.difuse.io/Difuse/Mellaris/ruleset" "git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake" "github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
) )
type countingRuleset struct { type countingRuleset struct {
@@ -54,17 +51,17 @@ func TestUDPStreamManagerRebindsByTupleInO1Path(t *testing.T) {
t.Fatalf("new manager: %v", err) t.Fatalf("new manager: %v", err)
} }
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) tuple := udpTupleKey{AIP: [16]byte{10, 0, 0, 1}, BIP: [16]byte{10, 0, 0, 2}, ALen: 4, BLen: 4, APort: 50000, BPort: 443}
udp := &layers.UDP{SrcPort: 50000, DstPort: 443, BaseLayer: layers.BaseLayer{Payload: []byte{0x01, 0x00, 0x00, 0x00}}} payload := []byte{0x01, 0x00, 0x00, 0x00}
ctx1 := &udpContext{Verdict: udpVerdictAccept} ctx1 := &udpContext{Verdict: udpVerdictAccept}
mgr.MatchWithContext(100, ipFlow, udp, ctx1) mgr.MatchWithContext(100, tuple, false, payload, ctx1)
if got := newCalls.Load(); got != 1 { if got := newCalls.Load(); got != 1 {
t.Fatalf("new stream calls=%d want=1", got) t.Fatalf("new stream calls=%d want=1", got)
} }
ctx2 := &udpContext{Verdict: udpVerdictAccept} ctx2 := &udpContext{Verdict: udpVerdictAccept}
mgr.MatchWithContext(200, ipFlow, udp, ctx2) mgr.MatchWithContext(200, tuple, false, payload, ctx2)
if got := newCalls.Load(); got != 1 { if got := newCalls.Load(); got != 1 {
t.Fatalf("expected stream reuse by tuple, new stream calls=%d want=1", got) t.Fatalf("expected stream reuse by tuple, new stream calls=%d want=1", got)
} }
+86 -51
View File
@@ -3,6 +3,7 @@ package engine
import ( import (
"context" "context"
"net" "net"
"time"
"git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset" "git.difuse.io/Difuse/Mellaris/ruleset"
@@ -119,10 +120,16 @@ func (w *worker) FeedBlocking(p *workerPacket) {
func (w *worker) Run(ctx context.Context) { func (w *worker) Run(ctx context.Context) {
w.logger.WorkerStart(w.id) w.logger.WorkerStart(w.id)
defer w.logger.WorkerStop(w.id) defer w.logger.WorkerStop(w.id)
tcpSweepTicker := time.NewTicker(1 * time.Minute)
defer tcpSweepTicker.Stop()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-tcpSweepTicker.C:
w.tcpFlowMgr.cleanupIdle(time.Now())
case wp := <-w.packetChan: case wp := <-w.packetChan:
if wp == nil { if wp == nil {
return return
@@ -150,56 +157,58 @@ func (w *worker) handle(wp *workerPacket) (io.Verdict, []byte) {
return io.VerdictAccept, nil return io.VerdictAccept, nil
} }
ipVersion := data[0] >> 4 if v, b, ok := w.handleIPPacket(wp, data); ok {
if ipVersion == 4 { return v, b
l3, transport, ok := ParseL3(data)
if !ok {
return io.VerdictAccept, nil
}
switch l3.Protocol {
case 6: // TCP
tcp, payload, ok := ParseTCP(transport)
if !ok {
return io.VerdictAccept, nil
}
verdict := w.tcpFlowMgr.handle(
wp.StreamID, l3, tcp, payload,
wp.SrcMAC, wp.DstMAC,
)
return verdict, nil
case 17: // UDP
udp, payload, ok := ParseUDP(transport)
if !ok {
return io.VerdictAccept, nil
}
v, modPayload := w.handleUDP(
wp.StreamID, l3, udp, payload,
wp.SrcMAC, wp.DstMAC,
)
if v == io.VerdictAcceptModify && modPayload != nil {
return w.serializeModifiedUDP(data, l3, udp, transport, modPayload)
}
return v, nil
default:
return io.VerdictAccept, nil
}
} }
// Ethernet frame path (for custom PacketIO) // Ethernet frame fallback path (for custom PacketIO implementations).
if ipVersion == 6 { if l3Payload, ok := extractL3PayloadFromEthernet(data); ok {
// TODO: IPv6 support with raw parsing if v, b, ok := w.handleIPPacket(wp, l3Payload); ok {
return io.VerdictAccept, nil return v, b
}
} }
return io.VerdictAccept, nil return io.VerdictAccept, nil
} }
func (w *worker) handleIPPacket(wp *workerPacket, data []byte) (io.Verdict, []byte, bool) {
l3, transport, ok := ParseL3(data)
if !ok {
return io.VerdictAccept, nil, false
}
switch l3.Protocol {
case 6: // TCP
tcp, payload, ok := ParseTCP(transport)
if !ok {
return io.VerdictAccept, nil, true
}
verdict := w.tcpFlowMgr.handle(
wp.StreamID, l3, tcp, payload,
wp.SrcMAC, wp.DstMAC,
)
return verdict, nil, true
case 17: // UDP
udp, payload, ok := ParseUDP(transport)
if !ok {
return io.VerdictAccept, nil, true
}
v, modPayload := w.handleUDP(
wp.StreamID, l3, udp, payload,
wp.SrcMAC, wp.DstMAC,
)
if v == io.VerdictAcceptModify && modPayload != nil {
mv, mb := w.serializeModifiedUDP(data, l3, modPayload)
return mv, mb, true
}
return v, nil, true
default:
return io.VerdictAccept, nil, true
}
}
func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) { func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []byte, srcMAC, dstMAC net.HardwareAddr) (io.Verdict, []byte) {
ipSrc := net.IP(l3.SrcIP[:]) ipSrc := l3.SrcIPAddr()
ipDst := net.IP(l3.DstIP[:]) ipDst := l3.DstIPAddr()
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, ipSrc.To4(), ipDst.To4())
if len(srcMAC) == 0 && w.macResolver != nil { if len(srcMAC) == 0 && w.macResolver != nil {
srcMAC = w.macResolver.Resolve(ipSrc) srcMAC = w.macResolver.Resolve(ipSrc)
@@ -210,17 +219,18 @@ func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []by
SrcMAC: srcMAC, SrcMAC: srcMAC,
DstMAC: dstMAC, DstMAC: dstMAC,
} }
// Temporarily set payload on a UDP layer so existing UDP handling works.
w.udpSM.MatchWithContext(streamID, ipFlow, &layers.UDP{ tuple := canonicalUDPTupleKey(ipSrc, ipDst, udp.SrcPort, udp.DstPort)
BaseLayer: layers.BaseLayer{Payload: payload}, w.udpSM.MatchWithContext(streamID, tuple, false, payload, uc)
SrcPort: layers.UDPPort(udp.SrcPort),
DstPort: layers.UDPPort(udp.DstPort),
}, uc)
return io.Verdict(uc.Verdict), uc.Packet return io.Verdict(uc.Verdict), uc.Packet
} }
func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, udp UDPInfo, transport []byte, modPayload []byte) (io.Verdict, []byte) { func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, modPayload []byte) (io.Verdict, []byte) {
ipPkt := gopacket.NewPacket(fullData, layers.LayerTypeIPv4, gopacket.DecodeOptions{Lazy: true, NoCopy: true}) layerType := layers.LayerTypeIPv4
if l3.Version == 6 {
layerType = layers.LayerTypeIPv6
}
ipPkt := gopacket.NewPacket(fullData, layerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
netLayer := ipPkt.NetworkLayer() netLayer := ipPkt.NetworkLayer()
trLayer := ipPkt.TransportLayer() trLayer := ipPkt.TransportLayer()
if netLayer == nil || trLayer == nil { if netLayer == nil || trLayer == nil {
@@ -238,5 +248,30 @@ func (w *worker) serializeModifiedUDP(fullData []byte, l3 L3Info, udp UDPInfo, t
if err != nil { if err != nil {
return io.VerdictAccept, nil return io.VerdictAccept, nil
} }
return io.VerdictAcceptModify, w.modSerializeBuffer.Bytes() return io.VerdictAcceptModify, append([]byte(nil), w.modSerializeBuffer.Bytes()...)
}
func extractL3PayloadFromEthernet(data []byte) ([]byte, bool) {
if len(data) < 14 {
return nil, false
}
offset := 12
etherType := uint16(data[offset])<<8 | uint16(data[offset+1])
offset += 2
for etherType == 0x8100 || etherType == 0x88A8 {
if len(data) < offset+4 {
return nil, false
}
etherType = uint16(data[offset+2])<<8 | uint16(data[offset+3])
offset += 4
}
if etherType != 0x0800 && etherType != 0x86DD {
return nil, false
}
if len(data) <= offset {
return nil, false
}
return data[offset:], true
} }
+114
View File
@@ -0,0 +1,114 @@
package engine
import (
"net"
"testing"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
func TestWorkerHandleIPv6TCP(t *testing.T) {
w, err := newWorker(workerConfig{
ID: 0,
Logger: noopTestLogger{},
Ruleset: fixedRuleset{action: ruleset.ActionBlock},
ResultChan: make(chan workerResult, 1),
})
if err != nil {
t.Fatalf("new worker: %v", err)
}
src := net.ParseIP("2001:db8::11").To16()
dst := net.ParseIP("2001:db8::22").To16()
data := serializeIPv6TCP(t, src, dst, 42310, 443, 1000)
v, _ := w.handle(&workerPacket{
StreamID: 11,
Data: data,
})
if v != io.VerdictDropStream {
t.Fatalf("verdict=%v want=%v", v, io.VerdictDropStream)
}
}
func TestWorkerHandleIPv6UDP(t *testing.T) {
w, err := newWorker(workerConfig{
ID: 0,
Logger: noopTestLogger{},
Ruleset: fixedRuleset{action: ruleset.ActionBlock},
ResultChan: make(chan workerResult, 1),
})
if err != nil {
t.Fatalf("new worker: %v", err)
}
src := net.ParseIP("2001:db8::33").To16()
dst := net.ParseIP("2001:db8::44").To16()
data := serializeIPv6UDP(t, src, dst, 50000, 53, []byte("dns"))
v, _ := w.handle(&workerPacket{
StreamID: 12,
Data: data,
})
if v != io.VerdictDropStream {
t.Fatalf("verdict=%v want=%v", v, io.VerdictDropStream)
}
}
func serializeIPv6TCP(t *testing.T, src, dst net.IP, srcPort, dstPort uint16, seq uint32) []byte {
t.Helper()
ip6 := &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolTCP,
SrcIP: src,
DstIP: dst,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
Seq: seq,
SYN: true,
}
if err := tcp.SetNetworkLayerForChecksum(ip6); err != nil {
t.Fatalf("set tcp checksum network layer: %v", err)
}
buf := gopacket.NewSerializeBuffer()
if err := gopacket.SerializeLayers(buf, gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}, ip6, tcp); err != nil {
t.Fatalf("serialize ipv6 tcp: %v", err)
}
return append([]byte(nil), buf.Bytes()...)
}
func serializeIPv6UDP(t *testing.T, src, dst net.IP, srcPort, dstPort uint16, payload []byte) []byte {
t.Helper()
ip6 := &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
SrcIP: src,
DstIP: dst,
}
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
if err := udp.SetNetworkLayerForChecksum(ip6); err != nil {
t.Fatalf("set udp checksum network layer: %v", err)
}
buf := gopacket.NewSerializeBuffer()
if err := gopacket.SerializeLayers(buf, gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}, ip6, udp, gopacket.Payload(payload)); err != nil {
t.Fatalf("serialize ipv6 udp: %v", err)
}
return append([]byte(nil), buf.Bytes()...)
}
+6 -2
View File
@@ -58,7 +58,7 @@ func generateNftRules(local, rst bool, numQueues int) (*nftTableSpec, error) {
} }
} else { } else {
table.Chains = []nftChainSpec{ table.Chains = []nftChainSpec{
{Chain: "FORWARD", Header: "type filter hook forward priority filter; policy accept;"}, {Chain: "FORWARD", Header: "type filter hook forward priority mangle; policy accept;"},
} }
} }
for i := range table.Chains { for i := range table.Chains {
@@ -238,6 +238,9 @@ func (nio *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) err
return 0 return 0
} }
} }
if strings.Contains(e.Error(), "mismatched sequence") {
return 0
}
return okBoolToInt(cb(nil, e)) return okBoolToInt(cb(nil, e))
}) })
if err != nil { if err != nil {
@@ -346,6 +349,7 @@ func (nio *nfqueuePacketIO) setupIpt(local, rst, remove bool) error {
if remove { if remove {
err = iptsBatchDeleteIfExists([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules) err = iptsBatchDeleteIfExists([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules)
} else { } else {
_ = iptsBatchDeleteIfExists([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules)
err = iptsBatchAppendUnique([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules) err = iptsBatchAppendUnique([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules)
} }
return err return err
@@ -456,7 +460,7 @@ func ctIDFromCtBytes(ct []byte) uint32 {
return 0 return 0
} }
for _, attr := range ctAttrs { for _, attr := range ctAttrs {
if attr.Type == 12 { // CTA_ID if attr.Type == 12 && len(attr.Data) >= 4 { // CTA_ID
return binary.BigEndian.Uint32(attr.Data) return binary.BigEndian.Uint32(attr.Data)
} }
} }
+6
View File
@@ -4,6 +4,7 @@ import (
"io" "io"
"net/http" "net/http"
"os" "os"
"sync"
"time" "time"
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo" "git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
@@ -31,6 +32,7 @@ type V2GeoLoader struct {
DownloadFunc func(filename, url string) DownloadFunc func(filename, url string)
DownloadErrFunc func(err error) DownloadErrFunc func(err error)
mu sync.Mutex
geoipMap map[string]*v2geo.GeoIP geoipMap map[string]*v2geo.GeoIP
geositeMap map[string]*v2geo.GeoSite geositeMap map[string]*v2geo.GeoSite
} }
@@ -80,6 +82,8 @@ func (l *V2GeoLoader) download(filename, url string) error {
} }
func (l *V2GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) { func (l *V2GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) {
l.mu.Lock()
defer l.mu.Unlock()
if l.geoipMap != nil { if l.geoipMap != nil {
return l.geoipMap, nil return l.geoipMap, nil
} }
@@ -104,6 +108,8 @@ func (l *V2GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) {
} }
func (l *V2GeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) { func (l *V2GeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) {
l.mu.Lock()
defer l.mu.Unlock()
if l.geositeMap != nil { if l.geositeMap != nil {
return l.geositeMap, nil return l.geositeMap, nil
} }
+69 -21
View File
@@ -112,23 +112,35 @@ type geositeDomain struct {
} }
type geositeMatcher struct { type geositeMatcher struct {
Domains []geositeDomain Domains []geositeDomain // legacy slow path for tests and manual construction
Plain []geositeDomain
Regex []geositeDomain
Root map[string]geositeDomain
Full map[string]geositeDomain
// Attributes are matched using "and" logic - if you have multiple attributes here, // Attributes are matched using "and" logic - if you have multiple attributes here,
// a domain must have all of those attributes to be considered a match. // a domain must have all of those attributes to be considered a match.
Attrs []string Attrs []string
} }
func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool { func (m *geositeMatcher) attrsMatch(domain geositeDomain) bool {
// Match attributes first if len(m.Attrs) == 0 {
if len(m.Attrs) > 0 { return true
if len(domain.Attrs) == 0 { }
if len(domain.Attrs) == 0 {
return false
}
for _, attr := range m.Attrs {
if !domain.Attrs[attr] {
return false return false
} }
for _, attr := range m.Attrs { }
if !domain.Attrs[attr] { return true
return false }
}
} func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool {
// Match attributes first
if !m.attrsMatch(domain) {
return false
} }
switch domain.Type { switch domain.Type {
@@ -152,7 +164,35 @@ func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool {
} }
func (m *geositeMatcher) Match(host HostInfo) bool { func (m *geositeMatcher) Match(host HostInfo) bool {
for _, domain := range m.Domains { if host.Name == "" {
return false
}
if domain, ok := m.Full[host.Name]; ok && m.attrsMatch(domain) {
return true
}
for name := host.Name; name != ""; {
if domain, ok := m.Root[name]; ok && m.attrsMatch(domain) {
return true
}
idx := strings.IndexByte(name, '.')
if idx < 0 {
break
}
name = name[idx+1:]
}
for _, domain := range m.Plain {
if m.matchDomain(domain, host) {
return true
}
}
if len(m.Plain) == 0 && len(m.Regex) == 0 && len(m.Root) == 0 && len(m.Full) == 0 {
for _, domain := range m.Domains {
if m.matchDomain(domain, host) {
return true
}
}
}
for _, domain := range m.Regex {
if m.matchDomain(domain, host) { if m.matchDomain(domain, host) {
return true return true
} }
@@ -161,45 +201,53 @@ func (m *geositeMatcher) Match(host HostInfo) bool {
} }
func newGeositeMatcher(list *v2geo.GeoSite, attrs []string) (*geositeMatcher, error) { func newGeositeMatcher(list *v2geo.GeoSite, attrs []string) (*geositeMatcher, error) {
domains := make([]geositeDomain, len(list.Domain)) matcher := &geositeMatcher{
for i, domain := range list.Domain { Root: make(map[string]geositeDomain),
Full: make(map[string]geositeDomain),
Attrs: attrs,
}
for _, domain := range list.Domain {
var compiled geositeDomain
switch domain.Type { switch domain.Type {
case v2geo.Domain_Plain: case v2geo.Domain_Plain:
domains[i] = geositeDomain{ compiled = geositeDomain{
Type: geositeDomainPlain, Type: geositeDomainPlain,
Value: domain.Value, Value: domain.Value,
Attrs: domainAttributeToMap(domain.Attribute), Attrs: domainAttributeToMap(domain.Attribute),
} }
matcher.Plain = append(matcher.Plain, compiled)
case v2geo.Domain_Regex: case v2geo.Domain_Regex:
regex, err := regexp.Compile(domain.Value) regex, err := regexp.Compile(domain.Value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
domains[i] = geositeDomain{ compiled = geositeDomain{
Type: geositeDomainRegex, Type: geositeDomainRegex,
Value: domain.Value,
Regex: regex, Regex: regex,
Attrs: domainAttributeToMap(domain.Attribute), Attrs: domainAttributeToMap(domain.Attribute),
} }
matcher.Regex = append(matcher.Regex, compiled)
case v2geo.Domain_Full: case v2geo.Domain_Full:
domains[i] = geositeDomain{ compiled = geositeDomain{
Type: geositeDomainFull, Type: geositeDomainFull,
Value: domain.Value, Value: domain.Value,
Attrs: domainAttributeToMap(domain.Attribute), Attrs: domainAttributeToMap(domain.Attribute),
} }
matcher.Full[domain.Value] = compiled
case v2geo.Domain_RootDomain: case v2geo.Domain_RootDomain:
domains[i] = geositeDomain{ compiled = geositeDomain{
Type: geositeDomainRoot, Type: geositeDomainRoot,
Value: domain.Value, Value: domain.Value,
Attrs: domainAttributeToMap(domain.Attribute), Attrs: domainAttributeToMap(domain.Attribute),
} }
matcher.Root[domain.Value] = compiled
default: default:
return nil, errors.New("unsupported domain type") return nil, errors.New("unsupported domain type")
} }
matcher.Domains = append(matcher.Domains, compiled)
} }
return &geositeMatcher{ return matcher, nil
Domains: domains,
Attrs: attrs,
}, nil
} }
func domainAttributeToMap(attrs []*v2geo.Domain_Attribute) map[string]bool { func domainAttributeToMap(attrs []*v2geo.Domain_Attribute) map[string]bool {
+154 -20
View File
@@ -59,6 +59,8 @@ type compiledExprRule struct {
Log bool Log bool
ModInstance modifier.Instance ModInstance modifier.Instance
Program *vm.Program Program *vm.Program
Native nativeExpr
AnalyzerRefs map[string]analyzerRuleRef
GeoSiteConditions []string GeoSiteConditions []string
StartTimeSecs int // seconds since midnight, -1 if unset StartTimeSecs int // seconds since midnight, -1 if unset
StopTimeSecs int // seconds since midnight, -1 if unset StopTimeSecs int // seconds since midnight, -1 if unset
@@ -67,6 +69,7 @@ type compiledExprRule struct {
} }
var _ Ruleset = (*exprRuleset)(nil) var _ Ruleset = (*exprRuleset)(nil)
var _ LogFinalizer = (*exprRuleset)(nil)
var ( var (
envPool = sync.Pool{ envPool = sync.Pool{
@@ -102,10 +105,12 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
}() }()
} }
env := envPool.Get().(map[string]any) var env map[string]any
clear(env) var macMap, ipMap, portMap map[string]any
macMap, ipMap, portMap := populateExprEnv(env, info)
releaseEnv := func() { releaseEnv := func() {
if env == nil {
return
}
clear(env) clear(env)
envPool.Put(env) envPool.Put(env)
putSubMap(macMap) putSubMap(macMap)
@@ -113,31 +118,45 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
putSubMap(portMap) putSubMap(portMap)
} }
now := time.Now() now := time.Now()
logged := false
for _, rule := range r.Rules { for _, rule := range r.Rules {
if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) { if !matchTime(now, rule.StartTimeSecs, rule.StopTimeSecs, rule.Weekdays, rule.WeekdaysNegated) {
continue continue
} }
v, err := vm.Run(rule.Program, env) matched := false
if err != nil { if rule.Native != nil {
if r.stats != nil { matched = rule.Native.Match(info)
r.stats.MatchErrors.Add(1) } else {
if env == nil {
env = envPool.Get().(map[string]any)
clear(env)
macMap, ipMap, portMap = populateExprEnv(env, info)
} }
r.Logger.MatchError(info, rule.Name, err) v, err := vm.Run(rule.Program, env)
continue if err != nil {
if r.stats != nil {
r.stats.MatchErrors.Add(1)
}
r.Logger.MatchError(info, rule.Name, err)
continue
}
matched, _ = v.(bool)
} }
if vBool, ok := v.(bool); ok && vBool { if matched {
if rule.Log { if rule.Log {
logInfo := info logInfo := info
if len(rule.GeoSiteConditions) > 0 && r.GeoMatcher != nil { if len(rule.GeoSiteConditions) > 0 && r.GeoMatcher != nil {
logInfo = addGeoSiteLogMetadata(logInfo, r.GeoMatcher, rule.GeoSiteConditions) logInfo = addGeoSiteLogMetadata(logInfo, r.GeoMatcher, rule.GeoSiteConditions)
} }
r.Logger.Log(logInfo, rule.Name) r.Logger.Log(logInfo, rule.Name)
logged = true
} }
if rule.Action != nil { if rule.Action != nil {
releaseEnv() releaseEnv()
return MatchResult{ return MatchResult{
Action: *rule.Action, Action: *rule.Action,
ModInstance: rule.ModInstance, ModInstance: rule.ModInstance,
Logged: logged,
} }
} }
} }
@@ -145,9 +164,40 @@ func (r *exprRuleset) Match(info StreamInfo) MatchResult {
releaseEnv() releaseEnv()
return MatchResult{ return MatchResult{
Action: ActionMaybe, Action: ActionMaybe,
Logged: logged,
} }
} }
func (r *exprRuleset) CanFinalizeAfterLog(info StreamInfo, activeAnalyzers []string) bool {
active := make(map[string]bool, len(activeAnalyzers))
for _, name := range activeAnalyzers {
active[name] = true
}
for _, rule := range r.Rules {
if rule.Action == nil {
continue
}
if *rule.Action == ActionModify {
return false
}
if rule.StartTimeSecs != -1 || rule.StopTimeSecs != -1 || len(rule.Weekdays) != 0 {
return false
}
for name, ref := range rule.AnalyzerRefs {
if !active[name] {
continue
}
if ref.ResponseSide {
return false
}
if _, ok := info.Props[name]; !ok {
return false
}
}
}
return true
}
func (r *exprRuleset) Stats() Stats { func (r *exprRuleset) Stats() Stats {
if r == nil || r.stats == nil { if r == nil || r.stats == nil {
return Stats{} return Stats{}
@@ -242,17 +292,23 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
if err != nil { if err != nil {
return nil, fmt.Errorf("rule %q has invalid weekdays: %w", rule.Name, err) return nil, fmt.Errorf("rule %q has invalid weekdays: %w", rule.Name, err)
} }
var analyzerRefs map[string]analyzerRuleRef
if refTree, err := parser.Parse(rule.Expr); err == nil && refTree != nil {
analyzerRefs = collectAnalyzerRefs(refTree.Node, fullAnMap)
}
cr := compiledExprRule{ cr := compiledExprRule{
Name: rule.Name, Name: rule.Name,
Action: action, Action: action,
Log: rule.Log, Log: rule.Log,
Program: program, Program: program,
AnalyzerRefs: analyzerRefs,
GeoSiteConditions: extractGeoSiteConditions(rule.Expr), GeoSiteConditions: extractGeoSiteConditions(rule.Expr),
StartTimeSecs: startSecs, StartTimeSecs: startSecs,
StopTimeSecs: stopSecs, StopTimeSecs: stopSecs,
Weekdays: weekdays, Weekdays: weekdays,
WeekdaysNegated: weekdaysNegated, WeekdaysNegated: weekdaysNegated,
} }
cr.Native = compileNativeExpr(rule.Expr, funcMap, geoMatcher)
if action != nil && *action == ActionModify { if action != nil && *action == ActionModify {
mod, ok := fullModMap[rule.Modifier.Name] mod, ok := fullModMap[rule.Modifier.Name]
if !ok { if !ok {
@@ -266,11 +322,13 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
} }
compiledRules = append(compiledRules, cr) compiledRules = append(compiledRules, cr)
} }
// Convert the analyzer map to a list. depAns := make([]analyzer.Analyzer, 0, len(depAnMap))
var depAns []analyzer.Analyzer for _, a := range ans {
for _, a := range depAnMap { if depAnMap[a.Name()] != nil {
depAns = append(depAns, a) depAns = append(depAns, a)
}
} }
return &exprRuleset{ return &exprRuleset{
Rules: compiledRules, Rules: compiledRules,
Ans: depAns, Ans: depAns,
@@ -378,6 +436,58 @@ func (v *idVisitor) Visit(node *ast.Node) {
} }
} }
type analyzerRuleRef struct {
ResponseSide bool
}
type analyzerRefVisitor struct {
Analyzers map[string]analyzer.Analyzer
Refs map[string]analyzerRuleRef
}
func collectAnalyzerRefs(root ast.Node, analyzers map[string]analyzer.Analyzer) map[string]analyzerRuleRef {
visitor := &analyzerRefVisitor{
Analyzers: analyzers,
Refs: make(map[string]analyzerRuleRef),
}
ast.Walk(&root, visitor)
return visitor.Refs
}
func (v *analyzerRefVisitor) Visit(node *ast.Node) {
switch n := (*node).(type) {
case *ast.IdentifierNode:
if _, ok := v.Analyzers[n.Value]; ok {
v.add(n.Value, false)
}
case *ast.MemberNode:
path := memberPath(n)
if len(path) == 0 {
return
}
name := path[0]
if _, ok := v.Analyzers[name]; !ok {
return
}
v.add(name, len(path) > 1 && isResponseSideAnalyzerPath(path[1]))
}
}
func (v *analyzerRefVisitor) add(name string, responseSide bool) {
ref := v.Refs[name]
ref.ResponseSide = ref.ResponseSide || responseSide
v.Refs[name] = ref
}
func isResponseSideAnalyzerPath(name string) bool {
switch name {
case "resp", "server", "answers", "response":
return true
default:
return false
}
}
// idPatcher patches the AST during expr compilation, replacing certain values with // idPatcher patches the AST during expr compilation, replacing certain values with
// their internal representations for better runtime performance. // their internal representations for better runtime performance.
type idPatcher struct { type idPatcher struct {
@@ -524,7 +634,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
InitFunc: geoMatcher.LoadGeoIP, InitFunc: geoMatcher.LoadGeoIP,
PatchFunc: nil, PatchFunc: nil,
Func: func(params ...any) (any, error) { Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil a, ok1 := params[0].(string)
b, ok2 := params[1].(string)
if !ok1 || !ok2 {
return false, nil
}
return geoMatcher.MatchGeoIp(a, b), nil
}, },
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)}, Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)},
}, },
@@ -532,7 +647,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
InitFunc: geoMatcher.LoadGeoSite, InitFunc: geoMatcher.LoadGeoSite,
PatchFunc: nil, PatchFunc: nil,
Func: func(params ...any) (any, error) { Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil a, ok1 := params[0].(string)
b, ok2 := params[1].(string)
if !ok1 || !ok2 {
return false, nil
}
return geoMatcher.MatchGeoSite(a, b), nil
}, },
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)}, Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
}, },
@@ -540,7 +660,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
InitFunc: geoMatcher.LoadGeoSite, InitFunc: geoMatcher.LoadGeoSite,
PatchFunc: nil, PatchFunc: nil,
Func: func(params ...any) (any, error) { Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoSiteSet(params[0].(string), params[1].(*geo.SiteConditionSet)), nil a, ok1 := params[0].(string)
b, ok2 := params[1].(*geo.SiteConditionSet)
if !ok1 || !ok2 {
return false, nil
}
return geoMatcher.MatchGeoSiteSet(a, b), nil
}, },
Types: []reflect.Type{ Types: []reflect.Type{
reflect.TypeOf((func(string, *geo.SiteConditionSet) bool)(nil)), reflect.TypeOf((func(string, *geo.SiteConditionSet) bool)(nil)),
@@ -561,7 +686,12 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
return nil return nil
}, },
Func: func(params ...any) (any, error) { Func: func(params ...any) (any, error) {
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil a, ok1 := params[0].(string)
b, ok2 := params[1].(*net.IPNet)
if !ok1 || !ok2 {
return false, nil
}
return builtins.MatchCIDR(a, b), nil
}, },
Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)}, Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)},
}, },
@@ -570,7 +700,6 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
PatchFunc: func(args *[]ast.Node) error { PatchFunc: func(args *[]ast.Node) error {
var serverStr *ast.StringNode var serverStr *ast.StringNode
if len(*args) > 1 { if len(*args) > 1 {
// Has the optional server argument
var ok bool var ok bool
serverStr, ok = (*args)[1].(*ast.StringNode) serverStr, ok = (*args)[1].(*ast.StringNode)
if !ok { if !ok {
@@ -600,9 +729,14 @@ func buildFunctionMap(config *BuiltinConfig, stats *statsCounters) (map[string]*
stats.LookupLatencyNanos.Add(uint64(time.Since(start).Nanoseconds())) stats.LookupLatencyNanos.Add(uint64(time.Since(start).Nanoseconds()))
}() }()
} }
a, ok1 := params[0].(string)
b, ok2 := params[1].(*net.Resolver)
if !ok1 || !ok2 {
return nil, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
defer cancel() defer cancel()
out, err := params[1].(*net.Resolver).LookupHost(ctx, params[0].(string)) out, err := b.LookupHost(ctx, a)
if err != nil && stats != nil { if err != nil && stats != nil {
stats.LookupErrors.Add(1) stats.LookupErrors.Add(1)
} }
+98
View File
@@ -1,6 +1,7 @@
package ruleset package ruleset
import ( import (
"net"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@@ -12,6 +13,13 @@ import (
"github.com/expr-lang/expr/parser" "github.com/expr-lang/expr/parser"
) )
type testAnalyzer struct {
name string
}
func (a testAnalyzer) Name() string { return a.name }
func (a testAnalyzer) Limit() int { return 0 }
func TestExtractGeoSiteConditions(t *testing.T) { func TestExtractGeoSiteConditions(t *testing.T) {
expression := ` expression := `
(geosite(tls.req.sni, "openai") || geosite(quic.req.sni, "OpenAI")) && (geosite(tls.req.sni, "openai") || geosite(quic.req.sni, "OpenAI")) &&
@@ -88,3 +96,93 @@ func TestIDPatcher_PatchesGeoSiteORChainToGeoSiteSet(t *testing.T) {
t.Fatalf("expected OR chain to be collapsed, got %q", got) t.Fatalf("expected OR chain to be collapsed, got %q", got)
} }
} }
func TestCompileExprRulesPrunesUnusedAnalyzers(t *testing.T) {
rs, err := CompileExprRules([]ExprRule{
{Name: "network-only", Action: "allow", Expr: `proto == "tcp" && port.dst == 443`},
}, []analyzer.Analyzer{testAnalyzer{name: "tls"}, testAnalyzer{name: "quic"}}, nil, &BuiltinConfig{})
if err != nil {
t.Fatalf("CompileExprRules error: %v", err)
}
exprRS := rs.(*exprRuleset)
if len(exprRS.Ans) != 0 {
t.Fatalf("expected no analyzers for network-only rule, got %d", len(exprRS.Ans))
}
if exprRS.Rules[0].Native == nil {
t.Fatalf("expected network-only rule to compile to native matcher")
}
got := rs.Match(StreamInfo{Protocol: ProtocolTCP, DstPort: 443})
if got.Action != ActionAllow {
t.Fatalf("native match action=%v want=%v", got.Action, ActionAllow)
}
}
func TestCompileExprRulesKeepsReferencedAnalyzersOnly(t *testing.T) {
rs, err := CompileExprRules([]ExprRule{
{Name: "tls-only", Action: "allow", Expr: `tls != nil && tls.req != nil && tls.req.sni == "example.com"`},
}, []analyzer.Analyzer{testAnalyzer{name: "tls"}, testAnalyzer{name: "quic"}}, nil, &BuiltinConfig{})
if err != nil {
t.Fatalf("CompileExprRules error: %v", err)
}
exprRS := rs.(*exprRuleset)
if len(exprRS.Ans) != 1 || exprRS.Ans[0].Name() != "tls" {
t.Fatalf("expected only tls analyzer, got %#v", exprRS.Ans)
}
}
func TestNativeCIDRMatcher(t *testing.T) {
funcMap, geoMatcher := buildFunctionMapForTest()
n := compileNativeExpr(`cidr(ip.src, "192.168.1.0/24") && port.dst >= 80 && port.dst <= 443`, funcMap, geoMatcher)
if n == nil {
t.Fatal("expected native matcher")
}
if !n.Match(StreamInfo{SrcIP: net.ParseIP("192.168.1.10"), DstPort: 443}) {
t.Fatal("expected native CIDR matcher to match")
}
if n.Match(StreamInfo{SrcIP: net.ParseIP("10.0.0.1"), DstPort: 443}) {
t.Fatal("expected native CIDR matcher not to match")
}
}
func TestCanFinalizeAfterLogForRequestOnlyActionRules(t *testing.T) {
rs, err := CompileExprRules([]ExprRule{
{Name: "log-host", Log: true, Expr: `tls != nil && tls.req != nil && tls.req.sni != nil`},
{Name: "block-bad-host", Action: "block", Expr: `tls != nil && tls.req != nil && tls.req.sni == "bad.example"`},
}, []analyzer.Analyzer{testAnalyzer{name: "tls"}}, nil, &BuiltinConfig{})
if err != nil {
t.Fatalf("CompileExprRules error: %v", err)
}
info := StreamInfo{
Props: analyzer.CombinedPropMap{
"tls": analyzer.PropMap{"req": analyzer.PropMap{"sni": "good.example"}},
},
}
if !rs.(LogFinalizer).CanFinalizeAfterLog(info, []string{"tls"}) {
t.Fatal("expected request-only rules to allow log finalization once request props exist")
}
}
func TestCanFinalizeAfterLogWaitsForResponseActionRules(t *testing.T) {
rs, err := CompileExprRules([]ExprRule{
{Name: "log-host", Log: true, Expr: `tls != nil && tls.req != nil && tls.req.sni != nil`},
{Name: "block-response", Action: "block", Expr: `tls != nil && tls.resp != nil && tls.resp.cipher_suite == "bad"`},
}, []analyzer.Analyzer{testAnalyzer{name: "tls"}}, nil, &BuiltinConfig{})
if err != nil {
t.Fatalf("CompileExprRules error: %v", err)
}
info := StreamInfo{
Props: analyzer.CombinedPropMap{
"tls": analyzer.PropMap{"req": analyzer.PropMap{"sni": "good.example"}},
},
}
if rs.(LogFinalizer).CanFinalizeAfterLog(info, []string{"tls"}) {
t.Fatal("expected response-side rule to keep inspection open")
}
}
func buildFunctionMapForTest() (map[string]*Function, *geo.GeoMatcher) {
m, g := buildFunctionMap(&BuiltinConfig{}, nil)
return m, g
}
+5
View File
@@ -85,6 +85,7 @@ func (i StreamInfo) DstString() string {
type MatchResult struct { type MatchResult struct {
Action Action Action Action
ModInstance modifier.Instance ModInstance modifier.Instance
Logged bool
} }
type Ruleset interface { type Ruleset interface {
@@ -96,6 +97,10 @@ type Ruleset interface {
Match(StreamInfo) MatchResult Match(StreamInfo) MatchResult
} }
type LogFinalizer interface {
CanFinalizeAfterLog(StreamInfo, []string) bool
}
type Stats struct { type Stats struct {
MatchCalls uint64 MatchCalls uint64
MatchErrors uint64 MatchErrors uint64
+273
View File
@@ -0,0 +1,273 @@
package ruleset
import (
"net"
"strconv"
"strings"
"github.com/expr-lang/expr/ast"
"github.com/expr-lang/expr/parser"
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo"
)
type nativeExpr interface {
Match(StreamInfo) bool
}
type nativeBoolFunc func(StreamInfo) bool
func (f nativeBoolFunc) Match(info StreamInfo) bool {
return f(info)
}
type nativeValueFunc func(StreamInfo) (any, bool)
func compileNativeExpr(expression string, funcMap map[string]*Function, gm *geo.GeoMatcher) nativeExpr {
tree, err := parser.Parse(expression)
if err != nil || tree == nil || tree.Node == nil {
return nil
}
root := tree.Node
patcher := &idPatcher{FuncMap: funcMap, GeoMatcher: gm}
ast.Walk(&root, patcher)
if patcher.Err != nil {
return nil
}
return compileNativeBool(root)
}
func compileNativeBool(node ast.Node) nativeExpr {
switch n := node.(type) {
case *ast.BinaryNode:
switch n.Operator {
case "&&", "and":
left := compileNativeBool(n.Left)
right := compileNativeBool(n.Right)
if left == nil || right == nil {
return nil
}
return nativeBoolFunc(func(info StreamInfo) bool {
return left.Match(info) && right.Match(info)
})
case "||", "or":
left := compileNativeBool(n.Left)
right := compileNativeBool(n.Right)
if left == nil || right == nil {
return nil
}
return nativeBoolFunc(func(info StreamInfo) bool {
return left.Match(info) || right.Match(info)
})
case "==", "!=", ">", ">=", "<", "<=":
left := compileNativeValue(n.Left)
right := compileNativeValue(n.Right)
if left == nil || right == nil {
return nil
}
op := n.Operator
return nativeBoolFunc(func(info StreamInfo) bool {
lv, lok := left(info)
rv, rok := right(info)
if !lok || !rok {
return false
}
result, ok := compareNativeValues(lv, rv, op)
return ok && result
})
default:
return nil
}
case *ast.UnaryNode:
if n.Operator != "!" && n.Operator != "not" {
return nil
}
child := compileNativeBool(n.Node)
if child == nil {
return nil
}
return nativeBoolFunc(func(info StreamInfo) bool {
return !child.Match(info)
})
case *ast.CallNode:
return compileNativeCall(n)
case *ast.BoolNode:
value := n.Value
return nativeBoolFunc(func(StreamInfo) bool { return value })
default:
return nil
}
}
func compileNativeCall(n *ast.CallNode) nativeExpr {
id, ok := n.Callee.(*ast.IdentifierNode)
if !ok || strings.ToLower(id.Value) != "cidr" || len(n.Arguments) != 2 {
return nil
}
ipValue := compileNativeValue(n.Arguments[0])
if ipValue == nil {
return nil
}
var cidr *net.IPNet
switch arg := n.Arguments[1].(type) {
case *ast.ConstantNode:
cidr, _ = arg.Value.(*net.IPNet)
case *ast.StringNode:
_, parsed, err := net.ParseCIDR(arg.Value)
if err == nil {
cidr = parsed
}
}
if cidr == nil {
return nil
}
return nativeBoolFunc(func(info StreamInfo) bool {
value, ok := ipValue(info)
if !ok {
return false
}
switch v := value.(type) {
case net.IP:
return cidr.Contains(v)
case string:
ip := net.ParseIP(v)
return ip != nil && cidr.Contains(ip)
default:
return false
}
})
}
func compileNativeValue(node ast.Node) nativeValueFunc {
switch n := node.(type) {
case *ast.StringNode:
value := n.Value
return func(StreamInfo) (any, bool) { return value, true }
case *ast.IntegerNode:
value := int64(n.Value)
return func(StreamInfo) (any, bool) { return value, true }
case *ast.IdentifierNode:
switch strings.ToLower(n.Value) {
case "proto":
return func(info StreamInfo) (any, bool) { return info.Protocol.String(), true }
default:
return nil
}
case *ast.MemberNode:
return compileNativeMember(n)
default:
return nil
}
}
func compileNativeMember(n *ast.MemberNode) nativeValueFunc {
path := memberPath(n)
switch strings.Join(path, ".") {
case "mac.src":
return func(info StreamInfo) (any, bool) { return info.SrcMAC.String(), true }
case "mac.dst":
return func(info StreamInfo) (any, bool) { return info.DstMAC.String(), true }
case "ip.src":
return func(info StreamInfo) (any, bool) { return info.SrcIP, info.SrcIP != nil }
case "ip.dst":
return func(info StreamInfo) (any, bool) { return info.DstIP, info.DstIP != nil }
case "port.src":
return func(info StreamInfo) (any, bool) { return int64(info.SrcPort), true }
case "port.dst":
return func(info StreamInfo) (any, bool) { return int64(info.DstPort), true }
default:
return nil
}
}
func memberPath(node ast.Node) []string {
switch n := node.(type) {
case *ast.IdentifierNode:
return []string{strings.ToLower(n.Value)}
case *ast.MemberNode:
base := memberPath(n.Node)
prop, ok := n.Property.(*ast.StringNode)
if !ok {
return nil
}
return append(base, strings.ToLower(prop.Value))
default:
return nil
}
}
func compareNativeValues(left, right any, op string) (bool, bool) {
if li, lok := nativeInt(left); lok {
ri, rok := nativeInt(right)
if !rok {
return false, false
}
return compareNativeOrdered(li, ri, op), true
}
ls, lok := nativeString(left)
if !lok {
return false, false
}
rs, rok := nativeString(right)
if !rok {
return false, false
}
switch op {
case "==":
return ls == rs, true
case "!=":
return ls != rs, true
default:
return false, false
}
}
func compareNativeOrdered(left, right int64, op string) bool {
switch op {
case "==":
return left == right
case "!=":
return left != right
case ">":
return left > right
case ">=":
return left >= right
case "<":
return left < right
case "<=":
return left <= right
default:
return false
}
}
func nativeInt(v any) (int64, bool) {
switch n := v.(type) {
case int:
return int64(n), true
case int64:
return n, true
case uint16:
return int64(n), true
case *ast.IntegerNode:
return int64(n.Value), true
default:
return 0, false
}
}
func nativeString(v any) (string, bool) {
switch s := v.(type) {
case string:
return s, true
case net.IP:
if s == nil {
return "", false
}
return s.String(), true
case int64:
return strconv.FormatInt(s, 10), true
default:
return "", false
}
}