diff --git a/engine/reload_rules_test.go b/engine/reload_rules_test.go index fa3fdea..04ae7c4 100644 --- a/engine/reload_rules_test.go +++ b/engine/reload_rules_test.go @@ -5,12 +5,12 @@ import ( "testing" "git.difuse.io/Difuse/Mellaris/analyzer" + "git.difuse.io/Difuse/Mellaris/io" "git.difuse.io/Difuse/Mellaris/ruleset" "github.com/bwmarrin/snowflake" "github.com/google/gopacket" "github.com/google/gopacket/layers" - "github.com/google/gopacket/reassembly" ) type fixedRuleset struct { @@ -137,137 +137,88 @@ func TestUDPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) { } } -func TestTCPStreamUsesUpdatedRuleset(t *testing.T) { +func TestTCPFlowUsesUpdatedRuleset(t *testing.T) { node, err := snowflake.NewNode(0) if err != nil { t.Fatalf("create node: %v", err) } - f := &tcpStreamFactory{ - WorkerID: 0, - Logger: noopTestLogger{}, - Node: node, - Ruleset: fixedRuleset{action: ruleset.ActionAllow}, - } + mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node) + mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0) - ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) - tcp := &layers.TCP{ + l3 := L3Info{ + Version: 4, + Protocol: 6, + SrcIP: [4]byte{10, 0, 0, 1}, + DstIP: [4]byte{10, 0, 0, 2}, + } + tcp := TCPInfo{ SrcPort: 12345, DstPort: 443, - } - ctx := &tcpContext{ - PacketMetadata: &gopacket.PacketMetadata{}, - Verdict: tcpVerdictAccept, - } - rs := f.New(ipFlow, tcp.TransportFlow(), tcp, ctx) - s, ok := rs.(*tcpStream) - if !ok { - t.Fatalf("unexpected stream type %T", rs) + Seq: 100, } - if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil { - t.Fatalf("update ruleset: %v", err) + v := mgr.handle(1, l3, tcp, nil, nil, nil) + if v != io.VerdictAcceptStream { + t.Fatalf("first verdict=%v want=%v", v, io.VerdictAcceptStream) } - s.ReassembledSG(fakeScatterGather{data: []byte("payload")}, ctx) - if ctx.Verdict != tcpVerdictDropStream { - t.Fatalf("verdict=%v want=%v", ctx.Verdict, tcpVerdictDropStream) + mgr.updateRuleset(fixedRuleset{action: ruleset.ActionBlock}, 1) + + tcp2 := TCPInfo{ + SrcPort: 12345, + DstPort: 443, + Seq: 100, + } + v = mgr.handle(2, l3, tcp2, []byte("data"), nil, nil) + if v != io.VerdictDropStream { + t.Fatalf("verdict after update=%v want=%v", v, io.VerdictDropStream) } } -func TestTCPStreamReevaluatesAfterRulesetVersionChange(t *testing.T) { +func TestTCPFlowReevaluatesAfterRulesetVersionChange(t *testing.T) { node, err := snowflake.NewNode(0) if err != nil { t.Fatalf("create node: %v", err) } - f := &tcpStreamFactory{ - WorkerID: 0, - Logger: noopTestLogger{}, - Node: node, - Ruleset: fixedRuleset{action: ruleset.ActionAllow}, - } + mgr := newTCPFlowManager(0, noopTestLogger{}, nil, node) + mgr.updateRuleset(fixedRuleset{action: ruleset.ActionAllow}, 0) - ipFlow := gopacket.NewFlow(layers.EndpointIPv4, net.IPv4(10, 0, 0, 1).To4(), net.IPv4(10, 0, 0, 2).To4()) - tcp := &layers.TCP{ + l3 := L3Info{ + Version: 4, + Protocol: 6, + SrcIP: [4]byte{10, 0, 0, 1}, + DstIP: [4]byte{10, 0, 0, 2}, + } + tcp := TCPInfo{ SrcPort: 12345, DstPort: 443, - } - ctx1 := &tcpContext{ - PacketMetadata: &gopacket.PacketMetadata{}, - Verdict: tcpVerdictAccept, - } - rs := f.New(ipFlow, tcp.TransportFlow(), tcp, ctx1) - s, ok := rs.(*tcpStream) - if !ok { - t.Fatalf("unexpected stream type %T", rs) + Seq: 100, } - start1 := false - if !s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start1, ctx1) { - t.Fatalf("unexpected Accept=false before first feed") - } - s.ReassembledSG(fakeScatterGather{data: []byte("first")}, ctx1) - if ctx1.Verdict != tcpVerdictAcceptStream { - t.Fatalf("verdict=%v want=%v", ctx1.Verdict, tcpVerdictAcceptStream) + v := mgr.handle(1, l3, tcp, nil, nil, nil) + if v != io.VerdictAcceptStream { + t.Fatalf("first verdict=%v want=%v", v, io.VerdictAcceptStream) } - if err := f.UpdateRuleset(fixedRuleset{action: ruleset.ActionBlock}); err != nil { - t.Fatalf("update ruleset: %v", err) + mgr.updateRuleset(fixedRuleset{action: ruleset.ActionBlock}, 1) + + tcp2 := TCPInfo{ + SrcPort: 12345, + DstPort: 443, + Seq: 100, + } + v = mgr.handle(2, l3, tcp2, []byte("data"), nil, nil) + if v != io.VerdictDropStream { + t.Fatalf("verdict after update=%v want=%v", v, io.VerdictDropStream) } - ctx2 := &tcpContext{ - PacketMetadata: &gopacket.PacketMetadata{}, - Verdict: tcpVerdictAccept, + tcp3 := TCPInfo{ + SrcPort: 12345, + DstPort: 443, + Seq: 104, } - start2 := false - if !s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start2, ctx2) { - t.Fatalf("expected Accept=true after ruleset update") - } - s.ReassembledSG(fakeScatterGather{data: []byte("second")}, ctx2) - if ctx2.Verdict != tcpVerdictDropStream { - t.Fatalf("verdict=%v want=%v", ctx2.Verdict, tcpVerdictDropStream) - } - - ctx3 := &tcpContext{ - PacketMetadata: &gopacket.PacketMetadata{}, - Verdict: tcpVerdictAccept, - } - start3 := false - if s.Accept(tcp, gopacket.CaptureInfo{}, reassembly.TCPDirClientToServer, 0, &start3, ctx3) { - t.Fatalf("expected Accept=false with unchanged ruleset and no active entries") - } - if ctx3.Verdict != tcpVerdictDropStream { - t.Fatalf("verdict=%v want=%v", ctx3.Verdict, tcpVerdictDropStream) + v = mgr.handle(1, l3, tcp3, nil, nil, nil) + if v != io.VerdictDropStream { + t.Fatalf("cached verdict after update=%v want=%v", v, io.VerdictDropStream) } } - -type fakeScatterGather struct { - data []byte -} - -func (s fakeScatterGather) Lengths() (int, int) { - return len(s.data), 0 -} - -func (s fakeScatterGather) Fetch(length int) []byte { - if length < 0 { - return nil - } - if length > len(s.data) { - length = len(s.data) - } - return s.data[:length] -} - -func (fakeScatterGather) KeepFrom(int) {} - -func (fakeScatterGather) CaptureInfo(int) gopacket.CaptureInfo { - return gopacket.CaptureInfo{} -} - -func (fakeScatterGather) Info() (reassembly.TCPFlowDirection, bool, bool, int) { - return reassembly.TCPDirClientToServer, true, false, 0 -} - -func (fakeScatterGather) Stats() reassembly.TCPAssemblyStats { - return reassembly.TCPAssemblyStats{} -} diff --git a/engine/tcp_flow.go b/engine/tcp_flow.go index 17020c3..d1c5cca 100644 --- a/engine/tcp_flow.go +++ b/engine/tcp_flow.go @@ -49,71 +49,76 @@ type tcpFlowEntry struct { } func (f *tcpFlow) feed(l3 L3Info, tcp TCPInfo, payload []byte) io.Verdict { - if f.rulesetChanged() || f.virgin { - f.virgin = false - return io.VerdictAccept - } - if len(f.activeEntries) == 0 { + rs, version := f.currentRuleset() + rulesetChanged := version != f.rulesetVersion + + if !f.virgin && !rulesetChanged && len(f.activeEntries) == 0 { return f.lastVerdict } - dir, rev := f.resolveDirection(tcp) - if tcp.RST || tcp.FIN { f.closeActiveEntries() + f.runMatch(rs, version, rulesetChanged) f.maybeFinalizeVerdict() return f.lastVerdict } - if len(payload) == 0 { - return io.VerdictAccept + if len(payload) > 0 { + dir, rev := f.resolveDirection(tcp) + expected := f.dirSeq[dir] + if !f.feedCalled[dir] || expected == 0 || tcp.Seq == expected { + f.feedCalled[dir] = true + f.dirBuf[dir] = append(f.dirBuf[dir], payload...) + f.dirSeq[dir] = tcp.Seq + uint32(len(payload)) + if len(f.dirBuf[dir]) <= tcpFlowMaxBuffer { + f.feedAnalyzers(rev) + } + } } - expected := f.dirSeq[dir] - if f.feedCalled[dir] && expected != 0 && tcp.Seq != expected { - return io.VerdictAccept - } - f.feedCalled[dir] = true - f.dirBuf[dir] = append(f.dirBuf[dir], payload...) - f.dirSeq[dir] = tcp.Seq + uint32(len(payload)) + f.runMatch(rs, version, rulesetChanged) + f.maybeFinalizeVerdict() + return f.lastVerdict +} - if len(f.dirBuf[dir]) > tcpFlowMaxBuffer { - return io.VerdictAccept +func (f *tcpFlow) feedAnalyzers(rev bool) { + buf := f.dirBuf[uint8(tcpDirC2S)] + if rev { + buf = f.dirBuf[uint8(tcpDirS2C)] } - - updated := false for i := len(f.activeEntries) - 1; i >= 0; i-- { entry := f.activeEntries[i] - update, closeUpdate, done := feedFlowEntry(entry, rev, f.dirBuf[dir]) + update, closeUpdate, done := feedFlowEntry(entry, rev, buf) u1 := processPropUpdate(f.info.Props, entry.Name, update) u2 := processPropUpdate(f.info.Props, entry.Name, closeUpdate) - updated = updated || u1 || u2 + if u1 || u2 { + f.logger.TCPStreamPropUpdate(f.info, false) + } if done { f.activeEntries = append(f.activeEntries[:i], f.activeEntries[i+1:]...) f.doneEntries = append(f.doneEntries, entry) } } +} - if updated { - f.logger.TCPStreamPropUpdate(f.info, false) - rs, version := f.currentRuleset() - f.rulesetVersion = version - result := ruleset.MatchResult{Action: ruleset.ActionMaybe} - if rs != nil { - result = rs.Match(f.info) - } - action := result.Action - if action != ruleset.ActionMaybe && action != ruleset.ActionModify { - verdict := actionToTCPVerdict(action) - f.lastVerdict = verdict - f.closeActiveEntries() - f.logger.TCPStreamAction(f.info, action, false) - return verdict - } +func (f *tcpFlow) runMatch(rs ruleset.Ruleset, version uint64, rulesetChanged bool) { + if !f.virgin && !rulesetChanged { + return } + f.virgin = false + f.rulesetVersion = version - f.maybeFinalizeVerdict() - return f.lastVerdict + result := ruleset.MatchResult{Action: ruleset.ActionMaybe} + if rs != nil { + result = rs.Match(f.info) + } + action := result.Action + if action != ruleset.ActionMaybe && action != ruleset.ActionModify { + verdict := actionToTCPVerdict(action) + f.lastVerdict = verdict + f.closeActiveEntries() + f.logger.TCPStreamAction(f.info, action, false) + } } func (f *tcpFlow) maybeFinalizeVerdict() { @@ -137,11 +142,6 @@ func (f *tcpFlow) currentRuleset() (ruleset.Ruleset, uint64) { return f.rulesetSource() } -func (f *tcpFlow) rulesetChanged() bool { - _, version := f.currentRuleset() - return version != f.rulesetVersion -} - func (f *tcpFlow) closeActiveEntries() { updated := false for _, entry := range f.activeEntries { diff --git a/engine/worker.go b/engine/worker.go index ade57a7..96fbe86 100644 --- a/engine/worker.go +++ b/engine/worker.go @@ -175,7 +175,6 @@ func (w *worker) handleUDP(streamID uint32, l3 L3Info, udp UDPInfo, payload []by ipSrc := net.IP(l3.SrcIP[:]) ipDst := net.IP(l3.DstIP[:]) ipFlow := gopacket.NewFlow(layers.EndpointIPv4, ipSrc.To4(), ipDst.To4()) - udpFlow := gopacket.NewFlow(layers.EndpointUDPPort, []byte{byte(udp.SrcPort >> 8), byte(udp.SrcPort)}, []byte{byte(udp.DstPort >> 8), byte(udp.DstPort)}) if len(srcMAC) == 0 && w.macResolver != nil { srcMAC = w.macResolver.Resolve(ipSrc) diff --git a/io/nfqueue.go b/io/nfqueue.go index 3fb0247..e880a4a 100644 --- a/io/nfqueue.go +++ b/io/nfqueue.go @@ -211,7 +211,7 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) { } func (nio *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error { - for i, nq := range nio.nqs { + for _, nq := range nio.nqs { nq := nq err := nq.RegisterWithErrorFunc(ctx, func(a nfqueue.Attribute) int {