tests: fix
This commit is contained in:
+56
-105
@@ -5,12 +5,12 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.difuse.io/Difuse/Mellaris/analyzer"
|
"git.difuse.io/Difuse/Mellaris/analyzer"
|
||||||
|
"git.difuse.io/Difuse/Mellaris/io"
|
||||||
"git.difuse.io/Difuse/Mellaris/ruleset"
|
"git.difuse.io/Difuse/Mellaris/ruleset"
|
||||||
|
|
||||||
"github.com/bwmarrin/snowflake"
|
"github.com/bwmarrin/snowflake"
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/google/gopacket/reassembly"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type fixedRuleset struct {
|
type fixedRuleset struct {
|
||||||
@@ -137,137 +137,88 @@ func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTCPStreamUsesUpdatedRuleset(t *testing.T) {
|
func TestTCPFlowUsesUpdatedRuleset(t *testing.T) {
|
||||||
node, err := snowflake.NewNode(0)
|
node, err := snowflake.NewNode(0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create node: %v", err)
|
t.Fatalf("create node: %v", err)
|
||||||
}
|
}
|
||||||
f := &tcpStreamFactory{
|
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node)
|
||||||
WorkerID: 0,
|
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0)
|
||||||
Logger: noopTestLogger{},
|
|
||||||
Node: node,
|
|
||||||
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
|
|
||||||
}
|
|
||||||
|
|
||||||
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
|
l3 := L3Info{
|
||||||
tcp := &layers.TCP{
|
Version: 4,
|
||||||
|
Protocol: 6,
|
||||||
|
SrcIP: [4]byte{10, 0, 0, 1},
|
||||||
|
DstIP: [4]byte{10, 0, 0, 2},
|
||||||
|
}
|
||||||
|
tcp := TCPInfo{
|
||||||
SrcPort: 12345,
|
SrcPort: 12345,
|
||||||
DstPort: 443,
|
DstPort: 443,
|
||||||
}
|
Seq: 100,
|
||||||
ctx := &tcpContext{
|
|
||||||
PacketMetadata: &gopacket.PacketMetadata{},
|
|
||||||
Verdict: tcpVerdictAccept,
|
|
||||||
}
|
|
||||||
rs := f.New(ipFlow, tcp.TransportFlow(), tcp, ctx)
|
|
||||||
s, ok := rs.(*tcpStream)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("unexpected stream type %T", rs)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil {
|
v := mgr.handle(1, l3, tcp, nil, nil, nil)
|
||||||
t.Fatalf("update ruleset: %v", err)
|
if v != io.VerdictAcceptStream {
|
||||||
|
t.Fatalf("first verdict=%v want=%v", v, io.VerdictAcceptStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.ReassembledSG(fakeScatterGather{data: []byte("payload")}, ctx)
|
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionBlock}, 1)
|
||||||
if ctx.Verdict != tcpVerdictDropStream {
|
|
||||||
t.Fatalf("verdict=%v want=%v", ctx.Verdict, tcpVerdictDropStream)
|
tcp2 := TCPInfo{
|
||||||
|
SrcPort: 12345,
|
||||||
|
DstPort: 443,
|
||||||
|
Seq: 100,
|
||||||
|
}
|
||||||
|
v = mgr.handle(2, l3, tcp2, []byte("data"), nil, nil)
|
||||||
|
if v != io.VerdictDropStream {
|
||||||
|
t.Fatalf("verdict after update=%v want=%v", v, io.VerdictDropStream)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTCPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) {
|
func TestTCPFlowReevaluatesAfterRulesetVersionChange(t *testing.T) {
|
||||||
node, err := snowflake.NewNode(0)
|
node, err := snowflake.NewNode(0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create node: %v", err)
|
t.Fatalf("create node: %v", err)
|
||||||
}
|
}
|
||||||
f := &tcpStreamFactory{
|
mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node)
|
||||||
WorkerID: 0,
|
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0)
|
||||||
Logger: noopTestLogger{},
|
|
||||||
Node: node,
|
|
||||||
Ruleset: fixedRuleset{action: ruleset.ActionAllow},
|
|
||||||
}
|
|
||||||
|
|
||||||
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4())
|
l3 := L3Info{
|
||||||
tcp := &layers.TCP{
|
Version: 4,
|
||||||
|
Protocol: 6,
|
||||||
|
SrcIP: [4]byte{10, 0, 0, 1},
|
||||||
|
DstIP: [4]byte{10, 0, 0, 2},
|
||||||
|
}
|
||||||
|
tcp := TCPInfo{
|
||||||
SrcPort: 12345,
|
SrcPort: 12345,
|
||||||
DstPort: 443,
|
DstPort: 443,
|
||||||
}
|
Seq: 100,
|
||||||
ctx1 := &tcpContext{
|
|
||||||
PacketMetadata: &gopacket.PacketMetadata{},
|
|
||||||
Verdict: tcpVerdictAccept,
|
|
||||||
}
|
|
||||||
rs := f.New(ipFlow, tcp.TransportFlow(), tcp, ctx1)
|
|
||||||
s, ok := rs.(*tcpStream)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("unexpected stream type %T", rs)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
start1 := false
|
v := mgr.handle(1, l3, tcp, nil, nil, nil)
|
||||||
if !s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start1, ctx1) {
|
if v != io.VerdictAcceptStream {
|
||||||
t.Fatalf("unexpected Accept=false before first feed")
|
t.Fatalf("first verdict=%v want=%v", v, io.VerdictAcceptStream)
|
||||||
}
|
|
||||||
s.ReassembledSG(fakeScatterGather{data: []byte("first")}, ctx1)
|
|
||||||
if ctx1.Verdict != tcpVerdictAcceptStream {
|
|
||||||
t.Fatalf("verdict=%v want=%v", ctx1.Verdict, tcpVerdictAcceptStream)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil {
|
mgr.updateRuleset(fixedRuleset{action: ruleset.ActionBlock}, 1)
|
||||||
t.Fatalf("update ruleset: %v", err)
|
|
||||||
|
tcp2 := TCPInfo{
|
||||||
|
SrcPort: 12345,
|
||||||
|
DstPort: 443,
|
||||||
|
Seq: 100,
|
||||||
|
}
|
||||||
|
v = mgr.handle(2, l3, tcp2, []byte("data"), nil, nil)
|
||||||
|
if v != io.VerdictDropStream {
|
||||||
|
t.Fatalf("verdict after update=%v want=%v", v, io.VerdictDropStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx2 := &tcpContext{
|
tcp3 := TCPInfo{
|
||||||
PacketMetadata: &gopacket.PacketMetadata{},
|
SrcPort: 12345,
|
||||||
Verdict: tcpVerdictAccept,
|
DstPort: 443,
|
||||||
|
Seq: 104,
|
||||||
}
|
}
|
||||||
start2 := false
|
v = mgr.handle(1, l3, tcp3, nil, nil, nil)
|
||||||
if !s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start2, ctx2) {
|
if v != io.VerdictDropStream {
|
||||||
t.Fatalf("expected Accept=true after ruleset update")
|
t.Fatalf("cached verdict after update=%v want=%v", v, io.VerdictDropStream)
|
||||||
}
|
|
||||||
s.ReassembledSG(fakeScatterGather{data: []byte("second")}, ctx2)
|
|
||||||
if ctx2.Verdict != tcpVerdictDropStream {
|
|
||||||
t.Fatalf("verdict=%v want=%v", ctx2.Verdict, tcpVerdictDropStream)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx3 := &tcpContext{
|
|
||||||
PacketMetadata: &gopacket.PacketMetadata{},
|
|
||||||
Verdict: tcpVerdictAccept,
|
|
||||||
}
|
|
||||||
start3 := false
|
|
||||||
if s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start3, ctx3) {
|
|
||||||
t.Fatalf("expected Accept=false with unchanged ruleset and no active entries")
|
|
||||||
}
|
|
||||||
if ctx3.Verdict != tcpVerdictDropStream {
|
|
||||||
t.Fatalf("verdict=%v want=%v", ctx3.Verdict, tcpVerdictDropStream)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type fakeScatterGather struct {
|
|
||||||
data []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s fakeScatterGather) Lengths() (int, int) {
|
|
||||||
return len(s.data), 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s fakeScatterGather) Fetch(length int) []byte {
|
|
||||||
if length < 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if length > len(s.data) {
|
|
||||||
length = len(s.data)
|
|
||||||
}
|
|
||||||
return s.data[:length]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fakeScatterGather) KeepFrom(int) {}
|
|
||||||
|
|
||||||
func (fakeScatterGather) CaptureInfo(int) gopacket.CaptureInfo {
|
|
||||||
return gopacket.CaptureInfo{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fakeScatterGather) Info() (reassembly.TCPFlowDirection, bool, bool, int) {
|
|
||||||
return reassembly.TCPDirClientToServer, true, false, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fakeScatterGather) Stats() reassembly.TCPAssemblyStats {
|
|
||||||
return reassembly.TCPAssemblyStats{}
|
|
||||||
}
|
|
||||||
|
|||||||
+33
-33
@@ -49,55 +49,65 @@ type tcpFlowEntry struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
|
func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
|
||||||
if f.rulesetChanged() || f.virgin {
|
rs, version := f.currentRuleset()
|
||||||
f.virgin = false
|
rulesetChanged := version != f.rulesetVersion
|
||||||
return io.VerdictAccept
|
|
||||||
}
|
if !f.virgin && !rulesetChanged && len(f.activeEntries) == 0 {
|
||||||
if len(f.activeEntries) == 0 {
|
|
||||||
return f.lastVerdict
|
return f.lastVerdict
|
||||||
}
|
}
|
||||||
|
|
||||||
dir, rev := f.resolveDirection(tcp)
|
|
||||||
|
|
||||||
if tcp.RST || tcp.FIN {
|
if tcp.RST || tcp.FIN {
|
||||||
f.closeActiveEntries()
|
f.closeActiveEntries()
|
||||||
|
f.runMatch(rs, version, rulesetChanged)
|
||||||
f.maybeFinalizeVerdict()
|
f.maybeFinalizeVerdict()
|
||||||
return f.lastVerdict
|
return f.lastVerdict
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(payload) == 0 {
|
if len(payload) > 0 {
|
||||||
return io.VerdictAccept
|
dir, rev := f.resolveDirection(tcp)
|
||||||
}
|
|
||||||
|
|
||||||
expected := f.dirSeq[dir]
|
expected := f.dirSeq[dir]
|
||||||
if f.feedCalled[dir] && expected != 0 && tcp.Seq != expected {
|
if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected {
|
||||||
return io.VerdictAccept
|
|
||||||
}
|
|
||||||
f.feedCalled[dir] = true
|
f.feedCalled[dir] = true
|
||||||
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
|
f.dirBuf[dir] = append(f.dirBuf[dir], payload...)
|
||||||
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
|
f.dirSeq[dir] = tcp.Seq + uint32(len(payload))
|
||||||
|
if len(f.dirBuf[dir]) <= tcpFlowMaxBuffer {
|
||||||
if len(f.dirBuf[dir]) > tcpFlowMaxBuffer {
|
f.feedAnalyzers(rev)
|
||||||
return io.VerdictAccept
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
updated := false
|
f.runMatch(rs, version, rulesetChanged)
|
||||||
|
f.maybeFinalizeVerdict()
|
||||||
|
return f.lastVerdict
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *tcpFlow) feedAnalyzers(rev bool) {
|
||||||
|
buf := f.dirBuf[uint8(tcpDirC2S)]
|
||||||
|
if rev {
|
||||||
|
buf = f.dirBuf[uint8(tcpDirS2C)]
|
||||||
|
}
|
||||||
for i := len(f.activeEntries) - 1; i >= 0; i-- {
|
for i := len(f.activeEntries) - 1; i >= 0; i-- {
|
||||||
entry := f.activeEntries[i]
|
entry := f.activeEntries[i]
|
||||||
update, closeUpdate, done := feedFlowEntry(entry, rev, f.dirBuf[dir])
|
update, closeUpdate, done := feedFlowEntry(entry, rev, buf)
|
||||||
u1 := processPropUpdate(f.info.Props, entry.Name, update)
|
u1 := processPropUpdate(f.info.Props, entry.Name, update)
|
||||||
u2 := processPropUpdate(f.info.Props, entry.Name, closeUpdate)
|
u2 := processPropUpdate(f.info.Props, entry.Name, closeUpdate)
|
||||||
updated = updated || u1 || u2
|
if u1 || u2 {
|
||||||
|
f.logger.TCPStreamPropUpdate(f.info, false)
|
||||||
|
}
|
||||||
if done {
|
if done {
|
||||||
f.activeEntries = append(f.activeEntries[:i], f.activeEntries[i+1:]...)
|
f.activeEntries = append(f.activeEntries[:i], f.activeEntries[i+1:]...)
|
||||||
f.doneEntries = append(f.doneEntries, entry)
|
f.doneEntries = append(f.doneEntries, entry)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if updated {
|
func (f *tcpFlow) runMatch(rs ruleset.Ruleset, version uint64, rulesetChanged bool) {
|
||||||
f.logger.TCPStreamPropUpdate(f.info, false)
|
if !f.virgin && !rulesetChanged {
|
||||||
rs, version := f.currentRuleset()
|
return
|
||||||
|
}
|
||||||
|
f.virgin = false
|
||||||
f.rulesetVersion = version
|
f.rulesetVersion = version
|
||||||
|
|
||||||
result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
|
result := ruleset.MatchResult{Action: ruleset.ActionMaybe}
|
||||||
if rs != nil {
|
if rs != nil {
|
||||||
result = rs.Match(f.info)
|
result = rs.Match(f.info)
|
||||||
@@ -108,14 +118,9 @@ func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict {
|
|||||||
f.lastVerdict = verdict
|
f.lastVerdict = verdict
|
||||||
f.closeActiveEntries()
|
f.closeActiveEntries()
|
||||||
f.logger.TCPStreamAction(f.info, action, false)
|
f.logger.TCPStreamAction(f.info, action, false)
|
||||||
return verdict
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
f.maybeFinalizeVerdict()
|
|
||||||
return f.lastVerdict
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *tcpFlow) maybeFinalizeVerdict() {
|
func (f *tcpFlow) maybeFinalizeVerdict() {
|
||||||
if len(f.activeEntries) == 0 && f.lastVerdict == io.VerdictAccept {
|
if len(f.activeEntries) == 0 && f.lastVerdict == io.VerdictAccept {
|
||||||
f.lastVerdict = io.VerdictAcceptStream
|
f.lastVerdict = io.VerdictAcceptStream
|
||||||
@@ -137,11 +142,6 @@ func (f *tcpFlow) currentRuleset() (ruleset.Ruleset, uint64) {
|
|||||||
return f.rulesetSource()
|
return f.rulesetSource()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *tcpFlow) rulesetChanged() bool {
|
|
||||||
_, version := f.currentRuleset()
|
|
||||||
return version != f.rulesetVersion
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *tcpFlow) closeActiveEntries() {
|
func (f *tcpFlow) closeActiveEntries() {
|
||||||
updated := false
|
updated := false
|
||||||
for _, entry := range f.activeEntries {
|
for _, entry := range f.activeEntries {
|
||||||
|
|||||||
@@ -175,7 +175,6 @@ func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []by
|
|||||||
ipSrc := net.IP(l3.SrcIP[:])
|
ipSrc := net.IP(l3.SrcIP[:])
|
||||||
ipDst := net.IP(l3.DstIP[:])
|
ipDst := net.IP(l3.DstIP[:])
|
||||||
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, ipSrc.To4(), ipDst.To4())
|
ipFlow := gopacket.NewFlow(layers.EndpointIPv4, ipSrc.To4(), ipDst.To4())
|
||||||
udpFlow := gopacket.NewFlow(layers.EndpointUDPPort, []byte{byte(udp.SrcPort >> 8), byte(udp.SrcPort)}, []byte{byte(udp.DstPort >> 8), byte(udp.DstPort)})
|
|
||||||
|
|
||||||
if len(srcMAC) == 0 && w.macResolver != nil {
|
if len(srcMAC) == 0 && w.macResolver != nil {
|
||||||
srcMAC = w.macResolver.Resolve(ipSrc)
|
srcMAC = w.macResolver.Resolve(ipSrc)
|
||||||
|
|||||||
+1
-1
@@ -211,7 +211,7 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (nio *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
|
func (nio *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
|
||||||
for i, nq := range nio.nqs {
|
for _, nq := range nio.nqs {
|
||||||
nq := nq
|
nq := nq
|
||||||
err := nq.RegisterWithErrorFunc(ctx,
|
err := nq.RegisterWithErrorFunc(ctx,
|
||||||
func(a nfqueue.Attribute) int {
|
func(a nfqueue.Attribute) int {
|
||||||
|
|||||||
Reference in New Issue
Block a user