analyzer: quic: some optimizations
This commit is contained in:
@@ -2,11 +2,13 @@ package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"container/list"
|
||||
"crypto"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
@@ -16,10 +18,30 @@ var defaultPNMaxGuesses = []int64{
|
||||
0, 1, 2, 3, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096,
|
||||
}
|
||||
|
||||
const (
|
||||
initialSecretLabelClientIn = "client in"
|
||||
initialSecretLabelServerIn = "server in"
|
||||
)
|
||||
|
||||
const initialProtectorCacheSize = 512
|
||||
|
||||
var initialProtectorCache = newInitialPacketProtectorCache(initialProtectorCacheSize)
|
||||
|
||||
type DecryptSuccessHint struct {
|
||||
ConnectionID []byte
|
||||
SecretLabel string
|
||||
PacketNumberMax int64
|
||||
}
|
||||
|
||||
type ReadCryptoFramesOptions struct {
|
||||
AdditionalConnectionIDs [][]byte
|
||||
TryServerSecret bool
|
||||
PacketNumberMaxGuesses []int64
|
||||
PreferredConnectionID []byte
|
||||
PreferredSecretLabel string
|
||||
PreferredPNMax int64
|
||||
HasPreferredPNMax bool
|
||||
SuccessHint *DecryptSuccessHint
|
||||
}
|
||||
|
||||
func ReadCryptoPayload(packet []byte) ([]byte, error) {
|
||||
@@ -61,33 +83,18 @@ func ReadCryptoFramesWithOptions(packet []byte, opts *ReadCryptoFramesOptions) (
|
||||
}
|
||||
packetView := packet[:offset+hdr.Length]
|
||||
|
||||
candidateConnIDs := [][]byte{hdr.DestConnectionID}
|
||||
if opts != nil {
|
||||
candidateConnIDs = append(candidateConnIDs, opts.AdditionalConnectionIDs...)
|
||||
}
|
||||
candidateConnIDs = uniqueNonEmptyConnectionIDs(candidateConnIDs)
|
||||
|
||||
pnMaxGuesses := defaultPNMaxGuesses
|
||||
if opts != nil && len(opts.PacketNumberMaxGuesses) > 0 {
|
||||
pnMaxGuesses = opts.PacketNumberMaxGuesses
|
||||
}
|
||||
|
||||
labels := []string{"client in"}
|
||||
if opts != nil && opts.TryServerSecret {
|
||||
labels = append(labels, "server in")
|
||||
}
|
||||
candidateConnIDs := collectConnectionIDCandidates(hdr, opts)
|
||||
pnMaxGuesses := collectPacketNumberMaxGuesses(opts)
|
||||
labels := collectSecretLabels(opts)
|
||||
|
||||
var lastErr error
|
||||
for _, connID := range candidateConnIDs {
|
||||
initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(hdr.Version))
|
||||
for _, label := range labels {
|
||||
secret := hkdfExpandLabel(crypto.SHA256.New, initialSecret, label, []byte{}, crypto.SHA256.Size())
|
||||
key, err := NewInitialProtectionKey(secret, hdr.Version)
|
||||
pp, err := getOrCreateInitialPacketProtector(hdr.Version, connID, label)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("NewInitialProtectionKey: %w", err)
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
pp := NewPacketProtector(key)
|
||||
for _, pnMax := range pnMaxGuesses {
|
||||
packetCopy := append([]byte(nil), packetView...)
|
||||
unProtectedPayload, err := pp.UnProtect(packetCopy, offset, pnMax)
|
||||
@@ -100,6 +107,11 @@ func ReadCryptoFramesWithOptions(packet []byte, opts *ReadCryptoFramesOptions) (
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
if opts != nil && opts.SuccessHint != nil {
|
||||
opts.SuccessHint.ConnectionID = append(opts.SuccessHint.ConnectionID[:0], connID...)
|
||||
opts.SuccessHint.SecretLabel = label
|
||||
opts.SuccessHint.PacketNumberMax = pnMax
|
||||
}
|
||||
return frs, nil
|
||||
}
|
||||
}
|
||||
@@ -337,6 +349,68 @@ func skipN(r *bytes.Reader, n uint64) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func collectConnectionIDCandidates(hdr *Header, opts *ReadCryptoFramesOptions) [][]byte {
|
||||
ids := make([][]byte, 0, 2)
|
||||
if opts != nil && len(opts.PreferredConnectionID) > 0 {
|
||||
ids = append(ids, opts.PreferredConnectionID)
|
||||
}
|
||||
ids = append(ids, hdr.DestConnectionID)
|
||||
if opts != nil {
|
||||
ids = append(ids, opts.AdditionalConnectionIDs...)
|
||||
}
|
||||
return uniqueNonEmptyConnectionIDs(ids)
|
||||
}
|
||||
|
||||
func collectPacketNumberMaxGuesses(opts *ReadCryptoFramesOptions) []int64 {
|
||||
guesses := defaultPNMaxGuesses
|
||||
if opts != nil && len(opts.PacketNumberMaxGuesses) > 0 {
|
||||
guesses = opts.PacketNumberMaxGuesses
|
||||
}
|
||||
if opts == nil || !opts.HasPreferredPNMax {
|
||||
return uniqueInt64PreserveOrder(guesses)
|
||||
}
|
||||
out := make([]int64, 0, len(guesses)+1)
|
||||
out = append(out, opts.PreferredPNMax)
|
||||
out = append(out, guesses...)
|
||||
return uniqueInt64PreserveOrder(out)
|
||||
}
|
||||
|
||||
func collectSecretLabels(opts *ReadCryptoFramesOptions) []string {
|
||||
labels := []string{initialSecretLabelClientIn}
|
||||
if opts != nil && opts.TryServerSecret {
|
||||
labels = append(labels, initialSecretLabelServerIn)
|
||||
}
|
||||
if opts == nil || opts.PreferredSecretLabel == "" {
|
||||
return labels
|
||||
}
|
||||
return prependStringIfPresent(labels, opts.PreferredSecretLabel)
|
||||
}
|
||||
|
||||
func prependStringIfPresent(base []string, preferred string) []string {
|
||||
if preferred == "" {
|
||||
return base
|
||||
}
|
||||
has := false
|
||||
for _, s := range base {
|
||||
if s == preferred {
|
||||
has = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !has {
|
||||
return base
|
||||
}
|
||||
out := make([]string, 0, len(base))
|
||||
out = append(out, preferred)
|
||||
for _, s := range base {
|
||||
if s == preferred {
|
||||
continue
|
||||
}
|
||||
out = append(out, s)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func uniqueNonEmptyConnectionIDs(ids [][]byte) [][]byte {
|
||||
out := make([][]byte, 0, len(ids))
|
||||
seen := make(map[string]struct{}, len(ids))
|
||||
@@ -354,6 +428,103 @@ func uniqueNonEmptyConnectionIDs(ids [][]byte) [][]byte {
|
||||
return out
|
||||
}
|
||||
|
||||
func uniqueInt64PreserveOrder(values []int64) []int64 {
|
||||
out := make([]int64, 0, len(values))
|
||||
seen := make(map[int64]struct{}, len(values))
|
||||
for _, v := range values {
|
||||
if _, ok := seen[v]; ok {
|
||||
continue
|
||||
}
|
||||
seen[v] = struct{}{}
|
||||
out = append(out, v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type initialPacketProtectorCacheKey struct {
|
||||
Version uint32
|
||||
ConnID string
|
||||
Label string
|
||||
}
|
||||
|
||||
type initialPacketProtectorCacheEntry struct {
|
||||
Key initialPacketProtectorCacheKey
|
||||
Value *PacketProtector
|
||||
}
|
||||
|
||||
type initialPacketProtectorCache struct {
|
||||
mutex sync.Mutex
|
||||
capacity int
|
||||
ll *list.List
|
||||
items map[initialPacketProtectorCacheKey]*list.Element
|
||||
}
|
||||
|
||||
func newInitialPacketProtectorCache(capacity int) *initialPacketProtectorCache {
|
||||
if capacity <= 0 {
|
||||
capacity = 1
|
||||
}
|
||||
return &initialPacketProtectorCache{
|
||||
capacity: capacity,
|
||||
ll: list.New(),
|
||||
items: make(map[initialPacketProtectorCacheKey]*list.Element, capacity),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *initialPacketProtectorCache) Get(key initialPacketProtectorCacheKey) (*PacketProtector, bool) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
elem, ok := c.items[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
c.ll.MoveToFront(elem)
|
||||
return elem.Value.(*initialPacketProtectorCacheEntry).Value, true
|
||||
}
|
||||
|
||||
func (c *initialPacketProtectorCache) Add(key initialPacketProtectorCacheKey, value *PacketProtector) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
if elem, ok := c.items[key]; ok {
|
||||
entry := elem.Value.(*initialPacketProtectorCacheEntry)
|
||||
entry.Value = value
|
||||
c.ll.MoveToFront(elem)
|
||||
return
|
||||
}
|
||||
elem := c.ll.PushFront(&initialPacketProtectorCacheEntry{
|
||||
Key: key,
|
||||
Value: value,
|
||||
})
|
||||
c.items[key] = elem
|
||||
for c.ll.Len() > c.capacity {
|
||||
oldest := c.ll.Back()
|
||||
if oldest == nil {
|
||||
break
|
||||
}
|
||||
c.ll.Remove(oldest)
|
||||
delete(c.items, oldest.Value.(*initialPacketProtectorCacheEntry).Key)
|
||||
}
|
||||
}
|
||||
|
||||
func getOrCreateInitialPacketProtector(version uint32, connID []byte, label string) (*PacketProtector, error) {
|
||||
key := initialPacketProtectorCacheKey{
|
||||
Version: version,
|
||||
ConnID: string(connID),
|
||||
Label: label,
|
||||
}
|
||||
if cached, ok := initialProtectorCache.Get(key); ok {
|
||||
return cached, nil
|
||||
}
|
||||
initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(version))
|
||||
secret := hkdfExpandLabel(crypto.SHA256.New, initialSecret, label, []byte{}, crypto.SHA256.Size())
|
||||
protectionKey, err := NewInitialProtectionKey(secret, version)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewInitialProtectionKey: %w", err)
|
||||
}
|
||||
protector := NewPacketProtector(protectionKey)
|
||||
initialProtectorCache.Add(key, protector)
|
||||
return protector, nil
|
||||
}
|
||||
|
||||
// assembleCryptoFrames assembles multiple crypto frames into a single slice (if possible).
|
||||
// It returns an error if the frames cannot be assembled. This can happen if the frames are not contiguous.
|
||||
func assembleCryptoFrames(frames []CryptoFrame) []byte {
|
||||
|
||||
Reference in New Issue
Block a user