package quic import ( "bytes" "container/list" "crypto" "errors" "fmt" "io" "sort" "sync" "github.com/quic-go/quic-go/quicvarint" "golang.org/x/crypto/hkdf" ) var defaultPNMaxGuesses = []int64{ 0, 1, 2, 3, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, } const ( initialSecretLabelClientIn = "client in" initialSecretLabelServerIn = "server in" ) const initialProtectorCacheSize = 512 var initialProtectorCache = newInitialPacketProtectorCache(initialProtectorCacheSize) type DecryptSuccessHint struct { ConnectionID []byte SecretLabel string PacketNumberMax int64 } type ReadCryptoFramesOptions struct { AdditionalConnectionIDs [][]byte TryServerSecret bool PacketNumberMaxGuesses []int64 PreferredConnectionID []byte PreferredSecretLabel string PreferredPNMax int64 HasPreferredPNMax bool SuccessHint *DecryptSuccessHint } func ReadCryptoPayload(packet []byte) ([]byte, error) { frs, err := ReadCryptoFrames(packet) if err != nil { return nil, err } data := assembleCryptoFrames(frs) if data == nil { return nil, errors.New("unable to assemble crypto frames") } return data, nil } // ReadCryptoFrames decrypts a QUIC Initial client packet and returns CRYPTO frames. func ReadCryptoFrames(packet []byte) ([]CryptoFrame, error) { return ReadCryptoFramesWithOptions(packet, nil) } // ReadCryptoFramesWithOptions decrypts a QUIC Initial packet and returns CRYPTO frames. func ReadCryptoFramesWithOptions(packet []byte, opts *ReadCryptoFramesOptions) ([]CryptoFrame, error) { hdr, offset, err := ParseInitialHeader(packet) if err != nil { return nil, err } // Some sanity checks if hdr.Version != V1 && hdr.Version != V2 { return nil, fmt.Errorf("unsupported version: %x", hdr.Version) } if offset == 0 || hdr.Length == 0 { return nil, errors.New("invalid packet") } // https://datatracker.ietf.org/doc/html/draft-ietf-quic-tls-32#name-client-initial // // "The unprotected header includes the connection ID and a 4-byte packet number encoding for a packet number of 2" if int64(len(packet)) < offset+hdr.Length { return nil, fmt.Errorf("packet is too short: %d < %d", len(packet), offset+hdr.Length) } packetView := packet[:offset+hdr.Length] candidateConnIDs := collectConnectionIDCandidates(hdr, opts) pnMaxGuesses := collectPacketNumberMaxGuesses(opts) labels := collectSecretLabels(opts) var lastErr error for _, connID := range candidateConnIDs { for _, label := range labels { pp, err := getOrCreateInitialPacketProtector(hdr.Version, connID, label) if err != nil { lastErr = err continue } for _, pnMax := range pnMaxGuesses { packetCopy := append([]byte(nil), packetView...) unProtectedPayload, err := pp.UnProtect(packetCopy, offset, pnMax) if err != nil { lastErr = err continue } frs, err := extractCryptoFrames(bytes.NewReader(unProtectedPayload)) if err != nil { lastErr = err continue } if opts != nil && opts.SuccessHint != nil { opts.SuccessHint.ConnectionID = append(opts.SuccessHint.ConnectionID[:0], connID...) opts.SuccessHint.SecretLabel = label opts.SuccessHint.PacketNumberMax = pnMax } return frs, nil } } } if lastErr != nil { return nil, lastErr } return nil, errors.New("unable to decrypt initial packet") } const ( paddingFrameType = 0x00 pingFrameType = 0x01 cryptoFrameType = 0x06 ) type CryptoFrame struct { Offset int64 Data []byte } func extractCryptoFrames(r *bytes.Reader) ([]CryptoFrame, error) { var frames []CryptoFrame for r.Len() > 0 { typ, err := quicvarint.Read(r) if err != nil { return nil, err } switch typ { case paddingFrameType, pingFrameType, 0x1e: // PADDING, PING, HANDSHAKE_DONE: no payload. continue case 0x02, 0x03: // ACK, ACK_ECN if _, err := quicvarint.Read(r); err != nil { // Largest Acknowledged return nil, err } if _, err := quicvarint.Read(r); err != nil { // ACK Delay return nil, err } ackRangeCount, err := quicvarint.Read(r) if err != nil { return nil, err } if _, err := quicvarint.Read(r); err != nil { // First ACK Range return nil, err } for i := uint64(0); i < ackRangeCount; i++ { if _, err := quicvarint.Read(r); err != nil { // Gap return nil, err } if _, err := quicvarint.Read(r); err != nil { // ACK Range Length return nil, err } } if typ == 0x03 { if _, err := quicvarint.Read(r); err != nil { // ECT0 Count return nil, err } if _, err := quicvarint.Read(r); err != nil { // ECT1 Count return nil, err } if _, err := quicvarint.Read(r); err != nil { // ECN-CE Count return nil, err } } case 0x04: // RESET_STREAM if _, err := quicvarint.Read(r); err != nil { // Stream ID return nil, err } if _, err := quicvarint.Read(r); err != nil { // Application Error Code return nil, err } if _, err := quicvarint.Read(r); err != nil { // Final Size return nil, err } case 0x05: // STOP_SENDING if _, err := quicvarint.Read(r); err != nil { // Stream ID return nil, err } if _, err := quicvarint.Read(r); err != nil { // Application Error Code return nil, err } case cryptoFrameType: // CRYPTO var frame CryptoFrame offset, err := quicvarint.Read(r) if err != nil { return nil, err } frame.Offset = int64(offset) dataLen, err := quicvarint.Read(r) if err != nil { return nil, err } frame.Data = make([]byte, dataLen) if _, err := io.ReadFull(r, frame.Data); err != nil { return nil, err } frames = append(frames, frame) case 0x07: // NEW_TOKEN tokenLen, err := quicvarint.Read(r) if err != nil { return nil, err } if err := skipN(r, tokenLen); err != nil { return nil, err } case 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f: // STREAM if _, err := quicvarint.Read(r); err != nil { // Stream ID return nil, err } hasOffset := typ&0x04 != 0 hasLength := typ&0x02 != 0 if hasOffset { if _, err := quicvarint.Read(r); err != nil { // Offset return nil, err } } var dataLen uint64 if hasLength { n, err := quicvarint.Read(r) if err != nil { return nil, err } dataLen = n } else { dataLen = uint64(r.Len()) } if err := skipN(r, dataLen); err != nil { return nil, err } case 0x10, 0x12, 0x13, 0x14, 0x16, 0x17, 0x19: // MAX_DATA, MAX_STREAMS_*, DATA_BLOCKED, STREAMS_BLOCKED_*, RETIRE_CONNECTION_ID if _, err := quicvarint.Read(r); err != nil { return nil, err } case 0x11, 0x15: // MAX_STREAM_DATA, STREAM_DATA_BLOCKED if _, err := quicvarint.Read(r); err != nil { return nil, err } if _, err := quicvarint.Read(r); err != nil { return nil, err } case 0x18: // NEW_CONNECTION_ID if _, err := quicvarint.Read(r); err != nil { // Sequence Number return nil, err } if _, err := quicvarint.Read(r); err != nil { // Retire Prior To return nil, err } cidLen, err := r.ReadByte() if err != nil { return nil, err } if cidLen > 20 { return nil, fmt.Errorf("invalid connection ID length: %d", cidLen) } if err := skipN(r, uint64(cidLen)); err != nil { // Connection ID return nil, err } if err := skipN(r, 16); err != nil { // Stateless Reset Token return nil, err } case 0x1a, 0x1b: // PATH_CHALLENGE, PATH_RESPONSE if err := skipN(r, 8); err != nil { return nil, err } case 0x1c: // CONNECTION_CLOSE (transport) if _, err := quicvarint.Read(r); err != nil { // Error Code return nil, err } if _, err := quicvarint.Read(r); err != nil { // Frame Type return nil, err } reasonLen, err := quicvarint.Read(r) if err != nil { return nil, err } if err := skipN(r, reasonLen); err != nil { return nil, err } case 0x1d: // CONNECTION_CLOSE (application) if _, err := quicvarint.Read(r); err != nil { // Error Code return nil, err } reasonLen, err := quicvarint.Read(r) if err != nil { return nil, err } if err := skipN(r, reasonLen); err != nil { return nil, err } case 0x30, 0x31: // DATAGRAM var dataLen uint64 if typ&0x01 != 0 { n, err := quicvarint.Read(r) if err != nil { return nil, err } dataLen = n } else { dataLen = uint64(r.Len()) } if err := skipN(r, dataLen); err != nil { return nil, err } default: // Unknown/extension frame type: if we already collected CRYPTO // frames, return them instead of failing hard. if len(frames) > 0 { return frames, nil } return nil, fmt.Errorf("unsupported frame type: %d", typ) } } return frames, nil } func skipN(r *bytes.Reader, n uint64) error { if n > uint64(r.Len()) { return io.EOF } _, err := r.Seek(int64(n), io.SeekCurrent) return err } func collectConnectionIDCandidates(hdr *Header, opts *ReadCryptoFramesOptions) [][]byte { ids := make([][]byte, 0, 2) if opts != nil && len(opts.PreferredConnectionID) > 0 { ids = append(ids, opts.PreferredConnectionID) } ids = append(ids, hdr.DestConnectionID) if opts != nil { ids = append(ids, opts.AdditionalConnectionIDs...) } return uniqueNonEmptyConnectionIDs(ids) } func collectPacketNumberMaxGuesses(opts *ReadCryptoFramesOptions) []int64 { guesses := defaultPNMaxGuesses if opts != nil && len(opts.PacketNumberMaxGuesses) > 0 { guesses = opts.PacketNumberMaxGuesses } if opts == nil || !opts.HasPreferredPNMax { return uniqueInt64PreserveOrder(guesses) } out := make([]int64, 0, len(guesses)+1) out = append(out, opts.PreferredPNMax) out = append(out, guesses...) return uniqueInt64PreserveOrder(out) } func collectSecretLabels(opts *ReadCryptoFramesOptions) []string { labels := []string{initialSecretLabelClientIn} if opts != nil && opts.TryServerSecret { labels = append(labels, initialSecretLabelServerIn) } if opts == nil || opts.PreferredSecretLabel == "" { return labels } return prependStringIfPresent(labels, opts.PreferredSecretLabel) } func prependStringIfPresent(base []string, preferred string) []string { if preferred == "" { return base } has := false for _, s := range base { if s == preferred { has = true break } } if !has { return base } out := make([]string, 0, len(base)) out = append(out, preferred) for _, s := range base { if s == preferred { continue } out = append(out, s) } return out } func uniqueNonEmptyConnectionIDs(ids [][]byte) [][]byte { out := make([][]byte, 0, len(ids)) seen := make(map[string]struct{}, len(ids)) for _, id := range ids { if len(id) == 0 { continue } k := string(id) if _, ok := seen[k]; ok { continue } seen[k] = struct{}{} out = append(out, append([]byte(nil), id...)) } return out } func uniqueInt64PreserveOrder(values []int64) []int64 { out := make([]int64, 0, len(values)) seen := make(map[int64]struct{}, len(values)) for _, v := range values { if _, ok := seen[v]; ok { continue } seen[v] = struct{}{} out = append(out, v) } return out } type initialPacketProtectorCacheKey struct { Version uint32 ConnID string Label string } type initialPacketProtectorCacheEntry struct { Key initialPacketProtectorCacheKey Value *PacketProtector } type initialPacketProtectorCache struct { mutex sync.Mutex capacity int ll *list.List items map[initialPacketProtectorCacheKey]*list.Element } func newInitialPacketProtectorCache(capacity int) *initialPacketProtectorCache { if capacity <= 0 { capacity = 1 } return &initialPacketProtectorCache{ capacity: capacity, ll: list.New(), items: make(map[initialPacketProtectorCacheKey]*list.Element, capacity), } } func (c *initialPacketProtectorCache) Get(key initialPacketProtectorCacheKey) (*PacketProtector, bool) { c.mutex.Lock() defer c.mutex.Unlock() elem, ok := c.items[key] if !ok { return nil, false } c.ll.MoveToFront(elem) return elem.Value.(*initialPacketProtectorCacheEntry).Value, true } func (c *initialPacketProtectorCache) Add(key initialPacketProtectorCacheKey, value *PacketProtector) { c.mutex.Lock() defer c.mutex.Unlock() if elem, ok := c.items[key]; ok { entry := elem.Value.(*initialPacketProtectorCacheEntry) entry.Value = value c.ll.MoveToFront(elem) return } elem := c.ll.PushFront(&initialPacketProtectorCacheEntry{ Key: key, Value: value, }) c.items[key] = elem for c.ll.Len() > c.capacity { oldest := c.ll.Back() if oldest == nil { break } c.ll.Remove(oldest) delete(c.items, oldest.Value.(*initialPacketProtectorCacheEntry).Key) } } func getOrCreateInitialPacketProtector(version uint32, connID []byte, label string) (*PacketProtector, error) { key := initialPacketProtectorCacheKey{ Version: version, ConnID: string(connID), Label: label, } if cached, ok := initialProtectorCache.Get(key); ok { return cached, nil } initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(version)) secret := hkdfExpandLabel(crypto.SHA256.New, initialSecret, label, []byte{}, crypto.SHA256.Size()) protectionKey, err := NewInitialProtectionKey(secret, version) if err != nil { return nil, fmt.Errorf("NewInitialProtectionKey: %w", err) } protector := NewPacketProtector(protectionKey) initialProtectorCache.Add(key, protector) return protector, nil } // assembleCryptoFrames assembles multiple crypto frames into a single slice (if possible). // It returns an error if the frames cannot be assembled. This can happen if the frames are not contiguous. func assembleCryptoFrames(frames []CryptoFrame) []byte { if len(frames) == 0 { return nil } if len(frames) == 1 { return frames[0].Data } // sort the frames by offset sort.Slice(frames, func(i, j int) bool { return frames[i].Offset < frames[j].Offset }) // check if the frames are contiguous for i := 1; i < len(frames); i++ { if frames[i].Offset != frames[i-1].Offset+int64(len(frames[i-1].Data)) { return nil } } // concatenate the frames data := make([]byte, frames[len(frames)-1].Offset+int64(len(frames[len(frames)-1].Data))) for _, frame := range frames { copy(data[frame.Offset:], frame.Data) } return data }