tests: fix

This commit is contained in:
2026-05-12 15:29:00 +00:00
parent ecc2cde1c2
commit e8fdf1268b
4 changed files with 102 additions and 152 deletions
+56 -105
View File
@@ -5,12 +5,12 @@ import (
"testing" "testing"
"git.difuse.io/Difuse/Mellaris/analyzer" "git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset" "git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake" "github.com/bwmarrin/snowflake"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/google/gopacket/reassembly"
) )
type fixedRuleset struct { type fixedRuleset struct {
@@ -137,137 +137,88 @@ func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
} }
} }
func TestTCPStreamUsesUpdatedRuleset(t *testing.T) { func TestTCPFlowUsesUpdatedRuleset(t *testing.T) {
node, err := snowflake.NewNode(0) node, err := snowflake.NewNode(0)
if err != nil { if err != nil {
t.Fatalf("create node: %v", err) t.Fatalf("create node: %v", err)
} }
f := &tcpStreamFactory{ mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node)
WorkerID: 0, mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0)
Logger: noopTestLogger{},
Node: node,
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
}
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) l3 := L3Info{
tcp := &layers.TCP{ Version: 4,
Protocol: 6,
SrcIP: [4]byte{10, 0, 0, 1},
DstIP: [4]byte{10, 0, 0, 2},
}
tcp := TCPInfo{
SrcPort: 12345, SrcPort: 12345,
DstPort: 443, DstPort: 443,
} Seq: 100,
ctx := &tcpContext{
PacketMetadata: &gopacket.PacketMetadata{},
Verdict: tcpVerdictAccept,
}
rs := f.New(ipFlow, tcp.TransportFlow(), tcp, ctx)
s, ok := rs.(*tcpStream)
if !ok {
t.Fatalf("unexpected stream type %T", rs)
} }
if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil { v := mgr.handle(1, l3, tcp, nil, nil, nil)
t.Fatalf("update ruleset: %v", err) if v != io.VerdictAcceptStream {
t.Fatalf("first verdict=%v want=%v", v, io.VerdictAcceptStream)
} }
s.ReassembledSG(fakeScatterGather{data: []byte("payload")}, ctx) mgr.updateRuleset(fixedRuleset{action: ruleset.ActionBlock}, 1)
if ctx.Verdict != tcpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx.Verdict, tcpVerdictDropStream) tcp2 := TCPInfo{
SrcPort: 12345,
DstPort: 443,
Seq: 100,
}
v = mgr.handle(2, l3, tcp2, []byte("data"), nil, nil)
if v != io.VerdictDropStream {
t.Fatalf("verdict after update=%v want=%v", v, io.VerdictDropStream)
} }
} }
func TestTCPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) { func TestTCPFlowReevaluatesAfterRulesetVersionChange(t *testing.T) {
node, err := snowflake.NewNode(0) node, err := snowflake.NewNode(0)
if err != nil { if err != nil {
t.Fatalf("create node: %v", err) t.Fatalf("create node: %v", err)
} }
f := &tcpStreamFactory{ mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node)
WorkerID: 0, mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0)
Logger: noopTestLogger{},
Node: node,
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
}
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) l3 := L3Info{
tcp := &layers.TCP{ Version: 4,
Protocol: 6,
SrcIP: [4]byte{10, 0, 0, 1},
DstIP: [4]byte{10, 0, 0, 2},
}
tcp := TCPInfo{
SrcPort: 12345, SrcPort: 12345,
DstPort: 443, DstPort: 443,
} Seq: 100,
ctx1 := &tcpContext{
PacketMetadata: &gopacket.PacketMetadata{},
Verdict: tcpVerdictAccept,
}
rs := f.New(ipFlow, tcp.TransportFlow(), tcp, ctx1)
s, ok := rs.(*tcpStream)
if !ok {
t.Fatalf("unexpected stream type %T", rs)
} }
start1 := false v := mgr.handle(1, l3, tcp, nil, nil, nil)
if !s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start1, ctx1) { if v != io.VerdictAcceptStream {
t.Fatalf("unexpected Accept=false before first feed") t.Fatalf("first verdict=%v want=%v", v, io.VerdictAcceptStream)
}
s.ReassembledSG(fakeScatterGather{data: []byte("first")}, ctx1)
if ctx1.Verdict != tcpVerdictAcceptStream {
t.Fatalf("verdict=%v want=%v", ctx1.Verdict, tcpVerdictAcceptStream)
} }
if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil { mgr.updateRuleset(fixedRuleset{action: ruleset.ActionBlock}, 1)
t.Fatalf("update ruleset: %v", err)
tcp2 := TCPInfo{
SrcPort: 12345,
DstPort: 443,
Seq: 100,
}
v = mgr.handle(2, l3, tcp2, []byte("data"), nil, nil)
if v != io.VerdictDropStream {
t.Fatalf("verdict after update=%v want=%v", v, io.VerdictDropStream)
} }
ctx2 := &tcpContext{ tcp3 := TCPInfo{
PacketMetadata: &gopacket.PacketMetadata{}, SrcPort: 12345,
Verdict: tcpVerdictAccept, DstPort: 443,
Seq: 104,
} }
start2 := false v = mgr.handle(1, l3, tcp3, nil, nil, nil)
if !s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start2, ctx2) { if v != io.VerdictDropStream {
t.Fatalf("expected Accept=true after ruleset update") t.Fatalf("cached verdict after update=%v want=%v", v, io.VerdictDropStream)
}
s.ReassembledSG(fakeScatterGather{data: []byte("second")}, ctx2)
if ctx2.Verdict != tcpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx2.Verdict, tcpVerdictDropStream)
}
ctx3 := &tcpContext{
PacketMetadata: &gopacket.PacketMetadata{},
Verdict: tcpVerdictAccept,
}
start3 := false
if s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start3, ctx3) {
t.Fatalf("expected Accept=false with unchanged ruleset and no active entries")
}
if ctx3.Verdict != tcpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx3.Verdict, tcpVerdictDropStream)
} }
} }
type fakeScatterGather struct {
data []byte
}
func (s fakeScatterGather) Lengths() (int, int) {
return len(s.data), 0
}
func (s fakeScatterGather) Fetch(length int) []byte {
if length < 0 {
return nil
}
if length > len(s.data) {
length = len(s.data)
}
return s.data[:length]
}
func (fakeScatterGather) KeepFrom(int) {}
func (fakeScatterGather) CaptureInfo(int) gopacket.CaptureInfo {
return gopacket.CaptureInfo{}
}
func (fakeScatterGather) Info() (reassembly.TCPFlowDirection, bool, bool, int) {
return reassembly.TCPDirClientToServer, true, false, 0
}
func (fakeScatterGather) Stats() reassembly.TCPAssemblyStats {
return reassembly.TCPAssemblyStats{}
}
+33 -33
View File
@@ -49,55 +49,65 @@ type tcpFlowEntry struct {
} }
func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict { func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
if f.rulesetChanged() || f.virgin { rs, version := f.currentRuleset()
f.virgin = false rulesetChanged := version != f.rulesetVersion
return io.VerdictAccept
} if !f.virgin && !rulesetChanged && len(f.activeEntries) == 0 {
if len(f.activeEntries) == 0 {
return f.lastVerdict return f.lastVerdict
} }
dir, rev := f.resolveDirection(tcp)
if tcp.RST || tcp.FIN { if tcp.RST || tcp.FIN {
f.closeActiveEntries() f.closeActiveEntries()
f.runMatch(rs, version, rulesetChanged)
f.maybeFinalizeVerdict() f.maybeFinalizeVerdict()
return f.lastVerdict return f.lastVerdict
} }
if len(payload) == 0 { if len(payload) > 0 {
return io.VerdictAccept dir, rev := f.resolveDirection(tcp)
}
expected := f.dirSeq[dir] expected := f.dirSeq[dir]
if f.feedCalled[dir] && expected != 0 && tcp.Seq != expected { if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected {
return io.VerdictAccept
}
f.feedCalled[dir] = true f.feedCalled[dir] = true
f.dirBuf[dir] = append(f.dirBuf[dir], payload...) f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
f.dirSeq[dir] = tcp.Seq + uint32(len(payload)) f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
if len(f.dirBuf[dir]) <= tcpFlowMaxBuffer {
if len(f.dirBuf[dir]) > tcpFlowMaxBuffer { f.feedAnalyzers(rev)
return io.VerdictAccept }
}
} }
updated := false f.runMatch(rs, version, rulesetChanged)
f.maybeFinalizeVerdict()
return f.lastVerdict
}
func (f *tcpFlow) feedAnalyzers(rev bool) {
buf := f.dirBuf[uint8(tcpDirC2S)]
if rev {
buf = f.dirBuf[uint8(tcpDirS2C)]
}
for i := len(f.activeEntries) - 1; i >= 0; i-- { for i := len(f.activeEntries) - 1; i >= 0; i-- {
entry := f.activeEntries[i] entry := f.activeEntries[i]
update, closeUpdate, done := feedFlowEntry(entry, rev, f.dirBuf[dir]) update, closeUpdate, done := feedFlowEntry(entry, rev, buf)
u1 := processPropUpdate(f.info.Props, entry.Name, update) u1 := processPropUpdate(f.info.Props, entry.Name, update)
u2 := processPropUpdate(f.info.Props, entry.Name, closeUpdate) u2 := processPropUpdate(f.info.Props, entry.Name, closeUpdate)
updated = updated || u1 || u2 if u1 || u2 {
f.logger.TCPStreamPropUpdate(f.info, false)
}
if done { if done {
f.activeEntries = append(f.activeEntries[:i], f.activeEntries[i+1:]...) f.activeEntries = append(f.activeEntries[:i], f.activeEntries[i+1:]...)
f.doneEntries = append(f.doneEntries, entry) f.doneEntries = append(f.doneEntries, entry)
} }
} }
}
if updated { func (f *tcpFlow) runMatch(rs ruleset.Ruleset, version uint64, rulesetChanged bool) {
f.logger.TCPStreamPropUpdate(f.info, false) if !f.virgin && !rulesetChanged {
rs, version := f.currentRuleset() return
}
f.virgin = false
f.rulesetVersion = version f.rulesetVersion = version
result := ruleset.MatchResult{Action: ruleset.ActionMaybe} result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
if rs != nil { if rs != nil {
result = rs.Match(f.info) result = rs.Match(f.info)
@@ -108,12 +118,7 @@ func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
f.lastVerdict = verdict f.lastVerdict = verdict
f.closeActiveEntries() f.closeActiveEntries()
f.logger.TCPStreamAction(f.info, action, false) f.logger.TCPStreamAction(f.info, action, false)
return verdict
} }
}
f.maybeFinalizeVerdict()
return f.lastVerdict
} }
func (f *tcpFlow) maybeFinalizeVerdict() { func (f *tcpFlow) maybeFinalizeVerdict() {
@@ -137,11 +142,6 @@ func (f *tcpFlow) currentRuleset() (ruleset.Ruleset, uint64) {
return f.rulesetSource() return f.rulesetSource()
} }
func (f *tcpFlow) rulesetChanged() bool {
_, version := f.currentRuleset()
return version != f.rulesetVersion
}
func (f *tcpFlow) closeActiveEntries() { func (f *tcpFlow) closeActiveEntries() {
updated := false updated := false
for _, entry := range f.activeEntries { for _, entry := range f.activeEntries {
-1
View File
@@ -175,7 +175,6 @@ func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []by
ipSrc := net.IP(l3.SrcIP[:]) ipSrc := net.IP(l3.SrcIP[:])
ipDst := net.IP(l3.DstIP[:]) ipDst := net.IP(l3.DstIP[:])
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, ipSrc.To4(), ipDst.To4()) ipFlow := gopacket.NewFlow(layers.EndpointIPv4, ipSrc.To4(), ipDst.To4())
udpFlow := gopacket.NewFlow(layers.EndpointUDPPort, []byte{byte(udp.SrcPort >> 8), byte(udp.SrcPort)}, []byte{byte(udp.DstPort >> 8), byte(udp.DstPort)})
if len(srcMAC) == 0 && w.macResolver != nil { if len(srcMAC) == 0 && w.macResolver != nil {
srcMAC = w.macResolver.Resolve(ipSrc) srcMAC = w.macResolver.Resolve(ipSrc)
+1 -1
View File
@@ -211,7 +211,7 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
} }
func (nio *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error { func (nio *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
for i, nq := range nio.nqs { for _, nq := range nio.nqs {
nq := nq nq := nq
err := nq.RegisterWithErrorFunc(ctx, err := nq.RegisterWithErrorFunc(ctx,
func(a nfqueue.Attribute) int { func(a nfqueue.Attribute) int {