552 lines
14 KiB
Go
552 lines
14 KiB
Go
package quic
|
|
|
|
import (
|
|
"bytes"
|
|
"container/list"
|
|
"crypto"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"sort"
|
|
"sync"
|
|
|
|
"github.com/quic-go/quic-go/quicvarint"
|
|
"golang.org/x/crypto/hkdf"
|
|
)
|
|
|
|
var defaultPNMaxGuesses = []int64{
|
|
0, 1, 2, 3, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096,
|
|
}
|
|
|
|
const (
|
|
initialSecretLabelClientIn = "client in"
|
|
initialSecretLabelServerIn = "server in"
|
|
)
|
|
|
|
const initialProtectorCacheSize = 512
|
|
|
|
var initialProtectorCache = newInitialPacketProtectorCache(initialProtectorCacheSize)
|
|
|
|
type DecryptSuccessHint struct {
|
|
ConnectionID []byte
|
|
SecretLabel string
|
|
PacketNumberMax int64
|
|
}
|
|
|
|
type ReadCryptoFramesOptions struct {
|
|
AdditionalConnectionIDs [][]byte
|
|
TryServerSecret bool
|
|
PacketNumberMaxGuesses []int64
|
|
PreferredConnectionID []byte
|
|
PreferredSecretLabel string
|
|
PreferredPNMax int64
|
|
HasPreferredPNMax bool
|
|
SuccessHint *DecryptSuccessHint
|
|
}
|
|
|
|
func ReadCryptoPayload(packet []byte) ([]byte, error) {
|
|
frs, err := ReadCryptoFrames(packet)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
data := assembleCryptoFrames(frs)
|
|
if data == nil {
|
|
return nil, errors.New("unable to assemble crypto frames")
|
|
}
|
|
return data, nil
|
|
}
|
|
|
|
// ReadCryptoFrames decrypts a QUIC Initial client packet and returns CRYPTO frames.
|
|
func ReadCryptoFrames(packet []byte) ([]CryptoFrame, error) {
|
|
return ReadCryptoFramesWithOptions(packet, nil)
|
|
}
|
|
|
|
// ReadCryptoFramesWithOptions decrypts a QUIC Initial packet and returns CRYPTO frames.
|
|
func ReadCryptoFramesWithOptions(packet []byte, opts *ReadCryptoFramesOptions) ([]CryptoFrame, error) {
|
|
hdr, offset, err := ParseInitialHeader(packet)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Some sanity checks
|
|
if hdr.Version != V1 && hdr.Version != V2 {
|
|
return nil, fmt.Errorf("unsupported version: %x", hdr.Version)
|
|
}
|
|
if offset == 0 || hdr.Length == 0 {
|
|
return nil, errors.New("invalid packet")
|
|
}
|
|
|
|
// https://datatracker.ietf.org/doc/html/draft-ietf-quic-tls-32#name-client-initial
|
|
//
|
|
// "The unprotected header includes the connection ID and a 4-byte packet number encoding for a packet number of 2"
|
|
if int64(len(packet)) < offset+hdr.Length {
|
|
return nil, fmt.Errorf("packet is too short: %d < %d", len(packet), offset+hdr.Length)
|
|
}
|
|
packetView := packet[:offset+hdr.Length]
|
|
|
|
candidateConnIDs := collectConnectionIDCandidates(hdr, opts)
|
|
pnMaxGuesses := collectPacketNumberMaxGuesses(opts)
|
|
labels := collectSecretLabels(opts)
|
|
|
|
var lastErr error
|
|
for _, connID := range candidateConnIDs {
|
|
for _, label := range labels {
|
|
pp, err := getOrCreateInitialPacketProtector(hdr.Version, connID, label)
|
|
if err != nil {
|
|
lastErr = err
|
|
continue
|
|
}
|
|
for _, pnMax := range pnMaxGuesses {
|
|
packetCopy := append([]byte(nil), packetView...)
|
|
unProtectedPayload, err := pp.UnProtect(packetCopy, offset, pnMax)
|
|
if err != nil {
|
|
lastErr = err
|
|
continue
|
|
}
|
|
frs, err := extractCryptoFrames(bytes.NewReader(unProtectedPayload))
|
|
if err != nil {
|
|
lastErr = err
|
|
continue
|
|
}
|
|
if opts != nil && opts.SuccessHint != nil {
|
|
opts.SuccessHint.ConnectionID = append(opts.SuccessHint.ConnectionID[:0], connID...)
|
|
opts.SuccessHint.SecretLabel = label
|
|
opts.SuccessHint.PacketNumberMax = pnMax
|
|
}
|
|
return frs, nil
|
|
}
|
|
}
|
|
}
|
|
if lastErr != nil {
|
|
return nil, lastErr
|
|
}
|
|
return nil, errors.New("unable to decrypt initial packet")
|
|
}
|
|
|
|
const (
|
|
paddingFrameType = 0x00
|
|
pingFrameType = 0x01
|
|
cryptoFrameType = 0x06
|
|
)
|
|
|
|
type CryptoFrame struct {
|
|
Offset int64
|
|
Data []byte
|
|
}
|
|
|
|
func extractCryptoFrames(r *bytes.Reader) ([]CryptoFrame, error) {
|
|
var frames []CryptoFrame
|
|
for r.Len() > 0 {
|
|
typ, err := quicvarint.Read(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
switch typ {
|
|
case paddingFrameType, pingFrameType, 0x1e:
|
|
// PADDING, PING, HANDSHAKE_DONE: no payload.
|
|
continue
|
|
case 0x02, 0x03:
|
|
// ACK, ACK_ECN
|
|
if _, err := quicvarint.Read(r); err != nil { // Largest Acknowledged
|
|
return nil, err
|
|
}
|
|
if _, err := quicvarint.Read(r); err != nil { // ACK Delay
|
|
return nil, err
|
|
}
|
|
ackRangeCount, err := quicvarint.Read(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := quicvarint.Read(r); err != nil { // First ACK Range
|
|
return nil, err
|
|
}
|
|
for i := uint64(0); i < ackRangeCount; i++ {
|
|
if _, err := quicvarint.Read(r); err != nil { // Gap
|
|
return nil, err
|
|
}
|
|
if _, err := quicvarint.Read(r); err != nil { // ACK Range Length
|
|
return nil, err
|
|
}
|
|
}
|
|
if typ == 0x03 {
|
|
if _, err := quicvarint.Read(r); err != nil { // ECT0 Count
|
|
return nil, err
|
|
}
|
|
if _, err := quicvarint.Read(r); err != nil { // ECT1 Count
|
|
return nil, err
|
|
}
|
|
if _, err := quicvarint.Read(r); err != nil { // ECN-CE Count
|
|
return nil, err
|
|
}
|
|
}
|
|
case 0x04:
|
|
// RESET_STREAM
|
|
if _, err := quicvarint.Read(r); err != nil { // Stream ID
|
|
return nil, err
|
|
}
|
|
if _, err := quicvarint.Read(r); err != nil { // Application Error Code
|
|
return nil, err
|
|
}
|
|
if _, err := quicvarint.Read(r); err != nil { // Final Size
|
|
return nil, err
|
|
}
|
|
case 0x05:
|
|
// STOP_SENDING
|
|
if _, err := quicvarint.Read(r); err != nil { // Stream ID
|
|
return nil, err
|
|
}
|
|
if _, err := quicvarint.Read(r); err != nil { // Application Error Code
|
|
return nil, err
|
|
}
|
|
case cryptoFrameType:
|
|
// CRYPTO
|
|
var frame CryptoFrame
|
|
offset, err := quicvarint.Read(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
frame.Offset = int64(offset)
|
|
dataLen, err := quicvarint.Read(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
frame.Data = make([]byte, dataLen)
|
|
if _, err := io.ReadFull(r, frame.Data); err != nil {
|
|
return nil, err
|
|
}
|
|
frames = append(frames, frame)
|
|
case 0x07:
|
|
// NEW_TOKEN
|
|
tokenLen, err := quicvarint.Read(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := skipN(r, tokenLen); err != nil {
|
|
return nil, err
|
|
}
|
|
case 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f:
|
|
// STREAM
|
|
if _, err := quicvarint.Read(r); err != nil { // Stream ID
|
|
return nil, err
|
|
}
|
|
hasOffset := typ&0x04 != 0
|
|
hasLength := typ&0x02 != 0
|
|
if hasOffset {
|
|
if _, err := quicvarint.Read(r); err != nil { // Offset
|
|
return nil, err
|
|
}
|
|
}
|
|
var dataLen uint64
|
|
if hasLength {
|
|
n, err := quicvarint.Read(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
dataLen = n
|
|
} else {
|
|
dataLen = uint64(r.Len())
|
|
}
|
|
if err := skipN(r, dataLen); err != nil {
|
|
return nil, err
|
|
}
|
|
case 0x10, 0x12, 0x13, 0x14, 0x16, 0x17, 0x19:
|
|
// MAX_DATA, MAX_STREAMS_*, DATA_BLOCKED, STREAMS_BLOCKED_*, RETIRE_CONNECTION_ID
|
|
if _, err := quicvarint.Read(r); err != nil {
|
|
return nil, err
|
|
}
|
|
case 0x11, 0x15:
|
|
// MAX_STREAM_DATA, STREAM_DATA_BLOCKED
|
|
if _, err := quicvarint.Read(r); err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := quicvarint.Read(r); err != nil {
|
|
return nil, err
|
|
}
|
|
case 0x18:
|
|
// NEW_CONNECTION_ID
|
|
if _, err := quicvarint.Read(r); err != nil { // Sequence Number
|
|
return nil, err
|
|
}
|
|
if _, err := quicvarint.Read(r); err != nil { // Retire Prior To
|
|
return nil, err
|
|
}
|
|
cidLen, err := r.ReadByte()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if cidLen > 20 {
|
|
return nil, fmt.Errorf("invalid connection ID length: %d", cidLen)
|
|
}
|
|
if err := skipN(r, uint64(cidLen)); err != nil { // Connection ID
|
|
return nil, err
|
|
}
|
|
if err := skipN(r, 16); err != nil { // Stateless Reset Token
|
|
return nil, err
|
|
}
|
|
case 0x1a, 0x1b:
|
|
// PATH_CHALLENGE, PATH_RESPONSE
|
|
if err := skipN(r, 8); err != nil {
|
|
return nil, err
|
|
}
|
|
case 0x1c:
|
|
// CONNECTION_CLOSE (transport)
|
|
if _, err := quicvarint.Read(r); err != nil { // Error Code
|
|
return nil, err
|
|
}
|
|
if _, err := quicvarint.Read(r); err != nil { // Frame Type
|
|
return nil, err
|
|
}
|
|
reasonLen, err := quicvarint.Read(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := skipN(r, reasonLen); err != nil {
|
|
return nil, err
|
|
}
|
|
case 0x1d:
|
|
// CONNECTION_CLOSE (application)
|
|
if _, err := quicvarint.Read(r); err != nil { // Error Code
|
|
return nil, err
|
|
}
|
|
reasonLen, err := quicvarint.Read(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := skipN(r, reasonLen); err != nil {
|
|
return nil, err
|
|
}
|
|
case 0x30, 0x31:
|
|
// DATAGRAM
|
|
var dataLen uint64
|
|
if typ&0x01 != 0 {
|
|
n, err := quicvarint.Read(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
dataLen = n
|
|
} else {
|
|
dataLen = uint64(r.Len())
|
|
}
|
|
if err := skipN(r, dataLen); err != nil {
|
|
return nil, err
|
|
}
|
|
default:
|
|
// Unknown/extension frame type: if we already collected CRYPTO
|
|
// frames, return them instead of failing hard.
|
|
if len(frames) > 0 {
|
|
return frames, nil
|
|
}
|
|
return nil, fmt.Errorf("unsupported frame type: %d", typ)
|
|
}
|
|
}
|
|
return frames, nil
|
|
}
|
|
|
|
func skipN(r *bytes.Reader, n uint64) error {
|
|
if n > uint64(r.Len()) {
|
|
return io.EOF
|
|
}
|
|
_, err := r.Seek(int64(n), io.SeekCurrent)
|
|
return err
|
|
}
|
|
|
|
func collectConnectionIDCandidates(hdr *Header, opts *ReadCryptoFramesOptions) [][]byte {
|
|
ids := make([][]byte, 0, 2)
|
|
if opts != nil && len(opts.PreferredConnectionID) > 0 {
|
|
ids = append(ids, opts.PreferredConnectionID)
|
|
}
|
|
ids = append(ids, hdr.DestConnectionID)
|
|
if opts != nil {
|
|
ids = append(ids, opts.AdditionalConnectionIDs...)
|
|
}
|
|
return uniqueNonEmptyConnectionIDs(ids)
|
|
}
|
|
|
|
func collectPacketNumberMaxGuesses(opts *ReadCryptoFramesOptions) []int64 {
|
|
guesses := defaultPNMaxGuesses
|
|
if opts != nil && len(opts.PacketNumberMaxGuesses) > 0 {
|
|
guesses = opts.PacketNumberMaxGuesses
|
|
}
|
|
if opts == nil || !opts.HasPreferredPNMax {
|
|
return uniqueInt64PreserveOrder(guesses)
|
|
}
|
|
out := make([]int64, 0, len(guesses)+1)
|
|
out = append(out, opts.PreferredPNMax)
|
|
out = append(out, guesses...)
|
|
return uniqueInt64PreserveOrder(out)
|
|
}
|
|
|
|
func collectSecretLabels(opts *ReadCryptoFramesOptions) []string {
|
|
labels := []string{initialSecretLabelClientIn}
|
|
if opts != nil && opts.TryServerSecret {
|
|
labels = append(labels, initialSecretLabelServerIn)
|
|
}
|
|
if opts == nil || opts.PreferredSecretLabel == "" {
|
|
return labels
|
|
}
|
|
return prependStringIfPresent(labels, opts.PreferredSecretLabel)
|
|
}
|
|
|
|
func prependStringIfPresent(base []string, preferred string) []string {
|
|
if preferred == "" {
|
|
return base
|
|
}
|
|
has := false
|
|
for _, s := range base {
|
|
if s == preferred {
|
|
has = true
|
|
break
|
|
}
|
|
}
|
|
if !has {
|
|
return base
|
|
}
|
|
out := make([]string, 0, len(base))
|
|
out = append(out, preferred)
|
|
for _, s := range base {
|
|
if s == preferred {
|
|
continue
|
|
}
|
|
out = append(out, s)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func uniqueNonEmptyConnectionIDs(ids [][]byte) [][]byte {
|
|
out := make([][]byte, 0, len(ids))
|
|
seen := make(map[string]struct{}, len(ids))
|
|
for _, id := range ids {
|
|
if len(id) == 0 {
|
|
continue
|
|
}
|
|
k := string(id)
|
|
if _, ok := seen[k]; ok {
|
|
continue
|
|
}
|
|
seen[k] = struct{}{}
|
|
out = append(out, append([]byte(nil), id...))
|
|
}
|
|
return out
|
|
}
|
|
|
|
func uniqueInt64PreserveOrder(values []int64) []int64 {
|
|
out := make([]int64, 0, len(values))
|
|
seen := make(map[int64]struct{}, len(values))
|
|
for _, v := range values {
|
|
if _, ok := seen[v]; ok {
|
|
continue
|
|
}
|
|
seen[v] = struct{}{}
|
|
out = append(out, v)
|
|
}
|
|
return out
|
|
}
|
|
|
|
type initialPacketProtectorCacheKey struct {
|
|
Version uint32
|
|
ConnID string
|
|
Label string
|
|
}
|
|
|
|
type initialPacketProtectorCacheEntry struct {
|
|
Key initialPacketProtectorCacheKey
|
|
Value *PacketProtector
|
|
}
|
|
|
|
type initialPacketProtectorCache struct {
|
|
mutex sync.Mutex
|
|
capacity int
|
|
ll *list.List
|
|
items map[initialPacketProtectorCacheKey]*list.Element
|
|
}
|
|
|
|
func newInitialPacketProtectorCache(capacity int) *initialPacketProtectorCache {
|
|
if capacity <= 0 {
|
|
capacity = 1
|
|
}
|
|
return &initialPacketProtectorCache{
|
|
capacity: capacity,
|
|
ll: list.New(),
|
|
items: make(map[initialPacketProtectorCacheKey]*list.Element, capacity),
|
|
}
|
|
}
|
|
|
|
func (c *initialPacketProtectorCache) Get(key initialPacketProtectorCacheKey) (*PacketProtector, bool) {
|
|
c.mutex.Lock()
|
|
defer c.mutex.Unlock()
|
|
elem, ok := c.items[key]
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
c.ll.MoveToFront(elem)
|
|
return elem.Value.(*initialPacketProtectorCacheEntry).Value, true
|
|
}
|
|
|
|
func (c *initialPacketProtectorCache) Add(key initialPacketProtectorCacheKey, value *PacketProtector) {
|
|
c.mutex.Lock()
|
|
defer c.mutex.Unlock()
|
|
if elem, ok := c.items[key]; ok {
|
|
entry := elem.Value.(*initialPacketProtectorCacheEntry)
|
|
entry.Value = value
|
|
c.ll.MoveToFront(elem)
|
|
return
|
|
}
|
|
elem := c.ll.PushFront(&initialPacketProtectorCacheEntry{
|
|
Key: key,
|
|
Value: value,
|
|
})
|
|
c.items[key] = elem
|
|
for c.ll.Len() > c.capacity {
|
|
oldest := c.ll.Back()
|
|
if oldest == nil {
|
|
break
|
|
}
|
|
c.ll.Remove(oldest)
|
|
delete(c.items, oldest.Value.(*initialPacketProtectorCacheEntry).Key)
|
|
}
|
|
}
|
|
|
|
func getOrCreateInitialPacketProtector(version uint32, connID []byte, label string) (*PacketProtector, error) {
|
|
key := initialPacketProtectorCacheKey{
|
|
Version: version,
|
|
ConnID: string(connID),
|
|
Label: label,
|
|
}
|
|
if cached, ok := initialProtectorCache.Get(key); ok {
|
|
return cached, nil
|
|
}
|
|
initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(version))
|
|
secret := hkdfExpandLabel(crypto.SHA256.New, initialSecret, label, []byte{}, crypto.SHA256.Size())
|
|
protectionKey, err := NewInitialProtectionKey(secret, version)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("NewInitialProtectionKey: %w", err)
|
|
}
|
|
protector := NewPacketProtector(protectionKey)
|
|
initialProtectorCache.Add(key, protector)
|
|
return protector, nil
|
|
}
|
|
|
|
// assembleCryptoFrames assembles multiple crypto frames into a single slice (if possible).
|
|
// It returns an error if the frames cannot be assembled. This can happen if the frames are not contiguous.
|
|
func assembleCryptoFrames(frames []CryptoFrame) []byte {
|
|
if len(frames) == 0 {
|
|
return nil
|
|
}
|
|
if len(frames) == 1 {
|
|
return frames[0].Data
|
|
}
|
|
// sort the frames by offset
|
|
sort.Slice(frames, func(i, j int) bool { return frames[i].Offset < frames[j].Offset })
|
|
// check if the frames are contiguous
|
|
for i := 1; i < len(frames); i++ {
|
|
if frames[i].Offset != frames[i-1].Offset+int64(len(frames[i-1].Data)) {
|
|
return nil
|
|
}
|
|
}
|
|
// concatenate the frames
|
|
data := make([]byte, frames[len(frames)-1].Offset+int64(len(frames[len(frames)-1].Data)))
|
|
for _, frame := range frames {
|
|
copy(data[frame.Offset:], frame.Data)
|
|
}
|
|
return data
|
|
}
|