diff --git a/analyzer/udp/internal/quic/header.go b/analyzer/udp/internal/quic/header.go index 791f023..4d46e6e 100644 --- a/analyzer/udp/internal/quic/header.go +++ b/analyzer/udp/internal/quic/header.go @@ -9,6 +9,8 @@ import ( "github.com/quic-go/quic-go/quicvarint" ) +var ErrNotInitialPacket = errors.New("not initial packet") + // The Header represents a QUIC header. type Header struct { Type uint8 @@ -36,6 +38,9 @@ func parseLongHeader(b *bytes.Reader) (*Header, error) { if err != nil { return nil, err } + if !isLongHeader(typeByte) { + return nil, ErrNotInitialPacket + } h := &Header{} ver, err := beUint32(b) if err != nil { @@ -66,18 +71,19 @@ func parseLongHeader(b *bytes.Reader) (*Header, error) { if h.Version == V2 { initialPacketType = 0b01 } - if (typeByte >> 4 & 0b11) == initialPacketType { - tokenLen, err := quicvarint.Read(b) - if err != nil { - return nil, err - } - if tokenLen > uint64(b.Len()) { - return nil, io.EOF - } - h.Token = make([]byte, tokenLen) - if _, err := io.ReadFull(b, h.Token); err != nil { - return nil, err - } + if (typeByte>>4)&0b11 != initialPacketType { + return nil, ErrNotInitialPacket + } + tokenLen, err := quicvarint.Read(b) + if err != nil { + return nil, err + } + if tokenLen > uint64(b.Len()) { + return nil, io.EOF + } + h.Token = make([]byte, tokenLen) + if _, err := io.ReadFull(b, h.Token); err != nil { + return nil, err } pl, err := quicvarint.Read(b) diff --git a/analyzer/udp/internal/quic/payload.go b/analyzer/udp/internal/quic/payload.go index 9a2fa26..75d9da8 100644 --- a/analyzer/udp/internal/quic/payload.go +++ b/analyzer/udp/internal/quic/payload.go @@ -51,15 +51,27 @@ func ReadCryptoFrames(packet []byte) ([]CryptoFrame, error) { if int64(len(packet)) < offset+hdr.Length { return nil, fmt.Errorf("packet is too short: %d < %d", len(packet), offset+hdr.Length) } - unProtectedPayload, err := pp.UnProtect(packet[:offset+hdr.Length], offset, 2) - if err != nil { - return nil, err + packetView := packet[:offset+hdr.Length] + pnMaxGuesses := []int64{0, 1, 2, 3, 4, 8, 16} + var lastErr error + for _, pnMax := range pnMaxGuesses { + packetCopy := append([]byte(nil), packetView...) + unProtectedPayload, err := pp.UnProtect(packetCopy, offset, pnMax) + if err != nil { + lastErr = err + continue + } + frs, err := extractCryptoFrames(bytes.NewReader(unProtectedPayload)) + if err != nil { + lastErr = err + continue + } + return frs, nil } - frs, err := extractCryptoFrames(bytes.NewReader(unProtectedPayload)) - if err != nil { - return nil, err + if lastErr != nil { + return nil, lastErr } - return frs, nil + return nil, errors.New("unable to decrypt initial packet") } const ( diff --git a/analyzer/udp/internal/quic/payload_test.go b/analyzer/udp/internal/quic/payload_test.go index d44fb4f..4fe7e21 100644 --- a/analyzer/udp/internal/quic/payload_test.go +++ b/analyzer/udp/internal/quic/payload_test.go @@ -55,3 +55,14 @@ func TestExtractCryptoFrames_UnknownAfterCrypto(t *testing.T) { t.Fatalf("frame0 = %+v, want offset=0 data=abc", frames[0]) } } + +func TestReadCryptoFrames_NonInitialHeader(t *testing.T) { + // Short header packet marker should be rejected as non-initial. + _, err := ReadCryptoFrames([]byte{0x40, 0x01, 0x02, 0x03, 0x04}) + if err == nil { + t.Fatal("ReadCryptoFrames() error = nil, want non-nil") + } + if err.Error() != ErrNotInitialPacket.Error() { + t.Fatalf("ReadCryptoFrames() error = %v, want %v", err, ErrNotInitialPacket) + } +} diff --git a/analyzer/udp/quic.go b/analyzer/udp/quic.go index dae90e5..98e3abf 100644 --- a/analyzer/udp/quic.go +++ b/analyzer/udp/quic.go @@ -1,6 +1,7 @@ package udp import ( + "errors" "sort" "git.difuse.io/Difuse/Mellaris/analyzer" @@ -50,7 +51,14 @@ func (s *quicStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, done b const minDataSize = 41 frs, err := quic.ReadCryptoFrames(data) - if err != nil || len(frs) == 0 { + if err != nil { + if errors.Is(err, quic.ErrNotInitialPacket) { + return nil, false + } + s.invalidCount++ + return nil, s.invalidCount >= quicInvalidCountThreshold + } + if len(frs) == 0 { s.invalidCount++ return nil, s.invalidCount >= quicInvalidCountThreshold } @@ -64,8 +72,8 @@ func (s *quicStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, done b } if pl[0] != internal.TypeClientHello { - s.invalidCount++ - return nil, s.invalidCount >= quicInvalidCountThreshold + // Not a ClientHello (e.g. server-direction CRYPTO); ignore. + return nil, false } chLen := int(pl[1])<<16 | int(pl[2])<<8 | int(pl[3])