package quic import ( "bytes" "crypto" "errors" "fmt" "io" "sort" "github.com/quic-go/quic-go/quicvarint" "golang.org/x/crypto/hkdf" ) func ReadCryptoPayload(packet []byte) ([]byte, error) { frs, err := ReadCryptoFrames(packet) if err != nil { return nil, err } data := assembleCryptoFrames(frs) if data == nil { return nil, errors.New("unable to assemble crypto frames") } return data, nil } // ReadCryptoFrames decrypts a QUIC Initial client packet and returns CRYPTO frames. func ReadCryptoFrames(packet []byte) ([]CryptoFrame, error) { hdr, offset, err := ParseInitialHeader(packet) if err != nil { return nil, err } // Some sanity checks if hdr.Version != V1 && hdr.Version != V2 { return nil, fmt.Errorf("unsupported version: %x", hdr.Version) } if offset == 0 || hdr.Length == 0 { return nil, errors.New("invalid packet") } initialSecret := hkdf.Extract(crypto.SHA256.New, hdr.DestConnectionID, getSalt(hdr.Version)) clientSecret := hkdfExpandLabel(crypto.SHA256.New, initialSecret, "client in", []byte{}, crypto.SHA256.Size()) key, err := NewInitialProtectionKey(clientSecret, hdr.Version) if err != nil { return nil, fmt.Errorf("NewInitialProtectionKey: %w", err) } pp := NewPacketProtector(key) // https://datatracker.ietf.org/doc/html/draft-ietf-quic-tls-32#name-client-initial // // "The unprotected header includes the connection ID and a 4-byte packet number encoding for a packet number of 2" if int64(len(packet)) < offset+hdr.Length { return nil, fmt.Errorf("packet is too short: %d < %d", len(packet), offset+hdr.Length) } packetView := packet[:offset+hdr.Length] pnMaxGuesses := []int64{0, 1, 2, 3, 4, 8, 16} var lastErr error for _, pnMax := range pnMaxGuesses { packetCopy := append([]byte(nil), packetView...) unProtectedPayload, err := pp.UnProtect(packetCopy, offset, pnMax) if err != nil { lastErr = err continue } frs, err := extractCryptoFrames(bytes.NewReader(unProtectedPayload)) if err != nil { lastErr = err continue } return frs, nil } if lastErr != nil { return nil, lastErr } return nil, errors.New("unable to decrypt initial packet") } const ( paddingFrameType = 0x00 pingFrameType = 0x01 cryptoFrameType = 0x06 ) type CryptoFrame struct { Offset int64 Data []byte } func extractCryptoFrames(r *bytes.Reader) ([]CryptoFrame, error) { var frames []CryptoFrame for r.Len() > 0 { typ, err := quicvarint.Read(r) if err != nil { return nil, err } switch typ { case paddingFrameType, pingFrameType, 0x1e: // PADDING, PING, HANDSHAKE_DONE: no payload. continue case 0x02, 0x03: // ACK, ACK_ECN if _, err := quicvarint.Read(r); err != nil { // Largest Acknowledged return nil, err } if _, err := quicvarint.Read(r); err != nil { // ACK Delay return nil, err } ackRangeCount, err := quicvarint.Read(r) if err != nil { return nil, err } if _, err := quicvarint.Read(r); err != nil { // First ACK Range return nil, err } for i := uint64(0); i < ackRangeCount; i++ { if _, err := quicvarint.Read(r); err != nil { // Gap return nil, err } if _, err := quicvarint.Read(r); err != nil { // ACK Range Length return nil, err } } if typ == 0x03 { if _, err := quicvarint.Read(r); err != nil { // ECT0 Count return nil, err } if _, err := quicvarint.Read(r); err != nil { // ECT1 Count return nil, err } if _, err := quicvarint.Read(r); err != nil { // ECN-CE Count return nil, err } } case 0x04: // RESET_STREAM if _, err := quicvarint.Read(r); err != nil { // Stream ID return nil, err } if _, err := quicvarint.Read(r); err != nil { // Application Error Code return nil, err } if _, err := quicvarint.Read(r); err != nil { // Final Size return nil, err } case 0x05: // STOP_SENDING if _, err := quicvarint.Read(r); err != nil { // Stream ID return nil, err } if _, err := quicvarint.Read(r); err != nil { // Application Error Code return nil, err } case cryptoFrameType: // CRYPTO var frame CryptoFrame offset, err := quicvarint.Read(r) if err != nil { return nil, err } frame.Offset = int64(offset) dataLen, err := quicvarint.Read(r) if err != nil { return nil, err } frame.Data = make([]byte, dataLen) if _, err := io.ReadFull(r, frame.Data); err != nil { return nil, err } frames = append(frames, frame) case 0x07: // NEW_TOKEN tokenLen, err := quicvarint.Read(r) if err != nil { return nil, err } if err := skipN(r, tokenLen); err != nil { return nil, err } case 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f: // STREAM if _, err := quicvarint.Read(r); err != nil { // Stream ID return nil, err } hasOffset := typ&0x04 != 0 hasLength := typ&0x02 != 0 if hasOffset { if _, err := quicvarint.Read(r); err != nil { // Offset return nil, err } } var dataLen uint64 if hasLength { n, err := quicvarint.Read(r) if err != nil { return nil, err } dataLen = n } else { dataLen = uint64(r.Len()) } if err := skipN(r, dataLen); err != nil { return nil, err } case 0x10, 0x12, 0x13, 0x14, 0x16, 0x17, 0x19: // MAX_DATA, MAX_STREAMS_*, DATA_BLOCKED, STREAMS_BLOCKED_*, RETIRE_CONNECTION_ID if _, err := quicvarint.Read(r); err != nil { return nil, err } case 0x11, 0x15: // MAX_STREAM_DATA, STREAM_DATA_BLOCKED if _, err := quicvarint.Read(r); err != nil { return nil, err } if _, err := quicvarint.Read(r); err != nil { return nil, err } case 0x18: // NEW_CONNECTION_ID if _, err := quicvarint.Read(r); err != nil { // Sequence Number return nil, err } if _, err := quicvarint.Read(r); err != nil { // Retire Prior To return nil, err } cidLen, err := r.ReadByte() if err != nil { return nil, err } if cidLen > 20 { return nil, fmt.Errorf("invalid connection ID length: %d", cidLen) } if err := skipN(r, uint64(cidLen)); err != nil { // Connection ID return nil, err } if err := skipN(r, 16); err != nil { // Stateless Reset Token return nil, err } case 0x1a, 0x1b: // PATH_CHALLENGE, PATH_RESPONSE if err := skipN(r, 8); err != nil { return nil, err } case 0x1c: // CONNECTION_CLOSE (transport) if _, err := quicvarint.Read(r); err != nil { // Error Code return nil, err } if _, err := quicvarint.Read(r); err != nil { // Frame Type return nil, err } reasonLen, err := quicvarint.Read(r) if err != nil { return nil, err } if err := skipN(r, reasonLen); err != nil { return nil, err } case 0x1d: // CONNECTION_CLOSE (application) if _, err := quicvarint.Read(r); err != nil { // Error Code return nil, err } reasonLen, err := quicvarint.Read(r) if err != nil { return nil, err } if err := skipN(r, reasonLen); err != nil { return nil, err } case 0x30, 0x31: // DATAGRAM var dataLen uint64 if typ&0x01 != 0 { n, err := quicvarint.Read(r) if err != nil { return nil, err } dataLen = n } else { dataLen = uint64(r.Len()) } if err := skipN(r, dataLen); err != nil { return nil, err } default: // Unknown/extension frame type: if we already collected CRYPTO // frames, return them instead of failing hard. if len(frames) > 0 { return frames, nil } return nil, fmt.Errorf("unsupported frame type: %d", typ) } } return frames, nil } func skipN(r *bytes.Reader, n uint64) error { if n > uint64(r.Len()) { return io.EOF } _, err := r.Seek(int64(n), io.SeekCurrent) return err } // assembleCryptoFrames assembles multiple crypto frames into a single slice (if possible). // It returns an error if the frames cannot be assembled. This can happen if the frames are not contiguous. func assembleCryptoFrames(frames []CryptoFrame) []byte { if len(frames) == 0 { return nil } if len(frames) == 1 { return frames[0].Data } // sort the frames by offset sort.Slice(frames, func(i, j int) bool { return frames[i].Offset < frames[j].Offset }) // check if the frames are contiguous for i := 1; i < len(frames); i++ { if frames[i].Offset != frames[i-1].Offset+int64(len(frames[i-1].Data)) { return nil } } // concatenate the frames data := make([]byte, frames[len(frames)-1].Offset+int64(len(frames[len(frames)-1].Data))) for _, frame := range frames { copy(data[frame.Offset:], frame.Data) } return data }