tests: fix

This commit is contained in:
2026-05-12 15:29:00 +00:00
parent ecc2cde1c2
commit e8fdf1268b
4 changed files with 102 additions and 152 deletions
+56 -105
View File
@@ -5,12 +5,12 @@ import (
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/reassembly"
)
type fixedRuleset struct {
@@ -137,137 +137,88 @@ func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
}
}
func TestTCPStreamUsesUpdatedRuleset(t *testing.T) {
func TestTCPFlowUsesUpdatedRuleset(t *testing.T) {
node, err := snowflake.NewNode(0)
if err != nil {
t.Fatalf("create node: %v", err)
}
f := &tcpStreamFactory{
WorkerID: 0,
Logger: noopTestLogger{},
Node: node,
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
}
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node)
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0)
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
tcp := &layers.TCP{
l3 := L3Info{
Version: 4,
Protocol: 6,
SrcIP: [4]byte{10, 0, 0, 1},
DstIP: [4]byte{10, 0, 0, 2},
}
tcp := TCPInfo{
SrcPort: 12345,
DstPort: 443,
}
ctx := &tcpContext{
PacketMetadata: &gopacket.PacketMetadata{},
Verdict: tcpVerdictAccept,
}
rs := f.New(ipFlow, tcp.TransportFlow(), tcp, ctx)
s, ok := rs.(*tcpStream)
if !ok {
t.Fatalf("unexpected stream type %T", rs)
Seq: 100,
}
if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil {
t.Fatalf("update ruleset: %v", err)
v := mgr.handle(1, l3, tcp, nil, nil, nil)
if v != io.VerdictAcceptStream {
t.Fatalf("first verdict=%v want=%v", v, io.VerdictAcceptStream)
}
s.ReassembledSG(fakeScatterGather{data: []byte("payload")}, ctx)
if ctx.Verdict != tcpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx.Verdict, tcpVerdictDropStream)
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionBlock}, 1)
tcp2 := TCPInfo{
SrcPort: 12345,
DstPort: 443,
Seq: 100,
}
v = mgr.handle(2, l3, tcp2, []byte("data"), nil, nil)
if v != io.VerdictDropStream {
t.Fatalf("verdict after update=%v want=%v", v, io.VerdictDropStream)
}
}
func TestTCPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
func TestTCPFlowReevaluatesAfterRulesetVersionChange(t *testing.T) {
node, err := snowflake.NewNode(0)
if err != nil {
t.Fatalf("create node: %v", err)
}
f := &tcpStreamFactory{
WorkerID: 0,
Logger: noopTestLogger{},
Node: node,
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
}
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node)
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0)
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
tcp := &layers.TCP{
l3 := L3Info{
Version: 4,
Protocol: 6,
SrcIP: [4]byte{10, 0, 0, 1},
DstIP: [4]byte{10, 0, 0, 2},
}
tcp := TCPInfo{
SrcPort: 12345,
DstPort: 443,
}
ctx1 := &tcpContext{
PacketMetadata: &gopacket.PacketMetadata{},
Verdict: tcpVerdictAccept,
}
rs := f.New(ipFlow, tcp.TransportFlow(), tcp, ctx1)
s, ok := rs.(*tcpStream)
if !ok {
t.Fatalf("unexpected stream type %T", rs)
Seq: 100,
}
start1 := false
if !s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start1, ctx1) {
t.Fatalf("unexpected Accept=false before first feed")
}
s.ReassembledSG(fakeScatterGather{data: []byte("first")}, ctx1)
if ctx1.Verdict != tcpVerdictAcceptStream {
t.Fatalf("verdict=%v want=%v", ctx1.Verdict, tcpVerdictAcceptStream)
v := mgr.handle(1, l3, tcp, nil, nil, nil)
if v != io.VerdictAcceptStream {
t.Fatalf("first verdict=%v want=%v", v, io.VerdictAcceptStream)
}
if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil {
t.Fatalf("update ruleset: %v", err)
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionBlock}, 1)
tcp2 := TCPInfo{
SrcPort: 12345,
DstPort: 443,
Seq: 100,
}
v = mgr.handle(2, l3, tcp2, []byte("data"), nil, nil)
if v != io.VerdictDropStream {
t.Fatalf("verdict after update=%v want=%v", v, io.VerdictDropStream)
}
ctx2 := &tcpContext{
PacketMetadata: &gopacket.PacketMetadata{},
Verdict: tcpVerdictAccept,
tcp3 := TCPInfo{
SrcPort: 12345,
DstPort: 443,
Seq: 104,
}
start2 := false
if !s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start2, ctx2) {
t.Fatalf("expected Accept=true after ruleset update")
}
s.ReassembledSG(fakeScatterGather{data: []byte("second")}, ctx2)
if ctx2.Verdict != tcpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx2.Verdict, tcpVerdictDropStream)
}
ctx3 := &tcpContext{
PacketMetadata: &gopacket.PacketMetadata{},
Verdict: tcpVerdictAccept,
}
start3 := false
if s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start3, ctx3) {
t.Fatalf("expected Accept=false with unchanged ruleset and no active entries")
}
if ctx3.Verdict != tcpVerdictDropStream {
t.Fatalf("verdict=%v want=%v", ctx3.Verdict, tcpVerdictDropStream)
v = mgr.handle(1, l3, tcp3, nil, nil, nil)
if v != io.VerdictDropStream {
t.Fatalf("cached verdict after update=%v want=%v", v, io.VerdictDropStream)
}
}
type fakeScatterGather struct {
data []byte
}
func (s fakeScatterGather) Lengths() (int, int) {
return len(s.data), 0
}
func (s fakeScatterGather) Fetch(length int) []byte {
if length < 0 {
return nil
}
if length > len(s.data) {
length = len(s.data)
}
return s.data[:length]
}
func (fakeScatterGather) KeepFrom(int) {}
func (fakeScatterGather) CaptureInfo(int) gopacket.CaptureInfo {
return gopacket.CaptureInfo{}
}
func (fakeScatterGather) Info() (reassembly.TCPFlowDirection, bool, bool, int) {
return reassembly.TCPDirClientToServer, true, false, 0
}
func (fakeScatterGather) Stats() reassembly.TCPAssemblyStats {
return reassembly.TCPAssemblyStats{}
}