diff --git a/analyzer/udp/internal/quic/payload.go b/analyzer/udp/internal/quic/payload.go index 6bb670f..769391a 100644 --- a/analyzer/udp/internal/quic/payload.go +++ b/analyzer/udp/internal/quic/payload.go @@ -2,11 +2,13 @@ package quic import ( "bytes" + "container/list" "crypto" "errors" "fmt" "io" "sort" + "sync" "github.com/quic-go/quic-go/quicvarint" "golang.org/x/crypto/hkdf" @@ -16,10 +18,30 @@ var defaultPNMaxGuesses = []int64{ 0, 1, 2, 3, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, } +const ( + initialSecretLabelClientIn = "client in" + initialSecretLabelServerIn = "server in" +) + +const initialProtectorCacheSize = 512 + +var initialProtectorCache = newInitialPacketProtectorCache(initialProtectorCacheSize) + +type DecryptSuccessHint struct { + ConnectionID []byte + SecretLabel string + PacketNumberMax int64 +} + type ReadCryptoFramesOptions struct { AdditionalConnectionIDs [][]byte TryServerSecret bool PacketNumberMaxGuesses []int64 + PreferredConnectionID []byte + PreferredSecretLabel string + PreferredPNMax int64 + HasPreferredPNMax bool + SuccessHint *DecryptSuccessHint } func ReadCryptoPayload(packet []byte) ([]byte, error) { @@ -61,33 +83,18 @@ func ReadCryptoFramesWithOptions(packet []byte, opts *ReadCryptoFramesOptions) ( } packetView := packet[:offset+hdr.Length] - candidateConnIDs := [][]byte{hdr.DestConnectionID} - if opts != nil { - candidateConnIDs = append(candidateConnIDs, opts.AdditionalConnectionIDs...) - } - candidateConnIDs = uniqueNonEmptyConnectionIDs(candidateConnIDs) - - pnMaxGuesses := defaultPNMaxGuesses - if opts != nil && len(opts.PacketNumberMaxGuesses) > 0 { - pnMaxGuesses = opts.PacketNumberMaxGuesses - } - - labels := []string{"client in"} - if opts != nil && opts.TryServerSecret { - labels = append(labels, "server in") - } + candidateConnIDs := collectConnectionIDCandidates(hdr, opts) + pnMaxGuesses := collectPacketNumberMaxGuesses(opts) + labels := collectSecretLabels(opts) var lastErr error for _, connID := range candidateConnIDs { - initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(hdr.Version)) for _, label := range labels { - secret := hkdfExpandLabel(crypto.SHA256.New, initialSecret, label, []byte{}, crypto.SHA256.Size()) - key, err := NewInitialProtectionKey(secret, hdr.Version) + pp, err := getOrCreateInitialPacketProtector(hdr.Version, connID, label) if err != nil { - lastErr = fmt.Errorf("NewInitialProtectionKey: %w", err) + lastErr = err continue } - pp := NewPacketProtector(key) for _, pnMax := range pnMaxGuesses { packetCopy := append([]byte(nil), packetView...) unProtectedPayload, err := pp.UnProtect(packetCopy, offset, pnMax) @@ -100,6 +107,11 @@ func ReadCryptoFramesWithOptions(packet []byte, opts *ReadCryptoFramesOptions) ( lastErr = err continue } + if opts != nil && opts.SuccessHint != nil { + opts.SuccessHint.ConnectionID = append(opts.SuccessHint.ConnectionID[:0], connID...) + opts.SuccessHint.SecretLabel = label + opts.SuccessHint.PacketNumberMax = pnMax + } return frs, nil } } @@ -337,6 +349,68 @@ func skipN(r *bytes.Reader, n uint64) error { return err } +func collectConnectionIDCandidates(hdr *Header, opts *ReadCryptoFramesOptions) [][]byte { + ids := make([][]byte, 0, 2) + if opts != nil && len(opts.PreferredConnectionID) > 0 { + ids = append(ids, opts.PreferredConnectionID) + } + ids = append(ids, hdr.DestConnectionID) + if opts != nil { + ids = append(ids, opts.AdditionalConnectionIDs...) + } + return uniqueNonEmptyConnectionIDs(ids) +} + +func collectPacketNumberMaxGuesses(opts *ReadCryptoFramesOptions) []int64 { + guesses := defaultPNMaxGuesses + if opts != nil && len(opts.PacketNumberMaxGuesses) > 0 { + guesses = opts.PacketNumberMaxGuesses + } + if opts == nil || !opts.HasPreferredPNMax { + return uniqueInt64PreserveOrder(guesses) + } + out := make([]int64, 0, len(guesses)+1) + out = append(out, opts.PreferredPNMax) + out = append(out, guesses...) + return uniqueInt64PreserveOrder(out) +} + +func collectSecretLabels(opts *ReadCryptoFramesOptions) []string { + labels := []string{initialSecretLabelClientIn} + if opts != nil && opts.TryServerSecret { + labels = append(labels, initialSecretLabelServerIn) + } + if opts == nil || opts.PreferredSecretLabel == "" { + return labels + } + return prependStringIfPresent(labels, opts.PreferredSecretLabel) +} + +func prependStringIfPresent(base []string, preferred string) []string { + if preferred == "" { + return base + } + has := false + for _, s := range base { + if s == preferred { + has = true + break + } + } + if !has { + return base + } + out := make([]string, 0, len(base)) + out = append(out, preferred) + for _, s := range base { + if s == preferred { + continue + } + out = append(out, s) + } + return out +} + func uniqueNonEmptyConnectionIDs(ids [][]byte) [][]byte { out := make([][]byte, 0, len(ids)) seen := make(map[string]struct{}, len(ids)) @@ -354,6 +428,103 @@ func uniqueNonEmptyConnectionIDs(ids [][]byte) [][]byte { return out } +func uniqueInt64PreserveOrder(values []int64) []int64 { + out := make([]int64, 0, len(values)) + seen := make(map[int64]struct{}, len(values)) + for _, v := range values { + if _, ok := seen[v]; ok { + continue + } + seen[v] = struct{}{} + out = append(out, v) + } + return out +} + +type initialPacketProtectorCacheKey struct { + Version uint32 + ConnID string + Label string +} + +type initialPacketProtectorCacheEntry struct { + Key initialPacketProtectorCacheKey + Value *PacketProtector +} + +type initialPacketProtectorCache struct { + mutex sync.Mutex + capacity int + ll *list.List + items map[initialPacketProtectorCacheKey]*list.Element +} + +func newInitialPacketProtectorCache(capacity int) *initialPacketProtectorCache { + if capacity <= 0 { + capacity = 1 + } + return &initialPacketProtectorCache{ + capacity: capacity, + ll: list.New(), + items: make(map[initialPacketProtectorCacheKey]*list.Element, capacity), + } +} + +func (c *initialPacketProtectorCache) Get(key initialPacketProtectorCacheKey) (*PacketProtector, bool) { + c.mutex.Lock() + defer c.mutex.Unlock() + elem, ok := c.items[key] + if !ok { + return nil, false + } + c.ll.MoveToFront(elem) + return elem.Value.(*initialPacketProtectorCacheEntry).Value, true +} + +func (c *initialPacketProtectorCache) Add(key initialPacketProtectorCacheKey, value *PacketProtector) { + c.mutex.Lock() + defer c.mutex.Unlock() + if elem, ok := c.items[key]; ok { + entry := elem.Value.(*initialPacketProtectorCacheEntry) + entry.Value = value + c.ll.MoveToFront(elem) + return + } + elem := c.ll.PushFront(&initialPacketProtectorCacheEntry{ + Key: key, + Value: value, + }) + c.items[key] = elem + for c.ll.Len() > c.capacity { + oldest := c.ll.Back() + if oldest == nil { + break + } + c.ll.Remove(oldest) + delete(c.items, oldest.Value.(*initialPacketProtectorCacheEntry).Key) + } +} + +func getOrCreateInitialPacketProtector(version uint32, connID []byte, label string) (*PacketProtector, error) { + key := initialPacketProtectorCacheKey{ + Version: version, + ConnID: string(connID), + Label: label, + } + if cached, ok := initialProtectorCache.Get(key); ok { + return cached, nil + } + initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(version)) + secret := hkdfExpandLabel(crypto.SHA256.New, initialSecret, label, []byte{}, crypto.SHA256.Size()) + protectionKey, err := NewInitialProtectionKey(secret, version) + if err != nil { + return nil, fmt.Errorf("NewInitialProtectionKey: %w", err) + } + protector := NewPacketProtector(protectionKey) + initialProtectorCache.Add(key, protector) + return protector, nil +} + // assembleCryptoFrames assembles multiple crypto frames into a single slice (if possible). // It returns an error if the frames cannot be assembled. This can happen if the frames are not contiguous. func assembleCryptoFrames(frames []CryptoFrame) []byte { diff --git a/analyzer/udp/internal/quic/payload_test.go b/analyzer/udp/internal/quic/payload_test.go index 4fe7e21..5b6f2ae 100644 --- a/analyzer/udp/internal/quic/payload_test.go +++ b/analyzer/udp/internal/quic/payload_test.go @@ -66,3 +66,114 @@ func TestReadCryptoFrames_NonInitialHeader(t *testing.T) { t.Fatalf("ReadCryptoFrames() error = %v, want %v", err, ErrNotInitialPacket) } } + +func TestCollectPacketNumberMaxGuesses_Prefer(t *testing.T) { + got := collectPacketNumberMaxGuesses(&ReadCryptoFramesOptions{ + PacketNumberMaxGuesses: []int64{1, 2, 3, 2}, + PreferredPNMax: 2, + HasPreferredPNMax: true, + }) + want := []int64{2, 1, 3} + if len(got) != len(want) { + t.Fatalf("collectPacketNumberMaxGuesses() len=%d, want=%d", len(got), len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("collectPacketNumberMaxGuesses()[%d]=%d, want=%d (got=%v)", i, got[i], want[i], got) + } + } +} + +func TestReadCryptoFramesWithOptions_HintRoundTrip(t *testing.T) { + // Example packet from quic.xargs.org client Initial, padded to 1200 bytes. + packet := make([]byte, 1200) + clientInitial := []byte{ + 0xcd, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, + 0x06, 0x07, 0x05, 0x63, 0x5f, 0x63, 0x69, 0x64, 0x00, 0x41, 0x03, 0x98, + 0x1c, 0x36, 0xa7, 0xed, 0x78, 0x71, 0x6b, 0xe9, 0x71, 0x1b, 0xa4, 0x98, + 0xb7, 0xed, 0x86, 0x84, 0x43, 0xbb, 0x2e, 0x0c, 0x51, 0x4d, 0x4d, 0x84, + 0x8e, 0xad, 0xcc, 0x7a, 0x00, 0xd2, 0x5c, 0xe9, 0xf9, 0xaf, 0xa4, 0x83, + 0x97, 0x80, 0x88, 0xde, 0x83, 0x6b, 0xe6, 0x8c, 0x0b, 0x32, 0xa2, 0x45, + 0x95, 0xd7, 0x81, 0x3e, 0xa5, 0x41, 0x4a, 0x91, 0x99, 0x32, 0x9a, 0x6d, + 0x9f, 0x7f, 0x76, 0x0d, 0xd8, 0xbb, 0x24, 0x9b, 0xf3, 0xf5, 0x3d, 0x9a, + 0x77, 0xfb, 0xb7, 0xb3, 0x95, 0xb8, 0xd6, 0x6d, 0x78, 0x79, 0xa5, 0x1f, + 0xe5, 0x9e, 0xf9, 0x60, 0x1f, 0x79, 0x99, 0x8e, 0xb3, 0x56, 0x8e, 0x1f, + 0xdc, 0x78, 0x9f, 0x64, 0x0a, 0xca, 0xb3, 0x85, 0x8a, 0x82, 0xef, 0x29, + 0x30, 0xfa, 0x5c, 0xe1, 0x4b, 0x5b, 0x9e, 0xa0, 0xbd, 0xb2, 0x9f, 0x45, + 0x72, 0xda, 0x85, 0xaa, 0x3d, 0xef, 0x39, 0xb7, 0xef, 0xaf, 0xff, 0xa0, + 0x74, 0xb9, 0x26, 0x70, 0x70, 0xd5, 0x0b, 0x5d, 0x07, 0x84, 0x2e, 0x49, + 0xbb, 0xa3, 0xbc, 0x78, 0x7f, 0xf2, 0x95, 0xd6, 0xae, 0x3b, 0x51, 0x43, + 0x05, 0xf1, 0x02, 0xaf, 0xe5, 0xa0, 0x47, 0xb3, 0xfb, 0x4c, 0x99, 0xeb, + 0x92, 0xa2, 0x74, 0xd2, 0x44, 0xd6, 0x04, 0x92, 0xc0, 0xe2, 0xe6, 0xe2, + 0x12, 0xce, 0xf0, 0xf9, 0xe3, 0xf6, 0x2e, 0xfd, 0x09, 0x55, 0xe7, 0x1c, + 0x76, 0x8a, 0xa6, 0xbb, 0x3c, 0xd8, 0x0b, 0xbb, 0x37, 0x55, 0xc8, 0xb7, + 0xeb, 0xee, 0x32, 0x71, 0x2f, 0x40, 0xf2, 0x24, 0x51, 0x19, 0x48, 0x70, + 0x21, 0xb4, 0xb8, 0x4e, 0x15, 0x65, 0xe3, 0xca, 0x31, 0x96, 0x7a, 0xc8, + 0x60, 0x4d, 0x40, 0x32, 0x17, 0x0d, 0xec, 0x28, 0x0a, 0xee, 0xfa, 0x09, + 0x5d, 0x08, 0xb3, 0xb7, 0x24, 0x1e, 0xf6, 0x64, 0x6a, 0x6c, 0x86, 0xe5, + 0xc6, 0x2c, 0xe0, 0x8b, 0xe0, 0x99, + } + copy(packet, clientInitial) + + firstHint := &DecryptSuccessHint{} + frames, err := ReadCryptoFramesWithOptions(packet, &ReadCryptoFramesOptions{ + TryServerSecret: true, + SuccessHint: firstHint, + }) + if err != nil { + t.Fatalf("ReadCryptoFramesWithOptions(first) error=%v", err) + } + if len(frames) == 0 { + t.Fatal("ReadCryptoFramesWithOptions(first) got no frames, want > 0") + } + if firstHint.SecretLabel != initialSecretLabelClientIn { + t.Fatalf("firstHint.SecretLabel=%q, want=%q", firstHint.SecretLabel, initialSecretLabelClientIn) + } + if len(firstHint.ConnectionID) == 0 { + t.Fatal("firstHint.ConnectionID empty, want non-empty") + } + + secondHint := &DecryptSuccessHint{} + frames, err = ReadCryptoFramesWithOptions(packet, &ReadCryptoFramesOptions{ + TryServerSecret: true, + PreferredConnectionID: firstHint.ConnectionID, + PreferredSecretLabel: firstHint.SecretLabel, + PreferredPNMax: firstHint.PacketNumberMax, + HasPreferredPNMax: true, + SuccessHint: secondHint, + }) + if err != nil { + t.Fatalf("ReadCryptoFramesWithOptions(second) error=%v", err) + } + if len(frames) == 0 { + t.Fatal("ReadCryptoFramesWithOptions(second) got no frames, want > 0") + } + if secondHint.PacketNumberMax != firstHint.PacketNumberMax { + t.Fatalf("secondHint.PacketNumberMax=%d, want=%d", secondHint.PacketNumberMax, firstHint.PacketNumberMax) + } +} + +func TestGetOrCreateInitialPacketProtector_CacheReuse(t *testing.T) { + packet := mustHexDecodeString(` + c7ff0000200008f067a5502a4262b500 4075fb12ff07823a5d24534d906ce4c7 + 6782a2167e3479c0f7f6395dc2c91676 302fe6d70bb7cbeb117b4ddb7d173498 + 44fd61dae200b8338e1b932976b61d91 e64a02e9e0ee72e3a6f63aba4ceeeec5 + be2f24f2d86027572943533846caa13e 6f163fb257473d0eda5047360fd4a47e + fd8142fafc0f76 + `) + hdr, _, err := ParseInitialHeader(packet) + if err != nil { + t.Fatalf("ParseInitialHeader() error=%v", err) + } + got1, err := getOrCreateInitialPacketProtector(hdr.Version, hdr.DestConnectionID, initialSecretLabelServerIn) + if err != nil { + t.Fatalf("getOrCreateInitialPacketProtector() #1 error=%v", err) + } + got2, err := getOrCreateInitialPacketProtector(hdr.Version, hdr.DestConnectionID, initialSecretLabelServerIn) + if err != nil { + t.Fatalf("getOrCreateInitialPacketProtector() #2 error=%v", err) + } + if got1 != got2 { + t.Fatal("expected cache hit to return same protector pointer") + } +} diff --git a/analyzer/udp/quic.go b/analyzer/udp/quic.go index 2e39f0d..cd3cdcf 100644 --- a/analyzer/udp/quic.go +++ b/analyzer/udp/quic.go @@ -45,6 +45,8 @@ type quicStream struct { frames map[int64][]byte maxEnd int64 connIDs [][]byte + lastHint quic.DecryptSuccessHint + hasLastHint bool } func (s *quicStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, done bool) { @@ -59,10 +61,19 @@ func (s *quicStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, done b s.rememberConnID(hdr.DestConnectionID) s.rememberConnID(hdr.SrcConnectionID) } - frs, err := quic.ReadCryptoFramesWithOptions(data, &quic.ReadCryptoFramesOptions{ + hint := quic.DecryptSuccessHint{} + opts := &quic.ReadCryptoFramesOptions{ AdditionalConnectionIDs: s.connIDs, TryServerSecret: true, - }) + SuccessHint: &hint, + } + if s.hasLastHint { + opts.PreferredConnectionID = append([]byte(nil), s.lastHint.ConnectionID...) + opts.PreferredSecretLabel = s.lastHint.SecretLabel + opts.PreferredPNMax = s.lastHint.PacketNumberMax + opts.HasPreferredPNMax = true + } + frs, err := quic.ReadCryptoFramesWithOptions(data, opts) if err != nil { if errors.Is(err, quic.ErrNotInitialPacket) { return nil, false @@ -86,6 +97,11 @@ func (s *quicStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, done b s.invalidCount++ return nil, s.invalidCount >= quicInvalidCountThreshold } + if len(hint.ConnectionID) > 0 { + s.lastHint = hint + s.hasLastHint = true + s.rememberConnID(hint.ConnectionID) + } if len(frs) == 0 { s.invalidCount++ return nil, s.invalidCount >= quicInvalidCountThreshold