1
0
mirror of https://github.com/fumiama/terasu-cloudflared.git synced 2026-06-10 05:04:15 +08:00

TUN-9016: update go to 1.24

## Summary

Update several moving parts of cloudflared build system:

* use goboring 1.24.2 in cfsetup
* update linter and fix lint issues
* update packages namely **quic-go and net**
* install script for macos
* update docker files to use go 1.24.1
* remove usage of cloudflare-go
* pin golang linter

Closes TUN-9016
This commit is contained in:
Luis Neto
2025-06-06 09:05:49 +00:00
parent e144eac2af
commit 96ce66bd30
585 changed files with 23572 additions and 21356 deletions

View File

@@ -4,6 +4,7 @@ main
mockgen_tmp.go
*.qtr
*.qlog
*.sqlog
*.txt
race.[0-9]*

View File

@@ -1,45 +1,85 @@
linters-settings:
misspell:
ignore-words:
- ect
depguard:
rules:
quicvarint:
list-mode: strict
files:
- "**/github.com/quic-go/quic-go/quicvarint/*"
- "!$test"
allow:
- $gostd
version: "2"
linters:
disable-all: true
default: none
enable:
- asciicheck
- copyloopvar
- depguard
- exhaustive
- exportloopref
- goimports
- gofmt # redundant, since gofmt *should* be a no-op after gofumpt
- gofumpt
- gosimple
- govet
- ineffassign
- misspell
- prealloc
- staticcheck
- stylecheck
- unconvert
- unparam
- unused
issues:
exclude-files:
- internal/handshake/cipher_suite.go
exclude-rules:
- path: internal/qtls
linters:
- depguard
- path: _test\.go
linters:
- exhaustive
settings:
depguard:
rules:
random:
deny:
- pkg: "math/rand$"
desc: use math/rand/v2
- pkg: "golang.org/x/exp/rand"
desc: use math/rand/v2
quicvarint:
list-mode: strict
files:
- '**/github.com/quic-go/quic-go/quicvarint/*'
- '!$test'
allow:
- $gostd
rsa:
list-mode: original
deny:
- pkg: crypto/rsa
desc: "use crypto/ed25519 instead"
misspell:
ignore-rules:
- ect
exclusions:
generated: lax
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
rules:
- linters:
- depguard
path: internal/qtls
- linters:
- exhaustive
- prealloc
- unparam
path: _test\.go
- linters:
- staticcheck
path: _test\.go
text: 'SA1029:' # inappropriate key in call to context.WithValue
# WebTransport still relies on the ConnectionTracingID and ConnectionTracingKey.
# See https://github.com/quic-go/quic-go/issues/4405 for more details.
- linters:
- staticcheck
paths:
- http3/
- integrationtests/self/http_test.go
text: 'SA1019:.+quic\.ConnectionTracing(ID|Key)'
paths:
- internal/handshake/cipher_suite.go
- third_party$
- builtin$
- examples$
formatters:
enable:
- gofmt
- gofumpt
- goimports
exclusions:
generated: lax
paths:
- internal/handshake/cipher_suite.go
- third_party$
- builtin$
- examples$

View File

@@ -9,7 +9,8 @@
quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)) and HTTP Datagrams ([RFC 9297](https://datatracker.ietf.org/doc/html/rfc9297)).
In addition to these base RFCs, it also implements the following RFCs:
In addition to these base RFCs, it also implements the following RFCs:
* Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221))
* Datagram Packetization Layer Path MTU Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899))
* QUIC Version 2 ([RFC 9369](https://datatracker.ietf.org/doc/html/rfc9369))
@@ -33,6 +34,7 @@ Detailed documentation can be found on [quic-go.net](https://quic-go.net/docs/).
| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) |
| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) |
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) |
| [reverst](https://github.com/flipt-io/reverst) | Reverse Tunnels in Go over HTTP/3 and QUIC | ![GitHub Repo stars](https://img.shields.io/github/stars/flipt-io/reverst?style=flat-square) |
| [RoadRunner](https://github.com/roadrunner-server/roadrunner) | High-performance PHP application server, process manager written in Go and powered with plugins | ![GitHub Repo stars](https://img.shields.io/github/stars/roadrunner-server/roadrunner?style=flat-square) |
| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) |
| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) |

View File

@@ -7,45 +7,15 @@ import (
"net"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
type client struct {
sendConn sendConn
use0RTT bool
packetHandlers packetHandlerManager
onClose func()
tlsConf *tls.Config
config *Config
connIDGenerator ConnectionIDGenerator
srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
initialPacketNumber protocol.PacketNumber
hasNegotiatedVersion bool
version protocol.Version
handshakeChan chan struct{}
conn quicConn
tracer *logging.ConnectionTracer
tracingID ConnectionTracingID
logger utils.Logger
}
// make it possible to mock connection ID for initial generation in the tests
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
// DialAddr establishes a new QUIC connection to a server.
// It resolves the address, and then creates a new UDP connection to dial the QUIC server.
// When the QUIC connection is closed, this UDP connection is closed.
// See Dial for more details.
// See [Dial] for more details.
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
@@ -63,7 +33,7 @@ func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Confi
}
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
// See DialAddr for more details.
// See [DialAddr] for more details.
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
@@ -86,7 +56,7 @@ func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *
}
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
// See Dial for more details.
// See [Dial] for more details.
func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
dl, err := setupTransport(c, tlsConf, false)
if err != nil {
@@ -101,12 +71,12 @@ func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tl
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// If the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn does),
// If the PacketConn satisfies the [OOBCapablePacketConn] interface (as a [net.UDPConn] does),
// ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
// will be used instead of ReadFrom and WriteTo to read/write packets.
// The tls.Config must define an application protocol (using NextProtos).
//
// This is a convenience function. More advanced use cases should instantiate a Transport,
// This is a convenience function. More advanced use cases should instantiate a [Transport],
// which offers configuration options for a more fine-grained control of the connection establishment,
// including reusing the underlying UDP socket for multiple QUIC connections.
func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
@@ -132,120 +102,3 @@ func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn boo
isSingleUse: true,
}, nil
}
func dial(
ctx context.Context,
conn sendConn,
connIDGenerator ConnectionIDGenerator,
packetHandlers packetHandlerManager,
tlsConf *tls.Config,
config *Config,
onClose func(),
use0RTT bool,
) (quicConn, error) {
c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT)
if err != nil {
return nil, err
}
c.packetHandlers = packetHandlers
c.tracingID = nextConnTracingID()
if c.config.Tracer != nil {
c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
}
if c.tracer != nil && c.tracer.StartedConnection != nil {
c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
}
if err := c.dial(ctx); err != nil {
return nil, err
}
return c.conn, nil
}
func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
srcConnID, err := connIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err
}
destConnID, err := generateConnectionIDForInitial()
if err != nil {
return nil, err
}
c := &client{
connIDGenerator: connIDGenerator,
srcConnID: srcConnID,
destConnID: destConnID,
sendConn: sendConn,
use0RTT: use0RTT,
onClose: onClose,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"),
}
return c, nil
}
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
c.conn = newClientConnection(
context.WithValue(context.WithoutCancel(ctx), ConnectionTracingKey, c.tracingID),
c.sendConn,
c.packetHandlers,
c.destConnID,
c.srcConnID,
c.connIDGenerator,
c.config,
c.tlsConf,
c.initialPacketNumber,
c.use0RTT,
c.hasNegotiatedVersion,
c.tracer,
c.logger,
c.version,
)
c.packetHandlers.Add(c.srcConnID, c.conn)
errorChan := make(chan error, 1)
recreateChan := make(chan errCloseForRecreating)
go func() {
err := c.conn.run()
var recreateErr *errCloseForRecreating
if errors.As(err, &recreateErr) {
recreateChan <- *recreateErr
return
}
if c.onClose != nil {
c.onClose()
}
errorChan <- err // returns as soon as the connection is closed
}()
// only set when we're using 0-RTT
// Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
var earlyConnChan <-chan struct{}
if c.use0RTT {
earlyConnChan = c.conn.earlyConnReady()
}
select {
case <-ctx.Done():
c.conn.destroy(nil)
return context.Cause(ctx)
case err := <-errorChan:
return err
case recreateErr := <-recreateChan:
c.initialPacketNumber = recreateErr.nextPacketNumber
c.version = recreateErr.nextVersion
c.hasNegotiatedVersion = true
return c.dial(ctx)
case <-earlyConnChan:
// ready to send 0-RTT data
return nil
case <-c.conn.HandshakeComplete():
// handshake successfully completed
return nil
}
}

View File

@@ -3,6 +3,7 @@ package quic
import (
"math/bits"
"net"
"sync/atomic"
"github.com/quic-go/quic-go/internal/utils"
)
@@ -11,7 +12,7 @@ import (
// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
// with an exponential backoff.
type closedLocalConn struct {
counter uint32
counter atomic.Uint32
logger utils.Logger
sendPacket func(net.Addr, packetInfo)
@@ -28,13 +29,13 @@ func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), logger utils.Logg
}
func (c *closedLocalConn) handlePacket(p receivedPacket) {
c.counter++
n := c.counter.Add(1)
// exponential backoff
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
if bits.OnesCount32(c.counter) != 1 {
if bits.OnesCount32(n) != 1 {
return
}
c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", c.counter)
c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", n)
c.sendPacket(p.remoteAddr, p.info)
}

View File

@@ -6,6 +6,8 @@ coverage:
- internal/handshake/cipher_suite.go
- internal/utils/linkedlist/linkedlist.go
- internal/testdata
- logging/connection_tracer_multiplexer.go
- logging/tracer_multiplexer.go
- testutils/
- fuzzing/
- metrics/

View File

@@ -8,7 +8,7 @@ import (
"github.com/quic-go/quic-go/quicvarint"
)
// Clone clones a Config
// Clone clones a Config.
func (c *Config) Clone() *Config {
copy := *c
return &copy

View File

@@ -8,41 +8,67 @@ import (
"github.com/quic-go/quic-go/internal/wire"
)
type connRunnerCallbacks struct {
AddConnectionID func(protocol.ConnectionID)
RemoveConnectionID func(protocol.ConnectionID)
RetireConnectionID func(protocol.ConnectionID)
ReplaceWithClosed func([]protocol.ConnectionID, []byte)
}
type connRunners map[transportID]connRunnerCallbacks
func (cr connRunners) AddConnectionID(id protocol.ConnectionID) {
for _, c := range cr {
c.AddConnectionID(id)
}
}
func (cr connRunners) RemoveConnectionID(id protocol.ConnectionID) {
for _, c := range cr {
c.RemoveConnectionID(id)
}
}
func (cr connRunners) RetireConnectionID(id protocol.ConnectionID) {
for _, c := range cr {
c.RetireConnectionID(id)
}
}
func (cr connRunners) ReplaceWithClosed(ids []protocol.ConnectionID, b []byte) {
for _, c := range cr {
c.ReplaceWithClosed(ids, b)
}
}
type connIDGenerator struct {
generator ConnectionIDGenerator
highestSeq uint64
generator ConnectionIDGenerator
highestSeq uint64
connRunners connRunners
activeSrcConnIDs map[uint64]protocol.ConnectionID
initialClientDestConnID *protocol.ConnectionID // nil for the client
addConnectionID func(protocol.ConnectionID)
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
removeConnectionID func(protocol.ConnectionID)
retireConnectionID func(protocol.ConnectionID)
replaceWithClosed func([]protocol.ConnectionID, []byte)
queueControlFrame func(wire.Frame)
statelessResetter *statelessResetter
queueControlFrame func(wire.Frame)
}
func newConnIDGenerator(
tID transportID,
initialConnectionID protocol.ConnectionID,
initialClientDestConnID *protocol.ConnectionID, // nil for the client
addConnectionID func(protocol.ConnectionID),
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
removeConnectionID func(protocol.ConnectionID),
retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func([]protocol.ConnectionID, []byte),
statelessResetter *statelessResetter,
connRunner connRunnerCallbacks,
queueControlFrame func(wire.Frame),
generator ConnectionIDGenerator,
) *connIDGenerator {
m := &connIDGenerator{
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
addConnectionID: addConnectionID,
getStatelessResetToken: getStatelessResetToken,
removeConnectionID: removeConnectionID,
retireConnectionID: retireConnectionID,
replaceWithClosed: replaceWithClosed,
queueControlFrame: queueControlFrame,
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
statelessResetter: statelessResetter,
connRunners: map[transportID]connRunnerCallbacks{tID: connRunner},
queueControlFrame: queueControlFrame,
}
m.activeSrcConnIDs[0] = initialConnectionID
m.initialClientDestConnID = initialClientDestConnID
@@ -85,7 +111,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
}
}
m.retireConnectionID(connID)
m.connRunners.RetireConnectionID(connID)
delete(m.activeSrcConnIDs, seq)
// Don't issue a replacement for the initial connection ID.
if seq == 0 {
@@ -100,11 +126,11 @@ func (m *connIDGenerator) issueNewConnID() error {
return err
}
m.activeSrcConnIDs[m.highestSeq+1] = connID
m.addConnectionID(connID)
m.connRunners.AddConnectionID(connID)
m.queueControlFrame(&wire.NewConnectionIDFrame{
SequenceNumber: m.highestSeq + 1,
ConnectionID: connID,
StatelessResetToken: m.getStatelessResetToken(connID),
StatelessResetToken: m.statelessResetter.GetStatelessResetToken(connID),
})
m.highestSeq++
return nil
@@ -112,17 +138,17 @@ func (m *connIDGenerator) issueNewConnID() error {
func (m *connIDGenerator) SetHandshakeComplete() {
if m.initialClientDestConnID != nil {
m.retireConnectionID(*m.initialClientDestConnID)
m.connRunners.RetireConnectionID(*m.initialClientDestConnID)
m.initialClientDestConnID = nil
}
}
func (m *connIDGenerator) RemoveAll() {
if m.initialClientDestConnID != nil {
m.removeConnectionID(*m.initialClientDestConnID)
m.connRunners.RemoveConnectionID(*m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
m.removeConnectionID(connID)
m.connRunners.RemoveConnectionID(connID)
}
}
@@ -134,5 +160,20 @@ func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
for _, connID := range m.activeSrcConnIDs {
connIDs = append(connIDs, connID)
}
m.replaceWithClosed(connIDs, connClose)
m.connRunners.ReplaceWithClosed(connIDs, connClose)
}
func (m *connIDGenerator) AddConnRunner(id transportID, r connRunnerCallbacks) {
// The transport might have already been added earlier.
// This happens if the application migrates back to and old path.
if _, ok := m.connRunners[id]; ok {
return
}
m.connRunners[id] = r
if m.initialClientDestConnID != nil {
r.AddConnectionID(*m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
r.AddConnectionID(connID)
}
}

View File

@@ -2,11 +2,11 @@ package quic
import (
"fmt"
"slices"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
"github.com/quic-go/quic-go/internal/wire"
)
@@ -17,7 +17,10 @@ type newConnID struct {
}
type connIDManager struct {
queue list.List[newConnID]
queue []newConnID
highestProbingID uint64
pathProbing map[pathID]newConnID // initialized lazily
handshakeComplete bool
activeSequenceNumber uint64
@@ -35,6 +38,8 @@ type connIDManager struct {
addStatelessResetToken func(protocol.StatelessResetToken)
removeStatelessResetToken func(protocol.StatelessResetToken)
queueControlFrame func(wire.Frame)
closed bool
}
func newConnIDManager(
@@ -48,6 +53,7 @@ func newConnIDManager(
addStatelessResetToken: addStatelessResetToken,
removeStatelessResetToken: removeStatelessResetToken,
queueControlFrame: queueControlFrame,
queue: make([]newConnID, 0, protocol.MaxActiveConnectionIDs),
}
}
@@ -59,36 +65,51 @@ func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error {
if err := h.add(f); err != nil {
return err
}
if h.queue.Len() >= protocol.MaxActiveConnectionIDs {
if len(h.queue) >= protocol.MaxActiveConnectionIDs {
return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}
}
return nil
}
func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
if h.activeConnectionID.Len() == 0 {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "received NEW_CONNECTION_ID frame but zero-length connection IDs are in use",
}
}
// If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active
// connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately.
if f.SequenceNumber < h.activeSequenceNumber || f.SequenceNumber < h.highestRetired {
if f.SequenceNumber < max(h.activeSequenceNumber, h.highestProbingID) || f.SequenceNumber < h.highestRetired {
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: f.SequenceNumber,
})
return nil
}
if f.RetirePriorTo != 0 && h.pathProbing != nil {
for id, entry := range h.pathProbing {
if entry.SequenceNumber < f.RetirePriorTo {
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: entry.SequenceNumber,
})
h.removeStatelessResetToken(entry.StatelessResetToken)
delete(h.pathProbing, id)
}
}
}
// Retire elements in the queue.
// Doesn't retire the active connection ID.
if f.RetirePriorTo > h.highestRetired {
var next *list.Element[newConnID]
for el := h.queue.Front(); el != nil; el = next {
if el.Value.SequenceNumber >= f.RetirePriorTo {
break
var newQueue []newConnID
for _, entry := range h.queue {
if entry.SequenceNumber >= f.RetirePriorTo {
newQueue = append(newQueue, entry)
} else {
h.queueControlFrame(&wire.RetireConnectionIDFrame{SequenceNumber: entry.SequenceNumber})
}
next = el.Next()
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: el.Value.SequenceNumber,
})
h.queue.Remove(el)
}
h.queue = newQueue
h.highestRetired = f.RetirePriorTo
}
@@ -109,39 +130,43 @@ func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
}
func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
// insert a new element at the end
if h.queue.Len() == 0 || h.queue.Back().Value.SequenceNumber < seq {
h.queue.PushBack(newConnID{
// fast path: add to the end of the queue
if len(h.queue) == 0 || h.queue[len(h.queue)-1].SequenceNumber < seq {
h.queue = append(h.queue, newConnID{
SequenceNumber: seq,
ConnectionID: connID,
StatelessResetToken: resetToken,
})
return nil
}
// insert a new element somewhere in the middle
for el := h.queue.Front(); el != nil; el = el.Next() {
if el.Value.SequenceNumber == seq {
if el.Value.ConnectionID != connID {
// slow path: insert in the middle
for i, entry := range h.queue {
if entry.SequenceNumber == seq {
if entry.ConnectionID != connID {
return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq)
}
if el.Value.StatelessResetToken != resetToken {
if entry.StatelessResetToken != resetToken {
return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", seq)
}
break
return nil
}
if el.Value.SequenceNumber > seq {
h.queue.InsertBefore(newConnID{
// insert at the correct position to maintain sorted order
if entry.SequenceNumber > seq {
h.queue = slices.Insert(h.queue, i, newConnID{
SequenceNumber: seq,
ConnectionID: connID,
StatelessResetToken: resetToken,
}, el)
break
})
return nil
}
}
return nil
return nil // unreachable
}
func (h *connIDManager) updateConnectionID() {
h.assertNotClosed()
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: h.activeSequenceNumber,
})
@@ -150,7 +175,8 @@ func (h *connIDManager) updateConnectionID() {
h.removeStatelessResetToken(*h.activeStatelessResetToken)
}
front := h.queue.Remove(h.queue.Front())
front := h.queue[0]
h.queue = h.queue[1:]
h.activeSequenceNumber = front.SequenceNumber
h.activeConnectionID = front.ConnectionID
h.activeStatelessResetToken = &front.StatelessResetToken
@@ -160,9 +186,15 @@ func (h *connIDManager) updateConnectionID() {
}
func (h *connIDManager) Close() {
h.closed = true
if h.activeStatelessResetToken != nil {
h.removeStatelessResetToken(*h.activeStatelessResetToken)
}
if h.pathProbing != nil {
for _, entry := range h.pathProbing {
h.removeStatelessResetToken(entry.StatelessResetToken)
}
}
}
// is called when the server performs a Retry
@@ -176,6 +208,7 @@ func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) {
// is called when the server provides a stateless reset token in the transport parameters
func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToken) {
h.assertNotClosed()
if h.activeSequenceNumber != 0 {
panic("expected first connection ID to have sequence number 0")
}
@@ -192,17 +225,18 @@ func (h *connIDManager) shouldUpdateConnID() bool {
return false
}
// initiate the first change as early as possible (after handshake completion)
if h.queue.Len() > 0 && h.activeSequenceNumber == 0 {
if len(h.queue) > 0 && h.activeSequenceNumber == 0 {
return true
}
// For later changes, only change if
// 1. The queue of connection IDs is filled more than 50%.
// 2. We sent at least PacketsPerConnectionID packets
return 2*h.queue.Len() >= protocol.MaxActiveConnectionIDs &&
return 2*len(h.queue) >= protocol.MaxActiveConnectionIDs &&
h.packetsSinceLastChange >= h.packetsPerConnectionID
}
func (h *connIDManager) Get() protocol.ConnectionID {
h.assertNotClosed()
if h.shouldUpdateConnID() {
h.updateConnectionID()
}
@@ -212,3 +246,76 @@ func (h *connIDManager) Get() protocol.ConnectionID {
func (h *connIDManager) SetHandshakeComplete() {
h.handshakeComplete = true
}
// GetConnIDForPath retrieves a connection ID for a new path (i.e. not the active one).
// Once a connection ID is allocated for a path, it cannot be used for a different path.
// When called with the same pathID, it will return the same connection ID,
// unless the peer requested that this connection ID be retired.
func (h *connIDManager) GetConnIDForPath(id pathID) (protocol.ConnectionID, bool) {
h.assertNotClosed()
// if we're using zero-length connection IDs, we don't need to change the connection ID
if h.activeConnectionID.Len() == 0 {
return protocol.ConnectionID{}, true
}
if h.pathProbing == nil {
h.pathProbing = make(map[pathID]newConnID)
}
entry, ok := h.pathProbing[id]
if ok {
return entry.ConnectionID, true
}
if len(h.queue) == 0 {
return protocol.ConnectionID{}, false
}
front := h.queue[0]
h.queue = h.queue[1:]
h.pathProbing[id] = front
h.highestProbingID = front.SequenceNumber
h.addStatelessResetToken(front.StatelessResetToken)
return front.ConnectionID, true
}
func (h *connIDManager) RetireConnIDForPath(pathID pathID) {
h.assertNotClosed()
// if we're using zero-length connection IDs, we don't need to change the connection ID
if h.activeConnectionID.Len() == 0 {
return
}
entry, ok := h.pathProbing[pathID]
if !ok {
return
}
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: entry.SequenceNumber,
})
h.removeStatelessResetToken(entry.StatelessResetToken)
delete(h.pathProbing, pathID)
}
func (h *connIDManager) IsActiveStatelessResetToken(token protocol.StatelessResetToken) bool {
if h.activeStatelessResetToken != nil {
if *h.activeStatelessResetToken == token {
return true
}
}
if h.pathProbing != nil {
for _, entry := range h.pathProbing {
if entry.StatelessResetToken == token {
return true
}
}
}
return false
}
// Using the connIDManager after it has been closed can have disastrous effects:
// If the connection ID is rotated, a new entry would be inserted into the packet handler map,
// leading to a memory leak of the connection struct.
// See https://github.com/quic-go/quic-go/pull/4852 for more details.
func (h *connIDManager) assertNotClosed() {
if h.closed {
panic("connection ID manager is closed")
}
}

File diff suppressed because it is too large Load Diff

168
vendor/github.com/quic-go/quic-go/connection_logging.go generated vendored Normal file
View File

@@ -0,0 +1,168 @@
package quic
import (
"slices"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
// ConvertFrame converts a wire.Frame into a logging.Frame.
// This makes it possible for external packages to access the frames.
// Furthermore, it removes the data slices from CRYPTO and STREAM frames.
func toLoggingFrame(frame wire.Frame) logging.Frame {
switch f := frame.(type) {
case *wire.AckFrame:
// We use a pool for ACK frames.
// Implementations of the tracer interface may hold on to frames, so we need to make a copy here.
return toLoggingAckFrame(f)
case *wire.CryptoFrame:
return &logging.CryptoFrame{
Offset: f.Offset,
Length: protocol.ByteCount(len(f.Data)),
}
case *wire.StreamFrame:
return &logging.StreamFrame{
StreamID: f.StreamID,
Offset: f.Offset,
Length: f.DataLen(),
Fin: f.Fin,
}
case *wire.DatagramFrame:
return &logging.DatagramFrame{
Length: logging.ByteCount(len(f.Data)),
}
default:
return logging.Frame(frame)
}
}
func toLoggingAckFrame(f *wire.AckFrame) *logging.AckFrame {
ack := &logging.AckFrame{
AckRanges: slices.Clone(f.AckRanges),
DelayTime: f.DelayTime,
ECNCE: f.ECNCE,
ECT0: f.ECT0,
ECT1: f.ECT1,
}
return ack
}
func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) {
// quic-go logging
if s.logger.Debug() {
p.header.Log(s.logger)
if p.ack != nil {
wire.LogFrame(s.logger, p.ack, true)
}
for _, frame := range p.frames {
wire.LogFrame(s.logger, frame.Frame, true)
}
for _, frame := range p.streamFrames {
wire.LogFrame(s.logger, frame.Frame, true)
}
}
// tracing
if s.tracer != nil && s.tracer.SentLongHeaderPacket != nil {
frames := make([]logging.Frame, 0, len(p.frames))
for _, f := range p.frames {
frames = append(frames, toLoggingFrame(f.Frame))
}
for _, f := range p.streamFrames {
frames = append(frames, toLoggingFrame(f.Frame))
}
var ack *logging.AckFrame
if p.ack != nil {
ack = toLoggingAckFrame(p.ack)
}
s.tracer.SentLongHeaderPacket(p.header, p.length, ecn, ack, frames)
}
}
func (s *connection) logShortHeaderPacket(
destConnID protocol.ConnectionID,
ackFrame *wire.AckFrame,
frames []ackhandler.Frame,
streamFrames []ackhandler.StreamFrame,
pn protocol.PacketNumber,
pnLen protocol.PacketNumberLen,
kp protocol.KeyPhaseBit,
ecn protocol.ECN,
size protocol.ByteCount,
isCoalesced bool,
) {
if s.logger.Debug() && !isCoalesced {
s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT (ECN: %s)", pn, size, s.logID, ecn)
}
// quic-go logging
if s.logger.Debug() {
wire.LogShortHeader(s.logger, destConnID, pn, pnLen, kp)
if ackFrame != nil {
wire.LogFrame(s.logger, ackFrame, true)
}
for _, f := range frames {
wire.LogFrame(s.logger, f.Frame, true)
}
for _, f := range streamFrames {
wire.LogFrame(s.logger, f.Frame, true)
}
}
// tracing
if s.tracer != nil && s.tracer.SentShortHeaderPacket != nil {
fs := make([]logging.Frame, 0, len(frames)+len(streamFrames))
for _, f := range frames {
fs = append(fs, toLoggingFrame(f.Frame))
}
for _, f := range streamFrames {
fs = append(fs, toLoggingFrame(f.Frame))
}
var ack *logging.AckFrame
if ackFrame != nil {
ack = toLoggingAckFrame(ackFrame)
}
s.tracer.SentShortHeaderPacket(
&logging.ShortHeader{DestConnectionID: destConnID, PacketNumber: pn, PacketNumberLen: pnLen, KeyPhase: kp},
size,
ecn,
ack,
fs,
)
}
}
func (s *connection) logCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN) {
if s.logger.Debug() {
// There's a short period between dropping both Initial and Handshake keys and completion of the handshake,
// during which we might call PackCoalescedPacket but just pack a short header packet.
if len(packet.longHdrPackets) == 0 && packet.shortHdrPacket != nil {
s.logShortHeaderPacket(
packet.shortHdrPacket.DestConnID,
packet.shortHdrPacket.Ack,
packet.shortHdrPacket.Frames,
packet.shortHdrPacket.StreamFrames,
packet.shortHdrPacket.PacketNumber,
packet.shortHdrPacket.PacketNumberLen,
packet.shortHdrPacket.KeyPhase,
ecn,
packet.shortHdrPacket.Length,
false,
)
return
}
if len(packet.longHdrPackets) > 1 {
s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.longHdrPackets), packet.buffer.Len(), s.logID)
} else {
s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.longHdrPackets[0].header.PacketNumber, packet.buffer.Len(), s.logID, packet.longHdrPackets[0].EncryptionLevel())
}
}
for _, p := range packet.longHdrPackets {
s.logLongHeaderPacket(p, ecn)
}
if p := packet.shortHdrPacket; p != nil {
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, p.Length, true)
}
}

View File

@@ -2,27 +2,14 @@ package quic
import (
"fmt"
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
)
type cryptoStream interface {
// for receiving data
HandleCryptoFrame(*wire.CryptoFrame) error
GetCryptoData() []byte
Finish() error
// for sending data
io.Writer
HasData() bool
PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
}
type cryptoStreamImpl struct {
queue *frameSorter
msgBuf []byte
type cryptoStream struct {
queue frameSorter
highestOffset protocol.ByteCount
finished bool
@@ -31,11 +18,11 @@ type cryptoStreamImpl struct {
writeBuf []byte
}
func newCryptoStream() cryptoStream {
return &cryptoStreamImpl{queue: newFrameSorter()}
func newCryptoStream() *cryptoStream {
return &cryptoStream{queue: *newFrameSorter()}
}
func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
func (s *cryptoStream) HandleCryptoFrame(f *wire.CryptoFrame) error {
highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
return &qerr.TransportError{
@@ -56,26 +43,16 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
return nil
}
s.highestOffset = max(s.highestOffset, highestOffset)
if err := s.queue.Push(f.Data, f.Offset, nil); err != nil {
return err
}
for {
_, data, _ := s.queue.Pop()
if data == nil {
return nil
}
s.msgBuf = append(s.msgBuf, data...)
}
return s.queue.Push(f.Data, f.Offset, nil)
}
// GetCryptoData retrieves data that was received in CRYPTO frames
func (s *cryptoStreamImpl) GetCryptoData() []byte {
b := s.msgBuf
s.msgBuf = nil
return b
func (s *cryptoStream) GetCryptoData() []byte {
_, data, _ := s.queue.Pop()
return data
}
func (s *cryptoStreamImpl) Finish() error {
func (s *cryptoStream) Finish() error {
if s.queue.HasMoreData() {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
@@ -87,16 +64,16 @@ func (s *cryptoStreamImpl) Finish() error {
}
// Writes writes data that should be sent out in CRYPTO frames
func (s *cryptoStreamImpl) Write(p []byte) (int, error) {
func (s *cryptoStream) Write(p []byte) (int, error) {
s.writeBuf = append(s.writeBuf, p...)
return len(p), nil
}
func (s *cryptoStreamImpl) HasData() bool {
func (s *cryptoStream) HasData() bool {
return len(s.writeBuf) > 0
}
func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
func (s *cryptoStream) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
f := &wire.CryptoFrame{Offset: s.writeOffset}
n := min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
f.Data = s.writeBuf[:n]

View File

@@ -3,32 +3,22 @@ package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
type cryptoDataHandler interface {
HandleMessage([]byte, protocol.EncryptionLevel) error
NextEvent() handshake.Event
}
type cryptoStreamManager struct {
cryptoHandler cryptoDataHandler
initialStream cryptoStream
handshakeStream cryptoStream
oneRTTStream cryptoStream
initialStream *cryptoStream
handshakeStream *cryptoStream
oneRTTStream *cryptoStream
}
func newCryptoStreamManager(
cryptoHandler cryptoDataHandler,
initialStream cryptoStream,
handshakeStream cryptoStream,
oneRTTStream cryptoStream,
initialStream *cryptoStream,
handshakeStream *cryptoStream,
oneRTTStream *cryptoStream,
) *cryptoStreamManager {
return &cryptoStreamManager{
cryptoHandler: cryptoHandler,
initialStream: initialStream,
handshakeStream: handshakeStream,
oneRTTStream: oneRTTStream,
@@ -36,7 +26,7 @@ func newCryptoStreamManager(
}
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
var str cryptoStream
var str *cryptoStream
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
switch encLevel {
case protocol.EncryptionInitial:
@@ -48,18 +38,23 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
default:
return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
}
if err := str.HandleCryptoFrame(frame); err != nil {
return err
}
for {
data := str.GetCryptoData()
if data == nil {
return nil
}
if err := m.cryptoHandler.HandleMessage(data, encLevel); err != nil {
return err
}
return str.HandleCryptoFrame(frame)
}
func (m *cryptoStreamManager) GetCryptoData(encLevel protocol.EncryptionLevel) []byte {
var str *cryptoStream
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
switch encLevel {
case protocol.EncryptionInitial:
str = m.initialStream
case protocol.EncryptionHandshake:
str = m.handshakeStream
case protocol.Encryption1RTT:
str = m.oneRTTStream
default:
panic(fmt.Sprintf("received CRYPTO frame with unexpected encryption level: %s", encLevel))
}
return str.GetCryptoData()
}
func (m *cryptoStreamManager) GetPostHandshakeData(maxSize protocol.ByteCount) *wire.CryptoFrame {

View File

@@ -50,8 +50,8 @@ type StreamError struct {
}
func (e *StreamError) Is(target error) bool {
_, ok := target.(*StreamError)
return ok
t, ok := target.(*StreamError)
return ok && e.StreamID == t.StreamID && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote
}
func (e *StreamError) Error() string {
@@ -68,8 +68,8 @@ type DatagramTooLargeError struct {
}
func (e *DatagramTooLargeError) Is(target error) bool {
_, ok := target.(*DatagramTooLargeError)
return ok
t, ok := target.(*DatagramTooLargeError)
return ok && e.MaxDatagramPayloadSize == t.MaxDatagramPayloadSize
}
func (e *DatagramTooLargeError) Error() string { return "DATAGRAM frame too large" }

View File

@@ -1,64 +1,54 @@
package quic
import (
"errors"
"slices"
"sync"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/flowcontrol"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils/ringbuffer"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/quicvarint"
)
type framer interface {
HasData() bool
QueueControlFrame(wire.Frame)
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount)
AddActiveStream(protocol.StreamID)
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount)
Handle0RTTRejection() error
// QueuedTooManyControlFrames says if the control frame queue exceeded its maximum queue length.
// This is a hack.
// It is easier to implement than propagating an error return value in QueueControlFrame.
// The correct solution would be to queue frames with their respective structs.
// See https://github.com/quic-go/quic-go/issues/4271 for the queueing of stream-related control frames.
QueuedTooManyControlFrames() bool
}
const (
maxPathResponses = 256
maxControlFrames = 16 << 10
)
type framerI struct {
// This is the largest possible size of a stream-related control frame
// (which is the RESET_STREAM frame).
const maxStreamControlFrameSize = 25
type streamControlFrameGetter interface {
getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore bool)
}
type framer struct {
mutex sync.Mutex
streamGetter streamGetter
activeStreams map[protocol.StreamID]struct{}
streamQueue ringbuffer.RingBuffer[protocol.StreamID]
activeStreams map[protocol.StreamID]sendStreamI
streamQueue ringbuffer.RingBuffer[protocol.StreamID]
streamsWithControlFrames map[protocol.StreamID]streamControlFrameGetter
controlFrameMutex sync.Mutex
controlFrames []wire.Frame
pathResponses []*wire.PathResponseFrame
connFlowController flowcontrol.ConnectionFlowController
queuedTooManyControlFrames bool
}
var _ framer = &framerI{}
func newFramer(streamGetter streamGetter) framer {
return &framerI{
streamGetter: streamGetter,
activeStreams: make(map[protocol.StreamID]struct{}),
func newFramer(connFlowController flowcontrol.ConnectionFlowController) *framer {
return &framer{
activeStreams: make(map[protocol.StreamID]sendStreamI),
streamsWithControlFrames: make(map[protocol.StreamID]streamControlFrameGetter),
connFlowController: connFlowController,
}
}
func (f *framerI) HasData() bool {
func (f *framer) HasData() bool {
f.mutex.Lock()
hasData := !f.streamQueue.Empty()
f.mutex.Unlock()
@@ -67,10 +57,10 @@ func (f *framerI) HasData() bool {
}
f.controlFrameMutex.Lock()
defer f.controlFrameMutex.Unlock()
return len(f.controlFrames) > 0 || len(f.pathResponses) > 0
return len(f.streamsWithControlFrames) > 0 || len(f.controlFrames) > 0 || len(f.pathResponses) > 0
}
func (f *framerI) QueueControlFrame(frame wire.Frame) {
func (f *framer) QueueControlFrame(frame wire.Frame) {
f.controlFrameMutex.Lock()
defer f.controlFrameMutex.Unlock()
@@ -92,10 +82,80 @@ func (f *framerI) QueueControlFrame(frame wire.Frame) {
f.controlFrames = append(f.controlFrames, frame)
}
func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) {
func (f *framer) Append(
frames []ackhandler.Frame,
streamFrames []ackhandler.StreamFrame,
maxLen protocol.ByteCount,
now time.Time,
v protocol.Version,
) ([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount) {
f.controlFrameMutex.Lock()
defer f.controlFrameMutex.Unlock()
frames, controlFrameLen := f.appendControlFrames(frames, maxLen, now, v)
maxLen -= controlFrameLen
var lastFrame ackhandler.StreamFrame
var streamFrameLen protocol.ByteCount
f.mutex.Lock()
// pop STREAM frames, until less than 128 bytes are left in the packet
numActiveStreams := f.streamQueue.Len()
for i := 0; i < numActiveStreams; i++ {
if protocol.MinStreamFrameSize > maxLen {
break
}
sf, blocked := f.getNextStreamFrame(maxLen, v)
if sf.Frame != nil {
streamFrames = append(streamFrames, sf)
maxLen -= sf.Frame.Length(v)
lastFrame = sf
streamFrameLen += sf.Frame.Length(v)
}
// If the stream just became blocked on stream flow control, attempt to pack the
// STREAM_DATA_BLOCKED into the same packet.
if blocked != nil {
l := blocked.Length(v)
// In case it doesn't fit, queue it for the next packet.
if maxLen < l {
f.controlFrames = append(f.controlFrames, blocked)
break
}
frames = append(frames, ackhandler.Frame{Frame: blocked})
maxLen -= l
controlFrameLen += l
}
}
// The only way to become blocked on connection-level flow control is by sending STREAM frames.
if isBlocked, offset := f.connFlowController.IsNewlyBlocked(); isBlocked {
blocked := &wire.DataBlockedFrame{MaximumData: offset}
l := blocked.Length(v)
// In case it doesn't fit, queue it for the next packet.
if maxLen >= l {
frames = append(frames, ackhandler.Frame{Frame: blocked})
controlFrameLen += l
} else {
f.controlFrames = append(f.controlFrames, blocked)
}
}
f.mutex.Unlock()
f.controlFrameMutex.Unlock()
if lastFrame.Frame != nil {
// account for the smaller size of the last STREAM frame
streamFrameLen -= lastFrame.Frame.Length(v)
lastFrame.Frame.DataLenPresent = false
streamFrameLen += lastFrame.Frame.Length(v)
}
return frames, streamFrames, controlFrameLen + streamFrameLen
}
func (f *framer) appendControlFrames(
frames []ackhandler.Frame,
maxLen protocol.ByteCount,
now time.Time,
v protocol.Version,
) ([]ackhandler.Frame, protocol.ByteCount) {
var length protocol.ByteCount
// add a PATH_RESPONSE first, but only pack a single PATH_RESPONSE per packet
if len(f.pathResponses) > 0 {
@@ -108,6 +168,29 @@ func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol
}
}
// add stream-related control frames
for id, str := range f.streamsWithControlFrames {
start:
remainingLen := maxLen - length
if remainingLen <= maxStreamControlFrameSize {
break
}
fr, ok, hasMore := str.getControlFrame(now)
if !hasMore {
delete(f.streamsWithControlFrames, id)
}
if !ok {
continue
}
frames = append(frames, fr)
length += fr.Frame.Length(v)
if hasMore {
// It is rare that a stream has more than one control frame to queue.
// We don't want to spawn another loop for just to cover that case.
goto start
}
}
for len(f.controlFrames) > 0 {
frame := f.controlFrames[len(f.controlFrames)-1]
frameLen := frame.Length(v)
@@ -118,76 +201,77 @@ func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol
length += frameLen
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
}
return frames, length
}
func (f *framerI) QueuedTooManyControlFrames() bool {
// QueuedTooManyControlFrames says if the control frame queue exceeded its maximum queue length.
// This is a hack.
// It is easier to implement than propagating an error return value in QueueControlFrame.
// The correct solution would be to queue frames with their respective structs.
// See https://github.com/quic-go/quic-go/issues/4271 for the queueing of stream-related control frames.
func (f *framer) QueuedTooManyControlFrames() bool {
return f.queuedTooManyControlFrames
}
func (f *framerI) AddActiveStream(id protocol.StreamID) {
func (f *framer) AddActiveStream(id protocol.StreamID, str sendStreamI) {
f.mutex.Lock()
if _, ok := f.activeStreams[id]; !ok {
f.streamQueue.PushBack(id)
f.activeStreams[id] = struct{}{}
f.activeStreams[id] = str
}
f.mutex.Unlock()
}
func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) {
startLen := len(frames)
var length protocol.ByteCount
func (f *framer) AddStreamWithControlFrames(id protocol.StreamID, str streamControlFrameGetter) {
f.controlFrameMutex.Lock()
if _, ok := f.streamsWithControlFrames[id]; !ok {
f.streamsWithControlFrames[id] = str
}
f.controlFrameMutex.Unlock()
}
// RemoveActiveStream is called when a stream completes.
func (f *framer) RemoveActiveStream(id protocol.StreamID) {
f.mutex.Lock()
// pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
numActiveStreams := f.streamQueue.Len()
for i := 0; i < numActiveStreams; i++ {
if protocol.MinStreamFrameSize+length > maxLen {
break
}
id := f.streamQueue.PopFront()
// This should never return an error. Better check it anyway.
// The stream will only be in the streamQueue, if it enqueued itself there.
str, err := f.streamGetter.GetOrOpenSendStream(id)
// The stream can be nil if it completed after it said it had data.
if str == nil || err != nil {
delete(f.activeStreams, id)
continue
}
remainingLen := maxLen - length
// For the last STREAM frame, we'll remove the DataLen field later.
// Therefore, we can pretend to have more bytes available when popping
// the STREAM frame (which will always have the DataLen set).
remainingLen += protocol.ByteCount(quicvarint.Len(uint64(remainingLen)))
frame, ok, hasMoreData := str.popStreamFrame(remainingLen, v)
if hasMoreData { // put the stream back in the queue (at the end)
f.streamQueue.PushBack(id)
} else { // no more data to send. Stream is not active
delete(f.activeStreams, id)
}
// The frame can be "nil"
// * if the receiveStream was canceled after it said it had data
// * the remaining size doesn't allow us to add another STREAM frame
if !ok {
continue
}
frames = append(frames, frame)
length += frame.Frame.Length(v)
}
delete(f.activeStreams, id)
// We don't delete the stream from the streamQueue,
// since we'd have to iterate over the ringbuffer.
// Instead, we check if the stream is still in activeStreams when appending STREAM frames.
f.mutex.Unlock()
if len(frames) > startLen {
l := frames[len(frames)-1].Frame.Length(v)
// account for the smaller size of the last STREAM frame
frames[len(frames)-1].Frame.DataLenPresent = false
length += frames[len(frames)-1].Frame.Length(v) - l
}
return frames, length
}
func (f *framerI) Handle0RTTRejection() error {
func (f *framer) getNextStreamFrame(maxLen protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame) {
id := f.streamQueue.PopFront()
// This should never return an error. Better check it anyway.
// The stream will only be in the streamQueue, if it enqueued itself there.
str, ok := f.activeStreams[id]
// The stream might have been removed after being enqueued.
if !ok {
return ackhandler.StreamFrame{}, nil
}
// For the last STREAM frame, we'll remove the DataLen field later.
// Therefore, we can pretend to have more bytes available when popping
// the STREAM frame (which will always have the DataLen set).
maxLen += protocol.ByteCount(quicvarint.Len(uint64(maxLen)))
frame, blocked, hasMoreData := str.popStreamFrame(maxLen, v)
if hasMoreData { // put the stream back in the queue (at the end)
f.streamQueue.PushBack(id)
} else { // no more data to send. Stream is not active
delete(f.activeStreams, id)
}
// Note that the frame.Frame can be nil:
// * if the stream was canceled after it said it had data
// * the remaining size doesn't allow us to add another STREAM frame
return frame, blocked
}
func (f *framer) Handle0RTTRejection() {
f.mutex.Lock()
defer f.mutex.Unlock()
f.controlFrameMutex.Lock()
defer f.controlFrameMutex.Unlock()
f.streamQueue.Clear()
for id := range f.activeStreams {
delete(f.activeStreams, id)
@@ -195,16 +279,13 @@ func (f *framerI) Handle0RTTRejection() error {
var j int
for i, frame := range f.controlFrames {
switch frame.(type) {
case *wire.MaxDataFrame, *wire.MaxStreamDataFrame, *wire.MaxStreamsFrame:
return errors.New("didn't expect MAX_DATA / MAX_STREAM_DATA / MAX_STREAMS frame to be sent in 0-RTT")
case *wire.DataBlockedFrame, *wire.StreamDataBlockedFrame, *wire.StreamsBlockedFrame:
case *wire.MaxDataFrame, *wire.MaxStreamDataFrame, *wire.MaxStreamsFrame,
*wire.DataBlockedFrame, *wire.StreamDataBlockedFrame, *wire.StreamsBlockedFrame:
continue
default:
f.controlFrames[j] = f.controlFrames[i]
j++
}
}
f.controlFrames = f.controlFrames[:j]
f.controlFrameMutex.Unlock()
return nil
f.controlFrames = slices.Delete(f.controlFrames, j, len(f.controlFrames))
}

View File

@@ -19,10 +19,6 @@ type StreamID = protocol.StreamID
// A Version is a QUIC version number.
type Version = protocol.Version
// A VersionNumber is a QUIC version number.
// Deprecated: VersionNumber was renamed to Version.
type VersionNumber = Version
const (
// Version1 is RFC 9000
Version1 = protocol.Version1
@@ -48,31 +44,34 @@ type TokenStore interface {
}
// Err0RTTRejected is the returned from:
// * Open{Uni}Stream{Sync}
// * Accept{Uni}Stream
// * Stream.Read and Stream.Write
// - Open{Uni}Stream{Sync}
// - Accept{Uni}Stream
// - Stream.Read and Stream.Write
//
// when the server rejects a 0-RTT connection attempt.
var Err0RTTRejected = errors.New("0-RTT rejected")
// ConnectionTracingKey can be used to associate a ConnectionTracer with a Connection.
// ConnectionTracingKey can be used to associate a [logging.ConnectionTracer] with a [Connection].
// It is set on the Connection.Context() context,
// as well as on the context passed to logging.Tracer.NewConnectionTracer.
//
// Deprecated: Applications can set their own tracing key using Transport.ConnContext.
var ConnectionTracingKey = connTracingCtxKey{}
// ConnectionTracingID is the type of the context value saved under the ConnectionTracingKey.
//
// Deprecated: Applications can set their own tracing key using Transport.ConnContext.
type ConnectionTracingID uint64
type connTracingCtxKey struct{}
// QUICVersionContextKey can be used to find out the QUIC version of a TLS handshake from the
// context returned by tls.Config.ClientHelloInfo.Context.
// context returned by tls.Config.ClientInfo.Context.
var QUICVersionContextKey = handshake.QUICVersionContextKey
// Stream is the interface implemented by QUIC streams
// In addition to the errors listed on the Connection,
// calls to stream functions can return a StreamError if the stream is canceled.
// Stream is the interface implemented by QUIC streams.
// In addition to the errors listed on the [Connection],
// calls to stream functions can return a [StreamError] if the stream is canceled.
type Stream interface {
ReceiveStream
SendStream
@@ -87,12 +86,8 @@ type ReceiveStream interface {
// StreamID returns the stream ID.
StreamID() StreamID
// Read reads data from the stream.
// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
// If the connection was closed due to a timeout, the error satisfies
// the net.Error interface, and Timeout() will be true.
// Read can be made to time out using SetDeadline and SetReadDeadline.
// If the stream was canceled, the error is a StreamError.
io.Reader
// CancelRead aborts receiving on this stream.
// It will ask the peer to stop transmitting stream data.
@@ -102,7 +97,6 @@ type ReceiveStream interface {
// SetReadDeadline sets the deadline for future Read calls and
// any currently-blocked Read call.
// A zero value for t means Read will not time out.
SetReadDeadline(t time.Time) error
}
@@ -111,12 +105,8 @@ type SendStream interface {
// StreamID returns the stream ID.
StreamID() StreamID
// Write writes data to the stream.
// Write can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
// If the connection was closed due to a timeout, the error satisfies
// the net.Error interface, and Timeout() will be true.
// Write can be made to time out using SetDeadline and SetWriteDeadline.
// If the stream was canceled, the error is a StreamError.
io.Writer
// Close closes the write-direction of the stream.
// Future calls to Write are not permitted after calling Close.
@@ -146,45 +136,42 @@ type SendStream interface {
// A Connection is a QUIC connection between two peers.
// Calls to the connection (and to streams) can return the following types of errors:
// * ApplicationError: for errors triggered by the application running on top of QUIC
// * TransportError: for errors triggered by the QUIC transport (in many cases a misbehaving peer)
// * IdleTimeoutError: when the peer goes away unexpectedly (this is a net.Error timeout error)
// * HandshakeTimeoutError: when the cryptographic handshake takes too long (this is a net.Error timeout error)
// * StatelessResetError: when we receive a stateless reset (this is a net.Error temporary error)
// * VersionNegotiationError: returned by the client, when there's no version overlap between the peers
// - [ApplicationError]: for errors triggered by the application running on top of QUIC
// - [TransportError]: for errors triggered by the QUIC transport (in many cases a misbehaving peer)
// - [IdleTimeoutError]: when the peer goes away unexpectedly (this is a [net.Error] timeout error)
// - [HandshakeTimeoutError]: when the cryptographic handshake takes too long (this is a [net.Error] timeout error)
// - [StatelessResetError]: when we receive a stateless reset
// - [VersionNegotiationError]: returned by the client, when there's no version overlap between the peers
type Connection interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
// If the connection was closed due to a timeout, the error satisfies
// the net.Error interface, and Timeout() will be true.
AcceptStream(context.Context) (Stream, error)
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
// If the connection was closed due to a timeout, the error satisfies
// the net.Error interface, and Timeout() will be true.
AcceptUniStream(context.Context) (ReceiveStream, error)
// OpenStream opens a new bidirectional QUIC stream.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream.
// If the error is non-nil, it satisfies the net.Error interface.
// When reaching the peer's stream limit, err.Temporary() will be true.
// If the connection was closed due to a timeout, Timeout() will be true.
// The peer can only accept the stream after data has been sent on the stream,
// or the stream has been reset or closed.
// When reaching the peer's stream limit, it is not possible to open a new stream until the
// peer raises the stream limit. In that case, a StreamLimitReachedError is returned.
OpenStream() (Stream, error)
// OpenStreamSync opens a new bidirectional QUIC stream.
// It blocks until a new stream can be opened.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream,
// or the stream has been reset or closed.
// If the error is non-nil, it satisfies the net.Error interface.
// If the connection was closed due to a timeout, Timeout() will be true.
OpenStreamSync(context.Context) (Stream, error)
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
// If the error is non-nil, it satisfies the net.Error interface.
// When reaching the peer's stream limit, Temporary() will be true.
// If the connection was closed due to a timeout, Timeout() will be true.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream,
// or the stream has been reset or closed.
// When reaching the peer's stream limit, it is not possible to open a new stream until the
// peer raises the stream limit. In that case, a StreamLimitReachedError is returned.
OpenUniStream() (SendStream, error)
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
// It blocks until a new stream can be opened.
// If the error is non-nil, it satisfies the net.Error interface.
// If the connection was closed due to a timeout, Timeout() will be true.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream,
// or the stream has been reset or closed.
OpenUniStreamSync(context.Context) (SendStream, error)
// LocalAddr returns the local address.
LocalAddr() net.Addr
@@ -209,6 +196,8 @@ type Connection interface {
SendDatagram(payload []byte) error
// ReceiveDatagram gets a message received in a datagram, as specified in RFC 9221.
ReceiveDatagram(context.Context) ([]byte, error)
AddPath(*Transport) (*Path, error)
}
// An EarlyConnection is a connection that is handshaking.
@@ -238,27 +227,22 @@ type TokenGeneratorKey = handshake.TokenProtectorKey
// as they are allowed by RFC 8999.
type ConnectionID = protocol.ConnectionID
// ConnectionIDFromBytes interprets b as a Connection ID. It panics if b is
// ConnectionIDFromBytes interprets b as a [ConnectionID]. It panics if b is
// longer than 20 bytes.
func ConnectionIDFromBytes(b []byte) ConnectionID {
return protocol.ParseConnectionID(b)
}
// A ConnectionIDGenerator is an interface that allows clients to implement their own format
// for the Connection IDs that servers/clients use as SrcConnectionID in QUIC packets.
//
// Connection IDs generated by an implementation should always produce IDs of constant size.
// A ConnectionIDGenerator allows the application to take control over the generation of Connection IDs.
// Connection IDs generated by an implementation must be of constant length.
type ConnectionIDGenerator interface {
// GenerateConnectionID generates a new ConnectionID.
// Generated ConnectionIDs should be unique and observers should not be able to correlate two ConnectionIDs.
// GenerateConnectionID generates a new Connection ID.
// Generated Connection IDs must be unique and observers should not be able to correlate two Connection IDs.
GenerateConnectionID() (ConnectionID, error)
// ConnectionIDLen tells what is the length of the ConnectionIDs generated by the implementation of
// this interface.
// Effectively, this means that implementations of ConnectionIDGenerator must always return constant-size
// connection IDs. Valid lengths are between 0 and 20 and calls to GenerateConnectionID.
// 0-length ConnectionsIDs can be used when an endpoint (server or client) does not require multiplexing connections
// in the presence of a connection migration environment.
// ConnectionIDLen returns the length of Connection IDs generated by this implementation.
// Implementations must return constant-length Connection IDs with lengths between 0 and 20 bytes.
// A length of 0 can only be used when an endpoint doesn't need to multiplex connections during migration.
ConnectionIDLen() int
}
@@ -266,7 +250,7 @@ type ConnectionIDGenerator interface {
type Config struct {
// GetConfigForClient is called for incoming connections.
// If the error is not nil, the connection attempt is refused.
GetConfigForClient func(info *ClientHelloInfo) (*Config, error)
GetConfigForClient func(info *ClientInfo) (*Config, error)
// The QUIC versions that can be negotiated.
// If not set, it uses all versions available.
Versions []Version
@@ -327,10 +311,10 @@ type Config struct {
// If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most
// every half of MaxIdleTimeout, whichever is smaller).
KeepAlivePeriod time.Duration
// InitialPacketSize is the initial size of packets sent.
// It is usually not necessary to manually set this value,
// since Path MTU discovery very quickly finds the path's MTU.
// If set too high, the path might not support packets that large, leading to a timeout of the QUIC handshake.
// InitialPacketSize is the initial size (and the lower limit) for packets sent.
// Under most circumstances, it is not necessary to manually set this value,
// since path MTU discovery quickly finds the path's MTU.
// If set too high, the path might not support packets of that size, leading to a timeout of the QUIC handshake.
// Values below 1200 are invalid.
InitialPacketSize uint16
// DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899).
@@ -346,7 +330,12 @@ type Config struct {
}
// ClientHelloInfo contains information about an incoming connection attempt.
type ClientHelloInfo struct {
//
// Deprecated: Use ClientInfo instead.
type ClientHelloInfo = ClientInfo
// ClientInfo contains information about an incoming connection attempt.
type ClientInfo struct {
// RemoteAddr is the remote address on the Initial packet.
// Unless AddrVerified is set, the address is not yet verified, and could be a spoofed IP address.
RemoteAddr net.Addr
@@ -356,19 +345,19 @@ type ClientHelloInfo struct {
AddrVerified bool
}
// ConnectionState records basic details about a QUIC connection
// ConnectionState records basic details about a QUIC connection.
type ConnectionState struct {
// TLS contains information about the TLS connection state, incl. the tls.ConnectionState.
TLS tls.ConnectionState
// SupportsDatagrams says if support for QUIC datagrams (RFC 9221) was negotiated.
// This requires both nodes to support and enable the datagram extensions (via Config.EnableDatagrams).
// If datagram support was negotiated, datagrams can be sent and received using the
// SendDatagram and ReceiveDatagram methods on the Connection.
// SupportsDatagrams indicates whether the peer advertised support for QUIC datagrams (RFC 9221).
// When true, datagrams can be sent using the Connection's SendDatagram method.
// This is a unilateral declaration by the peer - receiving datagrams is only possible if
// datagram support was enabled locally via Config.EnableDatagrams.
SupportsDatagrams bool
// Used0RTT says if 0-RTT resumption was used.
Used0RTT bool
// Version is the QUIC version of the QUIC connection.
Version Version
// GSO says if generic segmentation offload is used
// GSO says if generic segmentation offload is used.
GSO bool
}

View File

@@ -10,14 +10,13 @@ import (
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool)
SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket, isPathProbePacket bool)
// ReceivedAck processes an ACK frame.
// It does not store a copy of the frame.
ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error)
ReceivedBytes(protocol.ByteCount)
DropPackets(protocol.EncryptionLevel)
ResetForRetry(rcvTime time.Time) error
SetHandshakeConfirmed()
ReceivedBytes(_ protocol.ByteCount, rcvTime time.Time)
DropPackets(_ protocol.EncryptionLevel, rcvTime time.Time)
ResetForRetry(rcvTime time.Time)
// The SendMode determines if and what kind of packets can be sent.
SendMode(now time.Time) SendMode
@@ -34,12 +33,14 @@ type SentPacketHandler interface {
PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber
GetLossDetectionTimeout() time.Time
OnLossDetectionTimeout() error
OnLossDetectionTimeout(now time.Time) error
MigratedPath(now time.Time, initialMaxPacketSize protocol.ByteCount)
}
type sentPacketTracker interface {
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
ReceivedPacket(protocol.EncryptionLevel)
ReceivedPacket(_ protocol.EncryptionLevel, rcvTime time.Time)
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
@@ -49,5 +50,5 @@ type ReceivedPacketHandler interface {
DropPackets(protocol.EncryptionLevel)
GetAlarmTimeout() time.Time
GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame
GetAckFrame(_ protocol.EncryptionLevel, now time.Time, onlyIfQueued bool) *wire.AckFrame
}

View File

@@ -22,10 +22,11 @@ type packet struct {
includedInBytesInFlight bool
declaredLost bool
skippedPacket bool
isPathProbePacket bool
}
func (p *packet) outstanding() bool {
return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket
return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket && !p.isPathProbePacket
}
var packetPool = sync.Pool{New: func() any { return &packet{} }}

View File

@@ -38,7 +38,7 @@ func (h *receivedPacketHandler) ReceivedPacket(
rcvTime time.Time,
ackEliciting bool,
) error {
h.sentPackets.ReceivedPacket(encLevel)
h.sentPackets.ReceivedPacket(encLevel, rcvTime)
switch encLevel {
case protocol.EncryptionInitial:
return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting)
@@ -87,7 +87,7 @@ func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
return h.appDataPackets.GetAlarmTimeout()
}
func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame {
func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, now time.Time, onlyIfQueued bool) *wire.AckFrame {
//nolint:exhaustive // 0-RTT packets can't contain ACK frames.
switch encLevel {
case protocol.EncryptionInitial:
@@ -101,7 +101,7 @@ func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, o
}
return nil
case protocol.Encryption1RTT:
return h.appDataPackets.GetAckFrame(onlyIfQueued)
return h.appDataPackets.GetAckFrame(now, onlyIfQueued)
default:
// 0-RTT packets can't contain ACK frames
return nil

View File

@@ -1,10 +1,9 @@
package ackhandler
import (
"sync"
"slices"
"github.com/quic-go/quic-go/internal/protocol"
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
"github.com/quic-go/quic-go/internal/wire"
)
@@ -14,25 +13,17 @@ type interval struct {
End protocol.PacketNumber
}
var intervalElementPool sync.Pool
func init() {
intervalElementPool = *list.NewPool[interval]()
}
// The receivedPacketHistory stores if a packet number has already been received.
// It generates ACK ranges which can be used to assemble an ACK frame.
// It does not store packet contents.
type receivedPacketHistory struct {
ranges *list.List[interval]
ranges []interval // maximum length: protocol.MaxNumAckRanges
deletedBelow protocol.PacketNumber
}
func newReceivedPacketHistory() *receivedPacketHistory {
return &receivedPacketHistory{
ranges: list.NewWithPool[interval](&intervalElementPool),
}
return &receivedPacketHistory{}
}
// ReceivedPacket registers a packet with PacketNumber p and updates the ranges
@@ -41,58 +32,54 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) bool /*
if p < h.deletedBelow {
return false
}
isNew := h.addToRanges(p)
h.maybeDeleteOldRanges()
// Delete old ranges, if we're tracking too many of them.
// This is a DoS defense against a peer that sends us too many gaps.
if len(h.ranges) > protocol.MaxNumAckRanges {
h.ranges = slices.Delete(h.ranges, 0, len(h.ranges)-protocol.MaxNumAckRanges)
}
return isNew
}
func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ {
if h.ranges.Len() == 0 {
h.ranges.PushBack(interval{Start: p, End: p})
if len(h.ranges) == 0 {
h.ranges = append(h.ranges, interval{Start: p, End: p})
return true
}
for el := h.ranges.Back(); el != nil; el = el.Prev() {
for i := len(h.ranges) - 1; i >= 0; i-- {
// p already included in an existing range. Nothing to do here
if p >= el.Value.Start && p <= el.Value.End {
if p >= h.ranges[i].Start && p <= h.ranges[i].End {
return false
}
if el.Value.End == p-1 { // extend a range at the end
el.Value.End = p
if h.ranges[i].End == p-1 { // extend a range at the end
h.ranges[i].End = p
return true
}
if el.Value.Start == p+1 { // extend a range at the beginning
el.Value.Start = p
if h.ranges[i].Start == p+1 { // extend a range at the beginning
h.ranges[i].Start = p
prev := el.Prev()
if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges
prev.Value.End = el.Value.End
h.ranges.Remove(el)
if i > 0 && h.ranges[i-1].End+1 == h.ranges[i].Start { // merge two ranges
h.ranges[i-1].End = h.ranges[i].End
h.ranges = slices.Delete(h.ranges, i, i+1)
}
return true
}
// create a new range at the end
if p > el.Value.End {
h.ranges.InsertAfter(interval{Start: p, End: p}, el)
// create a new range after the current one
if p > h.ranges[i].End {
h.ranges = slices.Insert(h.ranges, i+1, interval{Start: p, End: p})
return true
}
}
// create a new range at the beginning
h.ranges.InsertBefore(interval{Start: p, End: p}, h.ranges.Front())
h.ranges = slices.Insert(h.ranges, 0, interval{Start: p, End: p})
return true
}
// Delete old ranges, if we're tracking more than 500 of them.
// This is a DoS defense against a peer that sends us too many gaps.
func (h *receivedPacketHistory) maybeDeleteOldRanges() {
for h.ranges.Len() > protocol.MaxNumAckRanges {
h.ranges.Remove(h.ranges.Front())
}
}
// DeleteBelow deletes all entries below (but not including) p
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
if p < h.deletedBelow {
@@ -100,37 +87,39 @@ func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
}
h.deletedBelow = p
nextEl := h.ranges.Front()
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
nextEl = el.Next()
if len(h.ranges) == 0 {
return
}
if el.Value.End < p { // delete a whole range
h.ranges.Remove(el)
} else if p > el.Value.Start && p <= el.Value.End {
el.Value.Start = p
return
idx := -1
for i := 0; i < len(h.ranges); i++ {
if h.ranges[i].End < p { // delete a whole range
idx = i
} else if p > h.ranges[i].Start && p <= h.ranges[i].End {
h.ranges[i].Start = p
break
} else { // no ranges affected. Nothing to do
return
break
}
}
if idx >= 0 {
h.ranges = slices.Delete(h.ranges, 0, idx+1)
}
}
// AppendAckRanges appends to a slice of all AckRanges that can be used in an AckFrame
func (h *receivedPacketHistory) AppendAckRanges(ackRanges []wire.AckRange) []wire.AckRange {
if h.ranges.Len() > 0 {
for el := h.ranges.Back(); el != nil; el = el.Prev() {
ackRanges = append(ackRanges, wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End})
}
for i := len(h.ranges) - 1; i >= 0; i-- {
ackRanges = append(ackRanges, wire.AckRange{Smallest: h.ranges[i].Start, Largest: h.ranges[i].End})
}
return ackRanges
}
func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
ackRange := wire.AckRange{}
if h.ranges.Len() > 0 {
r := h.ranges.Back().Value
ackRange.Smallest = r.Start
ackRange.Largest = r.End
if len(h.ranges) > 0 {
ackRange.Smallest = h.ranges[len(h.ranges)-1].Start
ackRange.Largest = h.ranges[len(h.ranges)-1].End
}
return ackRange
}
@@ -139,11 +128,12 @@ func (h *receivedPacketHistory) IsPotentiallyDuplicate(p protocol.PacketNumber)
if p < h.deletedBelow {
return true
}
for el := h.ranges.Back(); el != nil; el = el.Prev() {
if p > el.Value.End {
// Iterating over the slices is faster than using a binary search (using slices.BinarySearchFunc).
for i := len(h.ranges) - 1; i >= 0; i-- {
if p > h.ranges[i].End {
return false
}
if p <= el.Value.End && p >= el.Value.Start {
if p <= h.ranges[i].End && p >= h.ranges[i].Start {
return true
}
}

View File

@@ -196,8 +196,7 @@ func (h *appDataReceivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber,
return false
}
func (h *appDataReceivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
now := time.Now()
func (h *appDataReceivedPacketTracker) GetAckFrame(now time.Time, onlyIfQueued bool) *wire.AckFrame {
if onlyIfQueued && !h.ackQueued {
if h.ackAlarm.IsZero() || h.ackAlarm.After(now) {
return nil

View File

@@ -27,8 +27,11 @@ const (
maxPTODuration = 60 * time.Second
)
// Path probe packets are declared lost after this time.
const pathProbePacketLossTimeout = time.Second
type packetNumberSpace struct {
history *sentPacketHistory
history sentPacketHistory
pns packetNumberGenerator
lossTime time.Time
@@ -38,21 +41,27 @@ type packetNumberSpace struct {
largestSent protocol.PacketNumber
}
func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool) *packetNumberSpace {
func newPacketNumberSpace(initialPN protocol.PacketNumber, isAppData bool) *packetNumberSpace {
var pns packetNumberGenerator
if skipPNs {
if isAppData {
pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod)
} else {
pns = newSequentialPacketNumberGenerator(initialPN)
}
return &packetNumberSpace{
history: newSentPacketHistory(),
history: *newSentPacketHistory(isAppData),
pns: pns,
largestSent: protocol.InvalidPacketNumber,
largestAcked: protocol.InvalidPacketNumber,
}
}
type alarmTimer struct {
Time time.Time
TimerType logging.TimerType
EncryptionLevel protocol.EncryptionLevel
}
type sentPacketHandler struct {
initialPackets *packetNumberSpace
handshakePackets *packetNumberSpace
@@ -90,7 +99,7 @@ type sentPacketHandler struct {
numProbesToSend int
// The alarm timeout
alarm time.Time
alarm alarmTimer
enableECN bool
ecnTracker ecnHandler
@@ -155,7 +164,7 @@ func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) {
}
}
func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now time.Time) {
// The server won't await address validation after the handshake is confirmed.
// This applies even if we didn't receive an ACK for a Handshake packet.
if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake {
@@ -168,10 +177,9 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
if pnSpace == nil {
return
}
pnSpace.history.Iterate(func(p *packet) (bool, error) {
for p := range pnSpace.history.Packets() {
h.removeFromBytesInFlight(p)
return true, nil
})
}
}
// drop the packet history
//nolint:exhaustive // Not every packet number space can be dropped.
@@ -179,20 +187,22 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
case protocol.EncryptionInitial:
h.initialPackets = nil
case protocol.EncryptionHandshake:
// Dropping the handshake packet number space means that the handshake is confirmed,
// see section 4.9.2 of RFC 9001.
h.handshakeConfirmed = true
h.handshakePackets = nil
case protocol.Encryption0RTT:
// This function is only called when 0-RTT is rejected,
// and not when the client drops 0-RTT keys when the handshake completes.
// When 0-RTT is rejected, all application data sent so far becomes invalid.
// Delete the packets from the history and remove them from bytes_in_flight.
h.appDataPackets.history.Iterate(func(p *packet) (bool, error) {
for p := range h.appDataPackets.history.Packets() {
if p.EncryptionLevel != protocol.Encryption0RTT && !p.skippedPacket {
return false, nil
break
}
h.removeFromBytesInFlight(p)
h.appDataPackets.history.Remove(p.PacketNumber)
return true, nil
})
}
default:
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
}
@@ -202,21 +212,21 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
h.ptoCount = 0
h.numProbesToSend = 0
h.ptoMode = SendNone
h.setLossDetectionTimer()
h.setLossDetectionTimer(now)
}
func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) {
func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount, t time.Time) {
wasAmplificationLimit := h.isAmplificationLimited()
h.bytesReceived += n
if wasAmplificationLimit && !h.isAmplificationLimited() {
h.setLossDetectionTimer()
h.setLossDetectionTimer(t)
}
}
func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel) {
func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel, t time.Time) {
if h.perspective == protocol.PerspectiveServer && l == protocol.EncryptionHandshake && !h.peerAddressValidated {
h.peerAddressValidated = true
h.setLossDetectionTimer()
h.setLossDetectionTimer(t)
}
}
@@ -240,11 +250,12 @@ func (h *sentPacketHandler) SentPacket(
ecn protocol.ECN,
size protocol.ByteCount,
isPathMTUProbePacket bool,
isPathProbePacket bool,
) {
h.bytesSent += size
pnSpace := h.getPacketNumberSpace(encLevel)
if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() {
if h.logger.Debug() && (pnSpace.history.HasOutstandingPackets() || pnSpace.history.HasOutstandingPathProbes()) {
for p := max(0, pnSpace.largestSent+1); p < pn; p++ {
h.logger.Debugf("Skipping packet number %d", p)
}
@@ -253,6 +264,18 @@ func (h *sentPacketHandler) SentPacket(
pnSpace.largestSent = pn
isAckEliciting := len(streamFrames) > 0 || len(frames) > 0
if isPathProbePacket {
p := getPacket()
p.SendTime = t
p.PacketNumber = pn
p.EncryptionLevel = encLevel
p.Length = size
p.Frames = frames
p.isPathProbePacket = true
pnSpace.history.SentPathProbePacket(p)
h.setLossDetectionTimer(t)
return
}
if isAckEliciting {
pnSpace.lastAckElicitingPacketTime = t
h.bytesInFlight += size
@@ -269,7 +292,7 @@ func (h *sentPacketHandler) SentPacket(
if !isAckEliciting {
pnSpace.history.SentNonAckElicitingPacket(pn)
if !h.peerCompletedAddressValidation {
h.setLossDetectionTimer()
h.setLossDetectionTimer(t)
}
return
}
@@ -289,7 +312,7 @@ func (h *sentPacketHandler) SentPacket(
if h.tracer != nil && h.tracer.UpdatedMetrics != nil {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
}
h.setLossDetectionTimer()
h.setLossDetectionTimer(t)
}
func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace {
@@ -322,7 +345,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
h.peerCompletedAddressValidation = true
h.logger.Debugf("Peer doesn't await address validation any longer.")
// Make sure that the timer is reset, even if this ACK doesn't acknowledge any (ack-eliciting) packets.
h.setLossDetectionTimer()
h.setLossDetectionTimer(rcvTime)
}
priorInFlight := h.bytesInFlight
@@ -332,13 +355,13 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
}
// update the RTT, if the largest acked is newly acknowledged
if len(ackedPackets) > 0 {
if p := ackedPackets[len(ackedPackets)-1]; p.PacketNumber == ack.LargestAcked() {
if p := ackedPackets[len(ackedPackets)-1]; p.PacketNumber == ack.LargestAcked() && !p.isPathProbePacket {
// don't use the ack delay for Initial and Handshake packets
var ackDelay time.Duration
if encLevel == protocol.Encryption1RTT {
ackDelay = min(ack.DelayTime, h.rttStats.MaxAckDelay())
}
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay)
if h.logger.Debug() {
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
}
@@ -356,8 +379,9 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
pnSpace.largestAcked = max(pnSpace.largestAcked, largestAcked)
if err := h.detectLostPackets(rcvTime, encLevel); err != nil {
return false, err
h.detectLostPackets(rcvTime, encLevel)
if encLevel == protocol.Encryption1RTT {
h.detectLostPathProbes(rcvTime)
}
var acked1RTTPacket bool
for _, p := range ackedPackets {
@@ -368,7 +392,9 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
acked1RTTPacket = true
}
h.removeFromBytesInFlight(p)
putPacket(p)
if !p.isPathProbePacket {
putPacket(p)
}
}
// After this point, we must not use ackedPackets any longer!
// We've already returned the buffers.
@@ -387,7 +413,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
}
h.setLossDetectionTimer()
h.setLossDetectionTimer(rcvTime)
return acked1RTTPacket, nil
}
@@ -402,14 +428,13 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
ackRangeIndex := 0
lowestAcked := ack.LowestAcked()
largestAcked := ack.LargestAcked()
err := pnSpace.history.Iterate(func(p *packet) (bool, error) {
// Ignore packets below the lowest acked
for p := range pnSpace.history.Packets() {
// ignore packets below the lowest acked
if p.PacketNumber < lowestAcked {
return true, nil
continue
}
// Break after largest acked is reached
if p.PacketNumber > largestAcked {
return false, nil
break
}
if ack.HasMissingRanges() {
@@ -421,21 +446,28 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
}
if p.PacketNumber < ackRange.Smallest { // packet not contained in ACK range
return true, nil
continue
}
if p.PacketNumber > ackRange.Largest {
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
}
}
if p.skippedPacket {
return false, &qerr.TransportError{
return nil, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel),
}
}
if p.isPathProbePacket {
probePacket := pnSpace.history.RemovePathProbe(p.PacketNumber)
// the probe packet might already have been declared lost
if probePacket != nil {
h.ackedPackets = append(h.ackedPackets, probePacket)
}
continue
}
h.ackedPackets = append(h.ackedPackets, p)
return true, nil
})
}
if h.logger.Debug() && len(h.ackedPackets) > 0 {
pns := make([]protocol.PacketNumber, len(h.ackedPackets))
for i, p := range h.ackedPackets {
@@ -466,8 +498,7 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber)
}
}
return h.ackedPackets, err
return h.ackedPackets, nil
}
func (h *sentPacketHandler) getLossTimeAndSpace() (time.Time, protocol.EncryptionLevel) {
@@ -498,41 +529,44 @@ func (h *sentPacketHandler) getScaledPTO(includeMaxAckDelay bool) time.Duration
}
// same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime
func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) {
func (h *sentPacketHandler) getPTOTimeAndSpace(now time.Time) (pto time.Time, encLevel protocol.EncryptionLevel) {
// We only send application data probe packets once the handshake is confirmed,
// because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets.
if !h.handshakeConfirmed && !h.hasOutstandingCryptoPackets() {
if h.peerCompletedAddressValidation {
return
}
t := time.Now().Add(h.getScaledPTO(false))
t := now.Add(h.getScaledPTO(false))
if h.initialPackets != nil {
return t, protocol.EncryptionInitial, true
return t, protocol.EncryptionInitial
}
return t, protocol.EncryptionHandshake, true
return t, protocol.EncryptionHandshake
}
if h.initialPackets != nil {
if h.initialPackets != nil && h.initialPackets.history.HasOutstandingPackets() &&
!h.initialPackets.lastAckElicitingPacketTime.IsZero() {
encLevel = protocol.EncryptionInitial
if t := h.initialPackets.lastAckElicitingPacketTime; !t.IsZero() {
pto = t.Add(h.getScaledPTO(false))
}
}
if h.handshakePackets != nil && !h.handshakePackets.lastAckElicitingPacketTime.IsZero() {
if h.handshakePackets != nil && h.handshakePackets.history.HasOutstandingPackets() &&
!h.handshakePackets.lastAckElicitingPacketTime.IsZero() {
t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(false))
if pto.IsZero() || (!t.IsZero() && t.Before(pto)) {
pto = t
encLevel = protocol.EncryptionHandshake
}
}
if h.handshakeConfirmed && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() {
if h.handshakeConfirmed && h.appDataPackets.history.HasOutstandingPackets() &&
!h.appDataPackets.lastAckElicitingPacketTime.IsZero() {
t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(true))
if pto.IsZero() || (!t.IsZero() && t.Before(pto)) {
pto = t
encLevel = protocol.Encryption1RTT
}
}
return pto, encLevel, true
return pto, encLevel
}
func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool {
@@ -545,65 +579,91 @@ func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool {
return false
}
func (h *sentPacketHandler) hasOutstandingPackets() bool {
return h.appDataPackets.history.HasOutstandingPackets() || h.hasOutstandingCryptoPackets()
}
func (h *sentPacketHandler) setLossDetectionTimer() {
func (h *sentPacketHandler) setLossDetectionTimer(now time.Time) {
oldAlarm := h.alarm // only needed in case tracing is enabled
lossTime, encLevel := h.getLossTimeAndSpace()
if !lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = lossTime
if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm {
h.tracer.SetLossTimer(logging.TimerTypeACK, encLevel, h.alarm)
newAlarm := h.lossDetectionTime(now)
h.alarm = newAlarm
hasAlarm := !newAlarm.Time.IsZero()
if !hasAlarm && !oldAlarm.Time.IsZero() {
h.logger.Debugf("Canceling loss detection timer.")
if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
h.tracer.LossTimerCanceled()
}
return
}
// Cancel the alarm if amplification limited.
if h.isAmplificationLimited() {
h.alarm = time.Time{}
if !oldAlarm.IsZero() {
h.logger.Debugf("Canceling loss detection timer. Amplification limited.")
if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
h.tracer.LossTimerCanceled()
}
}
return
}
// Cancel the alarm if no packets are outstanding
if !h.hasOutstandingPackets() && h.peerCompletedAddressValidation {
h.alarm = time.Time{}
if !oldAlarm.IsZero() {
h.logger.Debugf("Canceling loss detection timer. No packets in flight.")
if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
h.tracer.LossTimerCanceled()
}
}
return
}
// PTO alarm
ptoTime, encLevel, ok := h.getPTOTimeAndSpace()
if !ok {
if !oldAlarm.IsZero() {
h.alarm = time.Time{}
h.logger.Debugf("Canceling loss detection timer. No PTO needed..")
if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
h.tracer.LossTimerCanceled()
}
}
return
}
h.alarm = ptoTime
if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm {
h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm)
if hasAlarm && h.tracer != nil && h.tracer.SetLossTimer != nil && newAlarm != oldAlarm {
h.tracer.SetLossTimer(newAlarm.TimerType, newAlarm.EncryptionLevel, newAlarm.Time)
}
}
func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) error {
func (h *sentPacketHandler) lossDetectionTime(now time.Time) alarmTimer {
// cancel the alarm if no packets are outstanding
if h.peerCompletedAddressValidation && !h.hasOutstandingCryptoPackets() &&
!h.appDataPackets.history.HasOutstandingPackets() && !h.appDataPackets.history.HasOutstandingPathProbes() {
return alarmTimer{}
}
// cancel the alarm if amplification limited
if h.isAmplificationLimited() {
return alarmTimer{}
}
var pathProbeLossTime time.Time
if h.appDataPackets.history.HasOutstandingPathProbes() {
if p := h.appDataPackets.history.FirstOutstandingPathProbe(); p != nil {
pathProbeLossTime = p.SendTime.Add(pathProbePacketLossTimeout)
}
}
// early retransmit timer or time loss detection
lossTime, encLevel := h.getLossTimeAndSpace()
if !lossTime.IsZero() && (pathProbeLossTime.IsZero() || lossTime.Before(pathProbeLossTime)) {
return alarmTimer{
Time: lossTime,
TimerType: logging.TimerTypeACK,
EncryptionLevel: encLevel,
}
}
ptoTime, encLevel := h.getPTOTimeAndSpace(now)
if !ptoTime.IsZero() && (pathProbeLossTime.IsZero() || ptoTime.Before(pathProbeLossTime)) {
return alarmTimer{
Time: ptoTime,
TimerType: logging.TimerTypePTO,
EncryptionLevel: encLevel,
}
}
if !pathProbeLossTime.IsZero() {
return alarmTimer{
Time: pathProbeLossTime,
TimerType: logging.TimerTypePathProbe,
EncryptionLevel: protocol.Encryption1RTT,
}
}
return alarmTimer{}
}
func (h *sentPacketHandler) detectLostPathProbes(now time.Time) {
if !h.appDataPackets.history.HasOutstandingPathProbes() {
return
}
lossTime := now.Add(-pathProbePacketLossTimeout)
// RemovePathProbe cannot be called while iterating.
var lostPathProbes []*packet
for p := range h.appDataPackets.history.PathProbes() {
if !p.SendTime.After(lossTime) {
lostPathProbes = append(lostPathProbes, p)
}
}
for _, p := range lostPathProbes {
for _, f := range p.Frames {
f.Handler.OnLost(f.Frame)
}
h.appDataPackets.history.RemovePathProbe(p.PacketNumber)
}
}
func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) {
pnSpace := h.getPacketNumberSpace(encLevel)
pnSpace.lossTime = time.Time{}
@@ -617,15 +677,16 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
lostSendTime := now.Add(-lossDelay)
priorInFlight := h.bytesInFlight
return pnSpace.history.Iterate(func(p *packet) (bool, error) {
for p := range pnSpace.history.Packets() {
if p.PacketNumber > pnSpace.largestAcked {
return false, nil
break
}
isRegularPacket := !p.skippedPacket && !p.isPathProbePacket
var packetLost bool
if p.SendTime.Before(lostSendTime) {
if !p.SendTime.After(lostSendTime) {
packetLost = true
if !p.skippedPacket {
if isRegularPacket {
if h.logger.Debug() {
h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber)
}
@@ -635,7 +696,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
}
} else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold {
packetLost = true
if !p.skippedPacket {
if isRegularPacket {
if h.logger.Debug() {
h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber)
}
@@ -653,7 +714,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
}
if packetLost {
pnSpace.history.DeclareLost(p.PacketNumber)
if !p.skippedPacket {
if isRegularPacket {
// the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted
h.removeFromBytesInFlight(p)
h.queueFramesForRetransmission(p)
@@ -665,12 +726,16 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
}
}
}
return true, nil
})
}
}
func (h *sentPacketHandler) OnLossDetectionTimeout() error {
defer h.setLossDetectionTimer()
func (h *sentPacketHandler) OnLossDetectionTimeout(now time.Time) error {
defer h.setLossDetectionTimer(now)
if h.handshakeConfirmed {
h.detectLostPathProbes(now)
}
earliestLossTime, encLevel := h.getLossTimeAndSpace()
if !earliestLossTime.IsZero() {
if h.logger.Debug() {
@@ -680,13 +745,14 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error {
h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel)
}
// Early retransmit or time loss detection
return h.detectLostPackets(time.Now(), encLevel)
h.detectLostPackets(now, encLevel)
return nil
}
// PTO
// When all outstanding are acknowledged, the alarm is canceled in
// setLossDetectionTimer. This doesn't reset the timer in the session though.
// When OnAlarm is called, we therefore need to make sure that there are
// When all outstanding are acknowledged, the alarm is canceled in setLossDetectionTimer.
// However, there's no way to reset the timer in the connection.
// When OnLossDetectionTimeout is called, we therefore need to make sure that there are
// actually packets outstanding.
if h.bytesInFlight == 0 && !h.peerCompletedAddressValidation {
h.ptoCount++
@@ -701,11 +767,12 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error {
return nil
}
_, encLevel, ok := h.getPTOTimeAndSpace()
if !ok {
ptoTime, encLevel := h.getPTOTimeAndSpace(now)
if ptoTime.IsZero() {
return nil
}
if ps := h.getPacketNumberSpace(encLevel); !ps.history.HasOutstandingPackets() && !h.peerCompletedAddressValidation {
ps := h.getPacketNumberSpace(encLevel)
if !ps.history.HasOutstandingPackets() && !ps.history.HasOutstandingPathProbes() && !h.peerCompletedAddressValidation {
return nil
}
h.ptoCount++
@@ -739,7 +806,7 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error {
}
func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time {
return h.alarm
return h.alarm.Time
}
func (h *sentPacketHandler) ECNMode(isShortHeaderPacket bool) protocol.ECN {
@@ -756,7 +823,7 @@ func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel)
pnSpace := h.getPacketNumberSpace(encLevel)
pn := pnSpace.pns.Peek()
// See section 17.1 of RFC 9000.
return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked)
return pn, protocol.PacketNumberLengthForHeader(pn, pnSpace.largestAcked)
}
func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber {
@@ -864,33 +931,30 @@ func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) {
p.Frames = nil
}
func (h *sentPacketHandler) ResetForRetry(now time.Time) error {
func (h *sentPacketHandler) ResetForRetry(now time.Time) {
h.bytesInFlight = 0
var firstPacketSendTime time.Time
h.initialPackets.history.Iterate(func(p *packet) (bool, error) {
for p := range h.initialPackets.history.Packets() {
if firstPacketSendTime.IsZero() {
firstPacketSendTime = p.SendTime
}
if p.declaredLost || p.skippedPacket {
return true, nil
}
h.queueFramesForRetransmission(p)
return true, nil
})
// All application data packets sent at this point are 0-RTT packets.
// In the case of a Retry, we can assume that the server dropped all of them.
h.appDataPackets.history.Iterate(func(p *packet) (bool, error) {
if !p.declaredLost && !p.skippedPacket {
h.queueFramesForRetransmission(p)
}
return true, nil
})
}
// All application data packets sent at this point are 0-RTT packets.
// In the case of a Retry, we can assume that the server dropped all of them.
for p := range h.appDataPackets.history.Packets() {
if !p.declaredLost && !p.skippedPacket {
h.queueFramesForRetransmission(p)
}
}
// Only use the Retry to estimate the RTT if we didn't send any retransmission for the Initial.
// Otherwise, we don't know which Initial the Retry was sent in response to.
if h.ptoCount == 0 {
// Don't set the RTT to a value lower than 5ms here.
h.rttStats.UpdateRTT(max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now)
h.rttStats.UpdateRTT(max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0)
if h.logger.Debug() {
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
}
@@ -901,28 +965,36 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) error {
h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Peek(), false)
h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Peek(), true)
oldAlarm := h.alarm
h.alarm = time.Time{}
h.alarm = alarmTimer{}
if h.tracer != nil {
if h.tracer.UpdatedPTOCount != nil {
h.tracer.UpdatedPTOCount(0)
}
if !oldAlarm.IsZero() && h.tracer.LossTimerCanceled != nil {
if !oldAlarm.Time.IsZero() && h.tracer.LossTimerCanceled != nil {
h.tracer.LossTimerCanceled()
}
}
h.ptoCount = 0
return nil
}
func (h *sentPacketHandler) SetHandshakeConfirmed() {
if h.initialPackets != nil {
panic("didn't drop initial correctly")
func (h *sentPacketHandler) MigratedPath(now time.Time, initialMaxDatagramSize protocol.ByteCount) {
h.rttStats.ResetForPathMigration()
for p := range h.appDataPackets.history.Packets() {
h.appDataPackets.history.DeclareLost(p.PacketNumber)
if !p.skippedPacket && !p.isPathProbePacket {
h.removeFromBytesInFlight(p)
h.queueFramesForRetransmission(p)
}
}
if h.handshakePackets != nil {
panic("didn't drop handshake correctly")
for p := range h.appDataPackets.history.PathProbes() {
h.appDataPackets.history.RemovePathProbe(p.PacketNumber)
}
h.handshakeConfirmed = true
// We don't send PTOs for application data packets before the handshake completes.
// Make sure the timer is armed now, if necessary.
h.setLossDetectionTimer()
h.congestion = congestion.NewCubicSender(
congestion.DefaultClock{},
h.rttStats,
initialMaxDatagramSize,
true, // use Reno
h.tracer,
)
h.setLossDetectionTimer(now)
}

View File

@@ -2,23 +2,30 @@ package ackhandler
import (
"fmt"
"iter"
"github.com/quic-go/quic-go/internal/protocol"
)
type sentPacketHistory struct {
packets []*packet
packets []*packet
pathProbePackets []*packet
numOutstanding int
highestPacketNumber protocol.PacketNumber
}
func newSentPacketHistory() *sentPacketHistory {
return &sentPacketHistory{
packets: make([]*packet, 0, 32),
func newSentPacketHistory(isAppData bool) *sentPacketHistory {
h := &sentPacketHistory{
highestPacketNumber: protocol.InvalidPacketNumber,
}
if isAppData {
h.packets = make([]*packet, 0, 32)
} else {
h.packets = make([]*packet, 0, 6)
}
return h
}
func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNumber) {
@@ -27,11 +34,11 @@ func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNum
panic("non-sequential packet number use")
}
}
h.highestPacketNumber = pn
}
func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) {
h.checkSequentialPacketNumberUse(pn)
h.highestPacketNumber = pn
h.packets = append(h.packets, &packet{
PacketNumber: pn,
skippedPacket: true,
@@ -40,7 +47,6 @@ func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) {
func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber) {
h.checkSequentialPacketNumberUse(pn)
h.highestPacketNumber = pn
if len(h.packets) > 0 {
h.packets = append(h.packets, nil)
}
@@ -48,28 +54,42 @@ func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber)
func (h *sentPacketHistory) SentAckElicitingPacket(p *packet) {
h.checkSequentialPacketNumberUse(p.PacketNumber)
h.highestPacketNumber = p.PacketNumber
h.packets = append(h.packets, p)
if p.outstanding() {
h.numOutstanding++
}
}
// Iterate iterates through all packets.
func (h *sentPacketHistory) Iterate(cb func(*packet) (cont bool, err error)) error {
for _, p := range h.packets {
if p == nil {
continue
}
cont, err := cb(p)
if err != nil {
return err
}
if !cont {
return nil
func (h *sentPacketHistory) SentPathProbePacket(p *packet) {
h.checkSequentialPacketNumberUse(p.PacketNumber)
h.packets = append(h.packets, &packet{
PacketNumber: p.PacketNumber,
isPathProbePacket: true,
})
h.pathProbePackets = append(h.pathProbePackets, p)
}
func (h *sentPacketHistory) Packets() iter.Seq[*packet] {
return func(yield func(*packet) bool) {
for _, p := range h.packets {
if p == nil {
continue
}
if !yield(p) {
return
}
}
}
}
func (h *sentPacketHistory) PathProbes() iter.Seq[*packet] {
return func(yield func(*packet) bool) {
for _, p := range h.pathProbePackets {
if !yield(p) {
return
}
}
}
return nil
}
// FirstOutstanding returns the first outstanding packet.
@@ -85,6 +105,14 @@ func (h *sentPacketHistory) FirstOutstanding() *packet {
return nil
}
// FirstOutstandingPathProbe returns the first outstanding path probe packet
func (h *sentPacketHistory) FirstOutstandingPathProbe() *packet {
if len(h.pathProbePackets) == 0 {
return nil
}
return h.pathProbePackets[0]
}
func (h *sentPacketHistory) Len() int {
return len(h.packets)
}
@@ -120,6 +148,27 @@ func (h *sentPacketHistory) Remove(pn protocol.PacketNumber) error {
return nil
}
// RemovePathProbe removes a path probe packet.
// It scales O(N), but that's ok, since we don't expect to send many path probe packets.
// It is not valid to call this function in IteratePathProbes.
func (h *sentPacketHistory) RemovePathProbe(pn protocol.PacketNumber) *packet {
var packetToDelete *packet
idx := -1
for i, p := range h.pathProbePackets {
if p.PacketNumber == pn {
packetToDelete = p
idx = i
break
}
}
if idx != -1 {
// don't use slices.Delete, because it zeros the deleted element
copy(h.pathProbePackets[idx:], h.pathProbePackets[idx+1:])
h.pathProbePackets = h.pathProbePackets[:len(h.pathProbePackets)-1]
}
return packetToDelete
}
// getIndex gets the index of packet p in the packets slice.
func (h *sentPacketHistory) getIndex(p protocol.PacketNumber) (int, bool) {
if len(h.packets) == 0 {
@@ -140,6 +189,10 @@ func (h *sentPacketHistory) HasOutstandingPackets() bool {
return h.numOutstanding > 0
}
func (h *sentPacketHistory) HasOutstandingPathProbes() bool {
return len(h.pathProbePackets) > 0
}
// delete all nil entries at the beginning of the packets slice
func (h *sentPacketHistory) cleanupStart() {
for i, p := range h.packets {

View File

@@ -36,7 +36,7 @@ type baseFlowController struct {
// For every offset, it only returns true once.
// If it is blocked, the offset is returned.
func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
if c.SendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
return false, 0
}
c.lastBlockedAt = c.sendWindow
@@ -56,7 +56,7 @@ func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) (update
return false
}
func (c *baseFlowController) sendWindowSize() protocol.ByteCount {
func (c *baseFlowController) SendWindowSize() protocol.ByteCount {
// this only happens during connection establishment, when data is sent before we receive the peer's transport parameters
if c.bytesSent > c.sendWindow {
return 0
@@ -66,11 +66,6 @@ func (c *baseFlowController) sendWindowSize() protocol.ByteCount {
// needs to be called with locked mutex
func (c *baseFlowController) addBytesRead(n protocol.ByteCount) {
// pretend we sent a WindowUpdate when reading the first byte
// this way auto-tuning of the window size already works for the first WindowUpdate
if c.bytesRead == 0 {
c.startNewAutoTuningEpoch(time.Now())
}
c.bytesRead += n
}
@@ -82,19 +77,19 @@ func (c *baseFlowController) hasWindowUpdate() bool {
// getWindowUpdate updates the receive window, if necessary
// it returns the new offset
func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
func (c *baseFlowController) getWindowUpdate(now time.Time) protocol.ByteCount {
if !c.hasWindowUpdate() {
return 0
}
c.maybeAdjustWindowSize()
c.maybeAdjustWindowSize(now)
c.receiveWindow = c.bytesRead + c.receiveWindowSize
return c.receiveWindow
}
// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
func (c *baseFlowController) maybeAdjustWindowSize() {
func (c *baseFlowController) maybeAdjustWindowSize(now time.Time) {
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
// don't do anything if less than half the window has been consumed
if bytesReadInEpoch <= c.receiveWindowSize/2 {
@@ -106,7 +101,6 @@ func (c *baseFlowController) maybeAdjustWindowSize() {
}
fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
now := time.Now()
if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
// window is consumed too fast, try to increase the window size
newSize := min(2*c.receiveWindowSize, c.maxReceiveWindowSize)

View File

@@ -12,8 +12,6 @@ import (
type connectionFlowController struct {
baseFlowController
queueWindowUpdate func()
}
var _ ConnectionFlowController = &connectionFlowController{}
@@ -23,11 +21,10 @@ var _ ConnectionFlowController = &connectionFlowController{}
func NewConnectionFlowController(
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
queueWindowUpdate func(),
allowWindowIncrease func(size protocol.ByteCount) bool,
rttStats *utils.RTTStats,
logger utils.Logger,
) ConnectionFlowController {
) *connectionFlowController {
return &connectionFlowController{
baseFlowController: baseFlowController{
rttStats: rttStats,
@@ -37,20 +34,20 @@ func NewConnectionFlowController(
allowWindowIncrease: allowWindowIncrease,
logger: logger,
},
queueWindowUpdate: queueWindowUpdate,
}
}
func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
return c.baseFlowController.sendWindowSize()
}
// IncrementHighestReceived adds an increment to the highestReceived value
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount, now time.Time) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// If this is the first frame received on this connection, start flow-control auto-tuning.
if c.highestReceived == 0 {
c.startNewAutoTuningEpoch(now)
}
c.highestReceived += increment
if c.checkFlowControlViolation() {
return &qerr.TransportError{
ErrorCode: qerr.FlowControlError,
@@ -60,44 +57,47 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B
return nil
}
func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) {
func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) (hasWindowUpdate bool) {
c.mutex.Lock()
c.baseFlowController.addBytesRead(n)
shouldQueueWindowUpdate := c.hasWindowUpdate()
c.mutex.Unlock()
if shouldQueueWindowUpdate {
c.queueWindowUpdate()
}
defer c.mutex.Unlock()
c.addBytesRead(n)
return c.hasWindowUpdate()
}
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
func (c *connectionFlowController) GetWindowUpdate(now time.Time) protocol.ByteCount {
c.mutex.Lock()
defer c.mutex.Unlock()
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if oldWindowSize < c.receiveWindowSize {
offset := c.getWindowUpdate(now)
if c.logger.Debug() && oldWindowSize < c.receiveWindowSize {
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
}
c.mutex.Unlock()
return offset
}
// EnsureMinimumWindowSize sets a minimum window size
// it should make sure that the connection-level window is increased when a stream-level window grows
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount, now time.Time) {
c.mutex.Lock()
if inc > c.receiveWindowSize {
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10))
newSize := min(inc, c.maxReceiveWindowSize)
if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) {
c.receiveWindowSize = newSize
}
c.startNewAutoTuningEpoch(time.Now())
defer c.mutex.Unlock()
if inc <= c.receiveWindowSize {
return
}
c.mutex.Unlock()
newSize := min(inc, c.maxReceiveWindowSize)
if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) {
c.receiveWindowSize = newSize
if c.logger.Debug() {
c.logger.Debugf("Increasing receive flow control window for the connection to %d, in response to stream flow control window increase", newSize)
}
}
c.startNewAutoTuningEpoch(now)
}
// Reset rests the flow controller. This happens when 0-RTT is rejected.
// All stream data is invalidated, it's if we had never opened a stream and never sent any data.
// All stream data is invalidated, it's as if we had never opened a stream and never sent any data.
// At that point, we only have sent stream data, but we didn't have the keys to open 1-RTT keys yet.
func (c *connectionFlowController) Reset() error {
c.mutex.Lock()
@@ -108,5 +108,6 @@ func (c *connectionFlowController) Reset() error {
}
c.bytesSent = 0
c.lastBlockedAt = 0
c.sendWindow = 0
return nil
}

View File

@@ -1,6 +1,10 @@
package flowcontrol
import "github.com/quic-go/quic-go/internal/protocol"
import (
"time"
"github.com/quic-go/quic-go/internal/protocol"
)
type flowController interface {
// for sending
@@ -8,34 +12,36 @@ type flowController interface {
UpdateSendWindow(protocol.ByteCount) (updated bool)
AddBytesSent(protocol.ByteCount)
// for receiving
AddBytesRead(protocol.ByteCount)
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
IsNewlyBlocked() (bool, protocol.ByteCount)
GetWindowUpdate(time.Time) protocol.ByteCount // returns 0 if no update is necessary
}
// A StreamFlowController is a flow controller for a QUIC stream.
type StreamFlowController interface {
flowController
AddBytesRead(protocol.ByteCount) (hasStreamWindowUpdate, hasConnWindowUpdate bool)
// UpdateHighestReceived is called when a new highest offset is received
// final has to be to true if this is the final offset of the stream,
// as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
UpdateHighestReceived(offset protocol.ByteCount, final bool, now time.Time) error
// Abandon is called when reading from the stream is aborted early,
// and there won't be any further calls to AddBytesRead.
Abandon()
IsNewlyBlocked() bool
}
// The ConnectionFlowController is the flow controller for the connection.
type ConnectionFlowController interface {
flowController
AddBytesRead(protocol.ByteCount) (hasWindowUpdate bool)
Reset() error
IsNewlyBlocked() (bool, protocol.ByteCount)
}
type connectionFlowControllerI interface {
ConnectionFlowController
// The following two methods are not supposed to be called from outside this packet, but are needed internally
// for sending
EnsureMinimumWindowSize(protocol.ByteCount)
EnsureMinimumWindowSize(protocol.ByteCount, time.Time)
// for receiving
IncrementHighestReceived(protocol.ByteCount) error
IncrementHighestReceived(protocol.ByteCount, time.Time) error
}

View File

@@ -2,6 +2,7 @@ package flowcontrol
import (
"fmt"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
@@ -13,8 +14,6 @@ type streamFlowController struct {
streamID protocol.StreamID
queueWindowUpdate func()
connection connectionFlowControllerI
receivedFinalOffset bool
@@ -29,14 +28,12 @@ func NewStreamFlowController(
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
initialSendWindow protocol.ByteCount,
queueWindowUpdate func(protocol.StreamID),
rttStats *utils.RTTStats,
logger utils.Logger,
) StreamFlowController {
return &streamFlowController{
streamID: streamID,
connection: cfc.(connectionFlowControllerI),
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
streamID: streamID,
connection: cfc.(connectionFlowControllerI),
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
@@ -49,7 +46,7 @@ func NewStreamFlowController(
}
// UpdateHighestReceived updates the highestReceived value, if the offset is higher.
func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool) error {
func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool, now time.Time) error {
// If the final offset for this stream is already known, check for consistency.
if c.receivedFinalOffset {
// If we receive another final offset, check that it's the same.
@@ -74,9 +71,8 @@ func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount,
if offset == c.highestReceived {
return nil
}
// A higher offset was received before.
// This can happen due to reordering.
if offset <= c.highestReceived {
// A higher offset was received before. This can happen due to reordering.
if offset < c.highestReceived {
if final {
return &qerr.TransportError{
ErrorCode: qerr.FinalSizeError,
@@ -86,26 +82,29 @@ func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount,
return nil
}
// If this is the first frame received for this stream, start flow-control auto-tuning.
if c.highestReceived == 0 {
c.startNewAutoTuningEpoch(now)
}
increment := offset - c.highestReceived
c.highestReceived = offset
if c.checkFlowControlViolation() {
return &qerr.TransportError{
ErrorCode: qerr.FlowControlError,
ErrorMessage: fmt.Sprintf("received %d bytes on stream %d, allowed %d bytes", offset, c.streamID, c.receiveWindow),
}
}
return c.connection.IncrementHighestReceived(increment)
return c.connection.IncrementHighestReceived(increment, now)
}
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) (hasStreamWindowUpdate, hasConnWindowUpdate bool) {
c.mutex.Lock()
c.baseFlowController.addBytesRead(n)
shouldQueueWindowUpdate := c.shouldQueueWindowUpdate()
c.addBytesRead(n)
hasStreamWindowUpdate = c.shouldQueueWindowUpdate()
c.mutex.Unlock()
if shouldQueueWindowUpdate {
c.queueWindowUpdate()
}
c.connection.AddBytesRead(n)
hasConnWindowUpdate = c.connection.AddBytesRead(n)
return
}
func (c *streamFlowController) Abandon() {
@@ -124,27 +123,32 @@ func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
}
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
return min(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize())
return min(c.baseFlowController.SendWindowSize(), c.connection.SendWindowSize())
}
func (c *streamFlowController) IsNewlyBlocked() bool {
blocked, _ := c.baseFlowController.IsNewlyBlocked()
return blocked
}
func (c *streamFlowController) shouldQueueWindowUpdate() bool {
return !c.receivedFinalOffset && c.hasWindowUpdate()
}
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
func (c *streamFlowController) GetWindowUpdate(now time.Time) protocol.ByteCount {
// If we already received the final offset for this stream, the peer won't need any additional flow control credit.
if c.receivedFinalOffset {
return 0
}
// Don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler
c.mutex.Lock()
defer c.mutex.Unlock()
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
offset := c.getWindowUpdate(now)
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
c.logger.Debugf("Increasing receive flow control window for stream %d to %d", c.streamID, c.receiveWindowSize)
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize)*protocol.ConnectionFlowControlMultiplier), now)
}
c.mutex.Unlock()
return offset
}

View File

@@ -12,9 +12,7 @@ import (
// These cipher suite implementations are copied from the standard library crypto/tls package.
const (
aeadNonceLength = 12
)
const aeadNonceLength = 12
type cipherSuite struct {
ID uint16

View File

@@ -12,7 +12,6 @@ import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
@@ -89,12 +88,13 @@ func NewCryptoSetupClient(
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
quicConf := &tls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState)
cs.tlsConf = tlsConf
cs.allow0RTT = enable0RTT
cs.conn = tls.QUICClient(quicConf)
cs.conn = tls.QUICClient(&tls.QUICConfig{
TLSConfig: tlsConf,
EnableSessionEvents: true,
})
cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient))
return cs
@@ -123,9 +123,13 @@ func NewCryptoSetupServer(
)
cs.allow0RTT = allow0RTT
tlsConf = qtls.SetupConfigForServer(tlsConf, localAddr, remoteAddr, cs.getDataForSessionTicket, cs.handleSessionTicket)
tlsConf = setupConfigForServer(tlsConf, localAddr, remoteAddr)
cs.tlsConf = tlsConf
cs.conn = tls.QUICServer(&tls.QUICConfig{TLSConfig: tlsConf})
cs.conn = tls.QUICServer(&tls.QUICConfig{
TLSConfig: tlsConf,
EnableSessionEvents: true,
})
return cs
}
@@ -178,11 +182,10 @@ func (h *cryptoSetup) StartHandshake(ctx context.Context) error {
}
for {
ev := h.conn.NextEvent()
done, err := h.handleEvent(ev)
if err != nil {
if err := h.handleEvent(ev); err != nil {
return wrapError(err)
}
if done {
if ev.Kind == tls.QUICNoEvent {
break
}
}
@@ -213,47 +216,78 @@ func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLev
}
func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
if err := h.conn.HandleData(qtls.ToTLSEncryptionLevel(encLevel), data); err != nil {
if err := h.conn.HandleData(encLevel.ToTLSEncryptionLevel(), data); err != nil {
return err
}
for {
ev := h.conn.NextEvent()
done, err := h.handleEvent(ev)
if err != nil {
if err := h.handleEvent(ev); err != nil {
return err
}
if done {
if ev.Kind == tls.QUICNoEvent {
return nil
}
}
}
func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) {
func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (err error) {
switch ev.Kind {
case tls.QUICNoEvent:
return true, nil
return nil
case tls.QUICSetReadSecret:
h.setReadKey(ev.Level, ev.Suite, ev.Data)
return false, nil
return nil
case tls.QUICSetWriteSecret:
h.setWriteKey(ev.Level, ev.Suite, ev.Data)
return false, nil
return nil
case tls.QUICTransportParameters:
return false, h.handleTransportParameters(ev.Data)
return h.handleTransportParameters(ev.Data)
case tls.QUICTransportParametersRequired:
h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective))
return false, nil
return nil
case tls.QUICRejectedEarlyData:
h.rejected0RTT()
return false, nil
return nil
case tls.QUICWriteData:
h.writeRecord(ev.Level, ev.Data)
return false, nil
return nil
case tls.QUICHandshakeDone:
h.handshakeComplete()
return false, nil
return nil
case tls.QUICStoreSession:
if h.perspective == protocol.PerspectiveServer {
panic("cryptoSetup BUG: unexpected QUICStoreSession event for the server")
}
ev.SessionState.Extra = append(
ev.SessionState.Extra,
addSessionStateExtraPrefix(h.marshalDataForSessionState(ev.SessionState.EarlyData)),
)
return h.conn.StoreSession(ev.SessionState)
case tls.QUICResumeSession:
var allowEarlyData bool
switch h.perspective {
case protocol.PerspectiveClient:
// for clients, this event occurs when a session ticket is selected
allowEarlyData = h.handleDataFromSessionState(
findSessionStateExtraData(ev.SessionState.Extra),
ev.SessionState.EarlyData,
)
case protocol.PerspectiveServer:
// for servers, this event occurs when receiving the client's session ticket
allowEarlyData = h.handleSessionTicket(
findSessionStateExtraData(ev.SessionState.Extra),
ev.SessionState.EarlyData,
)
}
if ev.SessionState.EarlyData {
ev.SessionState.EarlyData = allowEarlyData
}
return nil
default:
return false, fmt.Errorf("unexpected event: %d", ev.Kind)
// Unknown events should be ignored.
// crypto/tls will ensure that this is safe to do.
// See the discussion following https://github.com/golang/go/issues/68124#issuecomment-2187042510 for details.
return nil
}
}
@@ -344,7 +378,10 @@ func (h *cryptoSetup) getDataForSessionTicket() []byte {
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
// It is only valid for the server.
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{EarlyData: h.allow0RTT}); err != nil {
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{
EarlyData: h.allow0RTT,
Extra: [][]byte{addSessionStateExtraPrefix(h.getDataForSessionTicket())},
}); err != nil {
// Session tickets might be disabled by tls.Config.SessionTicketsDisabled.
// We can't check h.tlsConfig here, since the actual config might have been obtained from
// the GetConfigForClient callback.
@@ -370,9 +407,9 @@ func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
// It reads parameters from the session ticket and checks whether to accept 0-RTT if the session ticket enabled 0-RTT.
// Note that the fact that the session ticket allows 0-RTT doesn't mean that the actual TLS handshake enables 0-RTT:
// A client may use a 0-RTT enabled session to resume a TLS session without using 0-RTT.
func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bool) bool {
func (h *cryptoSetup) handleSessionTicket(data []byte, using0RTT bool) (allowEarlyData bool) {
var t sessionTicket
if err := t.Unmarshal(sessionTicketData, using0RTT); err != nil {
if err := t.Unmarshal(data, using0RTT); err != nil {
h.logger.Debugf("Unmarshalling session ticket failed: %s", err.Error())
return false
}
@@ -440,7 +477,7 @@ func (h *cryptoSetup) setReadKey(el tls.QUICEncryptionLevel, suiteID uint16, tra
}
h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
h.tracer.UpdatedKeyFromTLS(protocol.FromTLSEncryptionLevel(el), h.perspective.Opposite())
}
}
@@ -491,7 +528,7 @@ func (h *cryptoSetup) setWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr
panic("unexpected write encryption level")
}
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective)
h.tracer.UpdatedKeyFromTLS(protocol.FromTLSEncryptionLevel(el), h.perspective)
}
}
@@ -621,8 +658,7 @@ func (h *cryptoSetup) ConnectionState() ConnectionState {
}
func wrapError(err error) error {
// alert 80 is an internal error
if alertErr := tls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 {
if alertErr := tls.AlertError(0); errors.As(err, &alertErr) {
return qerr.NewLocalCryptoError(uint8(alertErr), err)
}
return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()}

View File

@@ -1,4 +1,4 @@
package qtls
package handshake
import (
"net"

View File

@@ -8,8 +8,6 @@ import (
)
// hkdfExpandLabel HKDF expands a label as defined in RFC 8446, section 7.1.
// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the
// hkdfExpandLabel in the standard library.
func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {
b := make([]byte, 3, 3+6+len(label)+1+len(context))
binary.BigEndian.PutUint16(b, uint16(length))

View File

@@ -83,6 +83,29 @@ const (
EventHandshakeComplete
)
func (k EventKind) String() string {
switch k {
case EventNoEvent:
return "EventNoEvent"
case EventWriteInitialData:
return "EventWriteInitialData"
case EventWriteHandshakeData:
return "EventWriteHandshakeData"
case EventReceivedReadKeys:
return "EventReceivedReadKeys"
case EventDiscard0RTTKeys:
return "EventDiscard0RTTKeys"
case EventReceivedTransportParameters:
return "EventReceivedTransportParameters"
case EventRestoredTransportParameters:
return "EventRestoredTransportParameters"
case EventHandshakeComplete:
return "EventHandshakeComplete"
default:
return "Unknown EventKind"
}
}
// Event is a handshake event.
type Event struct {
Kind EventKind

View File

@@ -10,16 +10,13 @@ import (
"github.com/quic-go/quic-go/internal/protocol"
)
// Instead of using an init function, the AEADs are created lazily.
// For more details see https://github.com/quic-go/quic-go/issues/4894.
var (
retryAEADv1 cipher.AEAD // used for QUIC v1 (RFC 9000)
retryAEADv2 cipher.AEAD // used for QUIC v2 (RFC 9369)
)
func init() {
retryAEADv1 = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e})
retryAEADv2 = initAEAD([16]byte{0x8f, 0xb4, 0xb0, 0x1b, 0x56, 0xac, 0x48, 0xe2, 0x60, 0xfb, 0xcb, 0xce, 0xad, 0x7c, 0xcc, 0x92})
}
func initAEAD(key [16]byte) cipher.AEAD {
aes, err := aes.NewCipher(key[:])
if err != nil {
@@ -52,8 +49,14 @@ func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, ve
var tag [16]byte
var sealed []byte
if version == protocol.Version2 {
if retryAEADv2 == nil {
retryAEADv2 = initAEAD([16]byte{0x8f, 0xb4, 0xb0, 0x1b, 0x56, 0xac, 0x48, 0xe2, 0x60, 0xfb, 0xcb, 0xce, 0xad, 0x7c, 0xcc, 0x92})
}
sealed = retryAEADv2.Seal(tag[:0], retryNonceV2[:], nil, retryBuf.Bytes())
} else {
if retryAEADv1 == nil {
retryAEADv1 = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e})
}
sealed = retryAEADv1.Seal(tag[:0], retryNonceV1[:], nil, retryBuf.Bytes())
}
if len(sealed) != 16 {

View File

@@ -1,6 +1,7 @@
package handshake
import (
"bytes"
"errors"
"fmt"
"time"
@@ -52,3 +53,20 @@ func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error {
t.RTT = time.Duration(rtt) * time.Microsecond
return nil
}
const extraPrefix = "quic-go1"
func addSessionStateExtraPrefix(b []byte) []byte {
return append([]byte(extraPrefix), b...)
}
func findSessionStateExtraData(extras [][]byte) []byte {
prefix := []byte(extraPrefix)
for _, extra := range extras {
if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) {
continue
}
return extra[len(prefix):]
}
return nil
}

View File

@@ -0,0 +1,39 @@
package handshake
import (
"crypto/tls"
"net"
)
func setupConfigForServer(conf *tls.Config, localAddr, remoteAddr net.Addr) *tls.Config {
// Workaround for https://github.com/golang/go/issues/60506.
// This initializes the session tickets _before_ cloning the config.
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
conf = conf.Clone()
conf.MinVersion = tls.VersionTLS13
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info)
if c != nil {
// we're returning a tls.Config here, so we need to apply this recursively
c = setupConfigForServer(c, localAddr, remoteAddr)
}
return c, err
}
}
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}
return conf
}

View File

@@ -46,7 +46,7 @@ type TokenGenerator struct {
// NewTokenGenerator initializes a new TokenGenerator
func NewTokenGenerator(key TokenProtectorKey) *TokenGenerator {
return &TokenGenerator{tokenProtector: newTokenProtector(key)}
return &TokenGenerator{tokenProtector: *newTokenProtector(key)}
}
// NewRetryToken generates a new token for a Retry for a given source address

View File

@@ -14,28 +14,20 @@ import (
// TokenProtectorKey is the key used to encrypt both Retry and session resumption tokens.
type TokenProtectorKey [32]byte
// TokenProtector is used to create and verify a token
type tokenProtector interface {
// NewToken creates a new token
NewToken([]byte) ([]byte, error)
// DecodeToken decodes a token
DecodeToken([]byte) ([]byte, error)
}
const tokenNonceSize = 32
// tokenProtector is used to create and verify a token
type tokenProtectorImpl struct {
type tokenProtector struct {
key TokenProtectorKey
}
// newTokenProtector creates a source for source address tokens
func newTokenProtector(key TokenProtectorKey) tokenProtector {
return &tokenProtectorImpl{key: key}
func newTokenProtector(key TokenProtectorKey) *tokenProtector {
return &tokenProtector{key: key}
}
// NewToken encodes data into a new token.
func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) {
func (s *tokenProtector) NewToken(data []byte) ([]byte, error) {
var nonce [tokenNonceSize]byte
if _, err := rand.Read(nonce[:]); err != nil {
return nil, err
@@ -48,7 +40,7 @@ func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) {
}
// DecodeToken decodes a token.
func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) {
func (s *tokenProtector) DecodeToken(p []byte) ([]byte, error) {
if len(p) < tokenNonceSize {
return nil, fmt.Errorf("token too short: %d", len(p))
}
@@ -60,7 +52,7 @@ func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) {
return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil)
}
func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
func (s *tokenProtector) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
h := hkdf.New(sha256.New, s.key[:], nonce, []byte("quic-go token source"))
key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
if _, err := io.ReadFull(h, key); err != nil {

View File

@@ -1,50 +0,0 @@
package logutils
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
// ConvertFrame converts a wire.Frame into a logging.Frame.
// This makes it possible for external packages to access the frames.
// Furthermore, it removes the data slices from CRYPTO and STREAM frames.
func ConvertFrame(frame wire.Frame) logging.Frame {
switch f := frame.(type) {
case *wire.AckFrame:
// We use a pool for ACK frames.
// Implementations of the tracer interface may hold on to frames, so we need to make a copy here.
return ConvertAckFrame(f)
case *wire.CryptoFrame:
return &logging.CryptoFrame{
Offset: f.Offset,
Length: protocol.ByteCount(len(f.Data)),
}
case *wire.StreamFrame:
return &logging.StreamFrame{
StreamID: f.StreamID,
Offset: f.Offset,
Length: f.DataLen(),
Fin: f.Fin,
}
case *wire.DatagramFrame:
return &logging.DatagramFrame{
Length: logging.ByteCount(len(f.Data)),
}
default:
return logging.Frame(frame)
}
}
func ConvertAckFrame(f *wire.AckFrame) *logging.AckFrame {
ranges := make([]wire.AckRange, 0, len(f.AckRanges))
ranges = append(ranges, f.AckRanges...)
ack := &logging.AckFrame{
AckRanges: ranges,
DelayTime: f.DelayTime,
ECNCE: f.ECNCE,
ECT0: f.ECT0,
ECT1: f.ECT1,
}
return ack
}

View File

@@ -1,5 +1,10 @@
package protocol
import (
"crypto/tls"
"fmt"
)
// EncryptionLevel is the encryption level
// Default value is Unencrypted
type EncryptionLevel uint8
@@ -28,3 +33,33 @@ func (e EncryptionLevel) String() string {
}
return "unknown"
}
func (e EncryptionLevel) ToTLSEncryptionLevel() tls.QUICEncryptionLevel {
switch e {
case EncryptionInitial:
return tls.QUICEncryptionLevelInitial
case EncryptionHandshake:
return tls.QUICEncryptionLevelHandshake
case Encryption1RTT:
return tls.QUICEncryptionLevelApplication
case Encryption0RTT:
return tls.QUICEncryptionLevelEarly
default:
panic(fmt.Sprintf("unexpected encryption level: %s", e))
}
}
func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) EncryptionLevel {
switch e {
case tls.QUICEncryptionLevelInitial:
return EncryptionInitial
case tls.QUICEncryptionLevelHandshake:
return EncryptionHandshake
case tls.QUICEncryptionLevelApplication:
return Encryption1RTT
case tls.QUICEncryptionLevelEarly:
return Encryption0RTT
default:
panic(fmt.Sprintf("unexpect encryption level: %s", e))
}
}

View File

@@ -21,58 +21,36 @@ const (
PacketNumberLen4 PacketNumberLen = 4
)
// DecodePacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
func DecodePacketNumber(
packetNumberLength PacketNumberLen,
lastPacketNumber PacketNumber,
wirePacketNumber PacketNumber,
) PacketNumber {
var epochDelta PacketNumber
switch packetNumberLength {
case PacketNumberLen1:
epochDelta = PacketNumber(1) << 8
case PacketNumberLen2:
epochDelta = PacketNumber(1) << 16
case PacketNumberLen3:
epochDelta = PacketNumber(1) << 24
case PacketNumberLen4:
epochDelta = PacketNumber(1) << 32
// DecodePacketNumber calculates the packet number based its length and the last seen packet number
// This function is taken from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.3.
func DecodePacketNumber(length PacketNumberLen, largest PacketNumber, truncated PacketNumber) PacketNumber {
expected := largest + 1
win := PacketNumber(1 << (length * 8))
hwin := win / 2
mask := win - 1
candidate := (expected & ^mask) | truncated
if candidate <= expected-hwin && candidate < 1<<62-win {
return candidate + win
}
epoch := lastPacketNumber & ^(epochDelta - 1)
var prevEpochBegin PacketNumber
if epoch > epochDelta {
prevEpochBegin = epoch - epochDelta
if candidate > expected+hwin && candidate >= win {
return candidate - win
}
nextEpochBegin := epoch + epochDelta
return closestTo(
lastPacketNumber+1,
epoch+wirePacketNumber,
closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber),
)
return candidate
}
func closestTo(target, a, b PacketNumber) PacketNumber {
if delta(target, a) < delta(target, b) {
return a
}
return b
}
func delta(a, b PacketNumber) PacketNumber {
if a < b {
return b - a
}
return a - b
}
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header
// PacketNumberLengthForHeader gets the length of the packet number for the public header
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen {
diff := uint64(packetNumber - leastUnacked)
if diff < (1 << (16 - 1)) {
func PacketNumberLengthForHeader(pn, largestAcked PacketNumber) PacketNumberLen {
var numUnacked PacketNumber
if largestAcked == InvalidPacketNumber {
numUnacked = pn + 1
} else {
numUnacked = pn - largestAcked
}
if numUnacked < 1<<(16-1) {
return PacketNumberLen2
}
if diff < (1 << (24 - 1)) {
if numUnacked < 1<<(24-1) {
return PacketNumberLen3
}
return PacketNumberLen4

View File

@@ -102,10 +102,6 @@ const DefaultIdleTimeout = 30 * time.Second
// DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion.
const DefaultHandshakeIdleTimeout = 5 * time.Second
// MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive.
// It should be shorter than the time that NATs clear their mapping.
const MaxKeepAliveInterval = 20 * time.Second
// RetiredConnectionIDDeleteTimeout is the time we keep closed connections around in order to retransmit the CONNECTION_CLOSE.
// after this time all information about the old connection will be deleted
const RetiredConnectionIDDeleteTimeout = 5 * time.Second

View File

@@ -123,6 +123,10 @@ const MinUnknownVersionPacketSize = MinInitialPacketSize
// MinStatelessResetSize is the minimum size of a stateless reset packet that we send
const MinStatelessResetSize = 1 /* first byte */ + 20 /* max. conn ID length */ + 4 /* max. packet number length */ + 1 /* min. payload length */ + 16 /* token */
// MinReceivedStatelessResetSize is the minimum size of a received stateless reset,
// as specified in section 10.3 of RFC 9000.
const MinReceivedStatelessResetSize = 5 + 16
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
const MinConnectionIDLenInitial = 8

View File

@@ -1,13 +1,12 @@
package protocol
import (
"crypto/rand"
"encoding/binary"
"fmt"
"math"
mrand "math/rand/v2"
"sync"
"time"
"golang.org/x/exp/rand"
)
// Version is a version number as int
@@ -90,13 +89,22 @@ func ChooseSupportedVersion(ours, theirs []Version) (Version, bool) {
var (
versionNegotiationMx sync.Mutex
versionNegotiationRand = rand.New(rand.NewSource(uint64(time.Now().UnixNano())))
versionNegotiationRand mrand.Rand
)
func init() {
var seed [16]byte
rand.Read(seed[:])
versionNegotiationRand = *mrand.New(mrand.NewPCG(
binary.BigEndian.Uint64(seed[:8]),
binary.BigEndian.Uint64(seed[8:]),
))
}
// generateReservedVersion generates a reserved version (v & 0x0f0f0f0f == 0x0a0a0a0a)
func generateReservedVersion() Version {
var b [4]byte
_, _ = versionNegotiationRand.Read(b[:]) // ignore the error here. Failure to read random data doesn't break anything
binary.BigEndian.PutUint32(b[:], versionNegotiationRand.Uint32())
return Version((binary.BigEndian.Uint32(b[:]) | 0x0a0a0a0a) & 0xfafafafa)
}
@@ -105,7 +113,7 @@ func generateReservedVersion() Version {
func GetGreasedVersions(supported []Version) []Version {
versionNegotiationMx.Lock()
defer versionNegotiationMx.Unlock()
randPos := rand.Intn(len(supported) + 1)
randPos := versionNegotiationRand.IntN(len(supported) + 1)
greased := make([]Version, len(supported)+1)
copy(greased, supported[:randPos])
greased[randPos] = generateReservedVersion()

View File

@@ -48,21 +48,16 @@ func (e *TransportError) Error() string {
return str + ": " + msg
}
func (e *TransportError) Is(target error) bool {
return target == net.ErrClosed
}
func (e *TransportError) Unwrap() []error { return []error{net.ErrClosed, e.error} }
func (e *TransportError) Unwrap() error {
return e.error
func (e *TransportError) Is(target error) bool {
t, ok := target.(*TransportError)
return ok && e.ErrorCode == t.ErrorCode && e.FrameType == t.FrameType && e.Remote == t.Remote
}
// An ApplicationErrorCode is an application-defined error code.
type ApplicationErrorCode uint64
func (e *ApplicationError) Is(target error) bool {
return target == net.ErrClosed
}
// A StreamErrorCode is an error code used to cancel streams.
type StreamErrorCode uint64
@@ -81,23 +76,30 @@ func (e *ApplicationError) Error() string {
return fmt.Sprintf("Application error %#x (%s): %s", e.ErrorCode, getRole(e.Remote), e.ErrorMessage)
}
func (e *ApplicationError) Unwrap() error { return net.ErrClosed }
func (e *ApplicationError) Is(target error) bool {
t, ok := target.(*ApplicationError)
return ok && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote
}
type IdleTimeoutError struct{}
var _ error = &IdleTimeoutError{}
func (e *IdleTimeoutError) Timeout() bool { return true }
func (e *IdleTimeoutError) Temporary() bool { return false }
func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" }
func (e *IdleTimeoutError) Is(target error) bool { return target == net.ErrClosed }
func (e *IdleTimeoutError) Timeout() bool { return true }
func (e *IdleTimeoutError) Temporary() bool { return false }
func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" }
func (e *IdleTimeoutError) Unwrap() error { return net.ErrClosed }
type HandshakeTimeoutError struct{}
var _ error = &HandshakeTimeoutError{}
func (e *HandshakeTimeoutError) Timeout() bool { return true }
func (e *HandshakeTimeoutError) Temporary() bool { return false }
func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" }
func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.ErrClosed }
func (e *HandshakeTimeoutError) Timeout() bool { return true }
func (e *HandshakeTimeoutError) Temporary() bool { return false }
func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" }
func (e *HandshakeTimeoutError) Unwrap() error { return net.ErrClosed }
// A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version.
type VersionNegotiationError struct {
@@ -109,25 +111,18 @@ func (e *VersionNegotiationError) Error() string {
return fmt.Sprintf("no compatible QUIC version found (we support %s, server offered %s)", e.Ours, e.Theirs)
}
func (e *VersionNegotiationError) Is(target error) bool {
return target == net.ErrClosed
}
func (e *VersionNegotiationError) Unwrap() error { return net.ErrClosed }
// A StatelessResetError occurs when we receive a stateless reset.
type StatelessResetError struct {
Token protocol.StatelessResetToken
}
type StatelessResetError struct{}
var _ net.Error = &StatelessResetError{}
func (e *StatelessResetError) Error() string {
return fmt.Sprintf("received a stateless reset with token %x", e.Token)
}
func (e *StatelessResetError) Is(target error) bool {
return target == net.ErrClosed
return "received a stateless reset"
}
func (e *StatelessResetError) Unwrap() error { return net.ErrClosed }
func (e *StatelessResetError) Timeout() bool { return false }
func (e *StatelessResetError) Temporary() bool { return true }

View File

@@ -1,52 +0,0 @@
package qtls
import (
"crypto/tls"
"fmt"
"unsafe"
)
//go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13
var cipherSuitesTLS13 []unsafe.Pointer
//go:linkname defaultCipherSuitesTLS13 crypto/tls.defaultCipherSuitesTLS13
var defaultCipherSuitesTLS13 []uint16
//go:linkname defaultCipherSuitesTLS13NoAES crypto/tls.defaultCipherSuitesTLS13NoAES
var defaultCipherSuitesTLS13NoAES []uint16
var cipherSuitesModified bool
// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls
// such that it only contains the cipher suite with the chosen id.
// The reset function returned resets them back to the original value.
func SetCipherSuite(id uint16) (reset func()) {
if cipherSuitesModified {
panic("cipher suites modified multiple times without resetting")
}
cipherSuitesModified = true
origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...)
origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...)
origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...)
// The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls.
switch id {
case tls.TLS_AES_128_GCM_SHA256:
cipherSuitesTLS13 = cipherSuitesTLS13[:1]
case tls.TLS_CHACHA20_POLY1305_SHA256:
cipherSuitesTLS13 = cipherSuitesTLS13[1:2]
case tls.TLS_AES_256_GCM_SHA384:
cipherSuitesTLS13 = cipherSuitesTLS13[2:]
default:
panic(fmt.Sprintf("unexpected cipher suite: %d", id))
}
defaultCipherSuitesTLS13 = []uint16{id}
defaultCipherSuitesTLS13NoAES = []uint16{id}
return func() {
cipherSuitesTLS13 = origCipherSuitesTLS13
defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13
defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES
cipherSuitesModified = false
}
}

View File

@@ -1,70 +0,0 @@
package qtls
import (
"crypto/tls"
"sync"
)
type clientSessionCache struct {
mx sync.Mutex
getData func(earlyData bool) []byte
setData func(data []byte, earlyData bool) (allowEarlyData bool)
wrapped tls.ClientSessionCache
}
var _ tls.ClientSessionCache = &clientSessionCache{}
func (c *clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
c.mx.Lock()
defer c.mx.Unlock()
if cs == nil {
c.wrapped.Put(key, nil)
return
}
ticket, state, err := cs.ResumptionState()
if err != nil || state == nil {
c.wrapped.Put(key, cs)
return
}
state.Extra = append(state.Extra, addExtraPrefix(c.getData(state.EarlyData)))
newCS, err := tls.NewResumptionState(ticket, state)
if err != nil {
// It's not clear why this would error. Just save the original state.
c.wrapped.Put(key, cs)
return
}
c.wrapped.Put(key, newCS)
}
func (c *clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
c.mx.Lock()
defer c.mx.Unlock()
cs, ok := c.wrapped.Get(key)
if !ok || cs == nil {
return cs, ok
}
ticket, state, err := cs.ResumptionState()
if err != nil {
// It's not clear why this would error.
// Remove the ticket from the session cache, so we don't run into this error over and over again
c.wrapped.Put(key, nil)
return nil, false
}
// restore QUIC transport parameters and RTT stored in state.Extra
if extra := findExtraData(state.Extra); extra != nil {
earlyData := c.setData(extra, state.EarlyData)
if state.EarlyData {
state.EarlyData = earlyData
}
}
session, err := tls.NewResumptionState(ticket, state)
if err != nil {
// It's not clear why this would error.
// Remove the ticket from the session cache, so we don't run into this error over and over again
c.wrapped.Put(key, nil)
return nil, false
}
return session, true
}

View File

@@ -1,150 +0,0 @@
package qtls
import (
"bytes"
"crypto/tls"
"fmt"
"net"
"github.com/quic-go/quic-go/internal/protocol"
)
func SetupConfigForServer(
conf *tls.Config,
localAddr, remoteAddr net.Addr,
getData func() []byte,
handleSessionTicket func([]byte, bool) bool,
) *tls.Config {
// Workaround for https://github.com/golang/go/issues/60506.
// This initializes the session tickets _before_ cloning the config.
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
conf = conf.Clone()
conf.MinVersion = tls.VersionTLS13
// add callbacks to save transport parameters into the session ticket
origWrapSession := conf.WrapSession
conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) {
// Add QUIC session ticket
state.Extra = append(state.Extra, addExtraPrefix(getData()))
if origWrapSession != nil {
return origWrapSession(cs, state)
}
b, err := conf.EncryptTicket(cs, state)
return b, err
}
origUnwrapSession := conf.UnwrapSession
// UnwrapSession might be called multiple times, as the client can use multiple session tickets.
// However, using 0-RTT is only possible with the first session ticket.
// crypto/tls guarantees that this callback is called in the same order as the session ticket in the ClientHello.
var unwrapCount int
conf.UnwrapSession = func(identity []byte, connState tls.ConnectionState) (*tls.SessionState, error) {
unwrapCount++
var state *tls.SessionState
var err error
if origUnwrapSession != nil {
state, err = origUnwrapSession(identity, connState)
} else {
state, err = conf.DecryptTicket(identity, connState)
}
if err != nil || state == nil {
return nil, err
}
extra := findExtraData(state.Extra)
if extra != nil {
state.EarlyData = handleSessionTicket(extra, state.EarlyData && unwrapCount == 1)
} else {
state.EarlyData = false
}
return state, nil
}
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info)
if c != nil {
// We're returning a tls.Config here, so we need to apply this recursively.
c = SetupConfigForServer(c, localAddr, remoteAddr, getData, handleSessionTicket)
}
return c, err
}
}
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}
return conf
}
func SetupConfigForClient(
qconf *tls.QUICConfig,
getData func(earlyData bool) []byte,
setData func(data []byte, earlyData bool) (allowEarlyData bool),
) {
conf := qconf.TLSConfig
if conf.ClientSessionCache != nil {
origCache := conf.ClientSessionCache
conf.ClientSessionCache = &clientSessionCache{
wrapped: origCache,
getData: getData,
setData: setData,
}
}
}
func ToTLSEncryptionLevel(e protocol.EncryptionLevel) tls.QUICEncryptionLevel {
switch e {
case protocol.EncryptionInitial:
return tls.QUICEncryptionLevelInitial
case protocol.EncryptionHandshake:
return tls.QUICEncryptionLevelHandshake
case protocol.Encryption1RTT:
return tls.QUICEncryptionLevelApplication
case protocol.Encryption0RTT:
return tls.QUICEncryptionLevelEarly
default:
panic(fmt.Sprintf("unexpected encryption level: %s", e))
}
}
func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) protocol.EncryptionLevel {
switch e {
case tls.QUICEncryptionLevelInitial:
return protocol.EncryptionInitial
case tls.QUICEncryptionLevelHandshake:
return protocol.EncryptionHandshake
case tls.QUICEncryptionLevelApplication:
return protocol.Encryption1RTT
case tls.QUICEncryptionLevelEarly:
return protocol.Encryption0RTT
default:
panic(fmt.Sprintf("unexpect encryption level: %s", e))
}
}
const extraPrefix = "quic-go1"
func addExtraPrefix(b []byte) []byte {
return append([]byte(extraPrefix), b...)
}
func findExtraData(extras [][]byte) []byte {
prefix := []byte(extraPrefix)
for _, extra := range extras {
if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) {
continue
}
return extra[len(prefix):]
}
return nil
}

View File

@@ -19,7 +19,7 @@ func NewBufferedWriteCloser(writer *bufio.Writer, closer io.Closer) io.WriteClos
}
func (h bufferedWriteCloser) Close() error {
if err := h.Writer.Flush(); err != nil {
if err := h.Flush(); err != nil {
return err
}
return h.Closer.Close()

View File

@@ -1,21 +0,0 @@
package utils
import (
"bytes"
"io"
)
// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers.
type ByteOrder interface {
Uint32([]byte) uint32
Uint24([]byte) uint32
Uint16([]byte) uint16
ReadUint32(io.ByteReader) (uint32, error)
ReadUint24(io.ByteReader) (uint32, error)
ReadUint16(io.ByteReader) (uint16, error)
WriteUint32(*bytes.Buffer, uint32)
WriteUint24(*bytes.Buffer, uint32)
WriteUint16(*bytes.Buffer, uint16)
}

View File

@@ -1,103 +0,0 @@
package utils
import (
"bytes"
"encoding/binary"
"io"
)
// BigEndian is the big-endian implementation of ByteOrder.
var BigEndian ByteOrder = bigEndian{}
type bigEndian struct{}
var _ ByteOrder = &bigEndian{}
// ReadUintN reads N bytes
func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
var res uint64
for i := uint8(0); i < length; i++ {
bt, err := b.ReadByte()
if err != nil {
return 0, err
}
res ^= uint64(bt) << ((length - 1 - i) * 8)
}
return res, nil
}
// ReadUint32 reads a uint32
func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) {
var b1, b2, b3, b4 uint8
var err error
if b4, err = b.ReadByte(); err != nil {
return 0, err
}
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil
}
// ReadUint24 reads a uint24
func (bigEndian) ReadUint24(b io.ByteReader) (uint32, error) {
var b1, b2, b3 uint8
var err error
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16, nil
}
// ReadUint16 reads a uint16
func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) {
var b1, b2 uint8
var err error
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
return uint16(b1) + uint16(b2)<<8, nil
}
func (bigEndian) Uint32(b []byte) uint32 {
return binary.BigEndian.Uint32(b)
}
func (bigEndian) Uint24(b []byte) uint32 {
_ = b[2] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[2]) | uint32(b[1])<<8 | uint32(b[0])<<16
}
func (bigEndian) Uint16(b []byte) uint16 {
return binary.BigEndian.Uint16(b)
}
// WriteUint32 writes a uint32
func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) {
b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)})
}
// WriteUint24 writes a uint24
func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) {
b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)})
}
// WriteUint16 writes a uint16
func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) {
b.Write([]byte{uint8(i >> 8), uint8(i)})
}

View File

@@ -1,10 +0,0 @@
package utils
import "net"
func IsIPv4(ip net.IP) bool {
// If ip is not an IPv4 address, To4 returns nil.
// Note that there might be some corner cases, where this is not correct.
// See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6.
return ip.To4() != nil
}

View File

@@ -1,36 +0,0 @@
package utils
import (
"math"
"time"
)
// InfDuration is a duration of infinite length
const InfDuration = time.Duration(math.MaxInt64)
// MinNonZeroDuration return the minimum duration that's not zero.
func MinNonZeroDuration(a, b time.Duration) time.Duration {
if a == 0 {
return b
}
if b == 0 {
return a
}
return min(a, b)
}
// MinTime returns the earlier time
func MinTime(a, b time.Time) time.Time {
if a.After(b) {
return b
}
return a
}
// MaxTime returns the later time
func MaxTime(a, b time.Time) time.Time {
if a.After(b) {
return a
}
return b
}

View File

@@ -27,11 +27,6 @@ type RTTStats struct {
maxAckDelay time.Duration
}
// NewRTTStats makes a properly initialized RTTStats object
func NewRTTStats() *RTTStats {
return &RTTStats{}
}
// MinRTT Returns the minRTT for the entire connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
@@ -63,8 +58,8 @@ func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration {
}
// UpdateRTT updates the RTT based on a new sample.
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
if sendDelta == InfDuration || sendDelta <= 0 {
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration) {
if sendDelta <= 0 {
return
}
@@ -114,18 +109,11 @@ func (r *RTTStats) SetInitialRTT(t time.Duration) {
r.latestRTT = t
}
// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset.
func (r *RTTStats) OnConnectionMigration() {
r.latestRTT = 0
func (r *RTTStats) ResetForPathMigration() {
r.hasMeasurement = false
r.minRTT = 0
r.latestRTT = 0
r.smoothedRTT = 0
r.meanDeviation = 0
}
// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt
// is larger. The mean deviation is increased to the most recent deviation if
// it's larger.
func (r *RTTStats) ExpireSmoothedMetrics() {
r.meanDeviation = max(r.meanDeviation, (r.smoothedRTT - r.latestRTT).Abs())
r.smoothedRTT = max(r.smoothedRTT, r.latestRTT)
// max_ack_delay remains valid
}

View File

@@ -2,11 +2,11 @@ package wire
import (
"errors"
"math"
"sort"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
)
@@ -40,7 +40,7 @@ func parseAckFrame(frame *AckFrame, b []byte, typ uint64, ackDelayExponent uint8
delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
if delayTime < 0 {
// If the delay time overflows, set it to the maximum encode-able value.
delayTime = utils.InfDuration
delayTime = time.Duration(math.MaxInt64)
}
frame.DelayTime = delayTime

View File

@@ -1,7 +1,6 @@
package wire
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
@@ -32,66 +31,23 @@ type ExtendedHeader struct {
parsedLen protocol.ByteCount
}
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.Version) (bool /* reserved bits valid */, error) {
startLen := b.Len()
func (h *ExtendedHeader) parse(data []byte) (bool /* reserved bits valid */, error) {
// read the (now unencrypted) first byte
var err error
h.typeByte, err = b.ReadByte()
if err != nil {
return false, err
}
if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil {
return false, err
}
reservedBitsValid, err := h.parseLongHeader(b, v)
if err != nil {
return false, err
}
h.parsedLen = protocol.ByteCount(startLen - b.Len())
return reservedBitsValid, err
}
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.Version) (bool /* reserved bits valid */, error) {
if err := h.readPacketNumber(b); err != nil {
return false, err
}
if h.typeByte&0xc != 0 {
return false, nil
}
return true, nil
}
func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
h.typeByte = data[0]
h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
n, err := b.ReadByte()
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen2:
n, err := utils.BigEndian.ReadUint16(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen3:
n, err := utils.BigEndian.ReadUint24(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen4:
n, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
default:
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
if protocol.ByteCount(len(data)) < h.Header.ParsedLen()+protocol.ByteCount(h.PacketNumberLen) {
return false, io.EOF
}
return nil
pn, err := readPacketNumber(data[h.Header.ParsedLen():], h.PacketNumberLen)
if err != nil {
return true, nil
}
h.PacketNumber = pn
reservedBitsValid := h.typeByte&0xc == 0
h.parsedLen = h.Header.ParsedLen() + protocol.ByteCount(h.PacketNumberLen)
return reservedBitsValid, err
}
// Append appends the Header.

View File

@@ -0,0 +1,21 @@
package wire
import (
"github.com/quic-go/quic-go/internal/protocol"
)
// A Frame in QUIC
type Frame interface {
Append(b []byte, version protocol.Version) ([]byte, error)
Length(version protocol.Version) protocol.ByteCount
}
// IsProbingFrame returns true if the frame is a probing frame.
// See section 9.1 of RFC 9000.
func IsProbingFrame(f Frame) bool {
switch f.(type) {
case *PathChallengeFrame, *PathResponseFrame, *NewConnectionIDFrame:
return true
}
return false
}

View File

@@ -1,7 +1,6 @@
package wire
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
@@ -40,37 +39,27 @@ func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.Connecti
// https://datatracker.ietf.org/doc/html/rfc8999#section-5.1.
// This function should only be called on Long Header packets for which we don't support the version.
func ParseArbitraryLenConnectionIDs(data []byte) (bytesParsed int, dest, src protocol.ArbitraryLenConnectionID, _ error) {
r := bytes.NewReader(data)
remaining := r.Len()
src, dest, err := parseArbitraryLenConnectionIDs(r)
return remaining - r.Len(), src, dest, err
}
func parseArbitraryLenConnectionIDs(r *bytes.Reader) (dest, src protocol.ArbitraryLenConnectionID, _ error) {
r.Seek(5, io.SeekStart) // skip first byte and version field
destConnIDLen, err := r.ReadByte()
if err != nil {
return nil, nil, err
startLen := len(data)
if len(data) < 6 {
return 0, nil, nil, io.EOF
}
data = data[5:] // skip first byte and version field
destConnIDLen := data[0]
data = data[1:]
destConnID := make(protocol.ArbitraryLenConnectionID, destConnIDLen)
if _, err := io.ReadFull(r, destConnID); err != nil {
if err == io.ErrUnexpectedEOF {
err = io.EOF
}
return nil, nil, err
if len(data) < int(destConnIDLen)+1 {
return 0, nil, nil, io.EOF
}
srcConnIDLen, err := r.ReadByte()
if err != nil {
return nil, nil, err
copy(destConnID, data)
data = data[destConnIDLen:]
srcConnIDLen := data[0]
data = data[1:]
if len(data) < int(srcConnIDLen) {
return 0, nil, nil, io.EOF
}
srcConnID := make(protocol.ArbitraryLenConnectionID, srcConnIDLen)
if _, err := io.ReadFull(r, srcConnID); err != nil {
if err == io.ErrUnexpectedEOF {
err = io.EOF
}
return nil, nil, err
}
return destConnID, srcConnID, nil
copy(srcConnID, data)
return startLen - len(data) + int(srcConnIDLen), destConnID, srcConnID, nil
}
func IsPotentialQUICPacket(firstByte byte) bool {
@@ -274,9 +263,9 @@ func (h *Header) ParsedLen() protocol.ByteCount {
// ParseExtended parses the version dependent part of the header.
// The Reader has to be set such that it points to the first byte of the header.
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.Version) (*ExtendedHeader, error) {
func (h *Header) ParseExtended(data []byte) (*ExtendedHeader, error) {
extHdr := h.toExtendedHeader()
reservedBitsValid, err := extHdr.parse(b, ver)
reservedBitsValid, err := extHdr.parse(data)
if err != nil {
return nil, err
}
@@ -294,3 +283,20 @@ func (h *Header) toExtendedHeader() *ExtendedHeader {
func (h *Header) PacketType() string {
return h.Type.String()
}
func readPacketNumber(data []byte, pnLen protocol.PacketNumberLen) (protocol.PacketNumber, error) {
var pn protocol.PacketNumber
switch pnLen {
case protocol.PacketNumberLen1:
pn = protocol.PacketNumber(data[0])
case protocol.PacketNumberLen2:
pn = protocol.PacketNumber(binary.BigEndian.Uint16(data[:2]))
case protocol.PacketNumberLen3:
pn = protocol.PacketNumber(uint32(data[2]) + uint32(data[1])<<8 + uint32(data[0])<<16)
case protocol.PacketNumberLen4:
pn = protocol.PacketNumber(binary.BigEndian.Uint32(data[:4]))
default:
return 0, fmt.Errorf("invalid packet number length: %d", pnLen)
}
return pn, nil
}

View File

@@ -1,11 +0,0 @@
package wire
import (
"github.com/quic-go/quic-go/internal/protocol"
)
// A Frame in QUIC
type Frame interface {
Append(b []byte, version protocol.Version) ([]byte, error)
Length(version protocol.Version) protocol.ByteCount
}

View File

@@ -30,7 +30,7 @@ func parseNewConnectionIDFrame(b []byte, _ protocol.Version) (*NewConnectionIDFr
}
b = b[l:]
if ret > seq {
//nolint:stylecheck
//nolint:staticcheck // SA1021: Retire Prior To is the name of the field
return nil, 0, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq)
}
if len(b) == 0 {

View File

@@ -2,7 +2,6 @@ package wire
import (
"errors"
"fmt"
"io"
"github.com/quic-go/quic-go/internal/protocol"
@@ -28,25 +27,15 @@ func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.Packet
}
pos := 1 + connIDLen
var pn protocol.PacketNumber
switch pnLen {
case protocol.PacketNumberLen1:
pn = protocol.PacketNumber(data[pos])
case protocol.PacketNumberLen2:
pn = protocol.PacketNumber(utils.BigEndian.Uint16(data[pos : pos+2]))
case protocol.PacketNumberLen3:
pn = protocol.PacketNumber(utils.BigEndian.Uint24(data[pos : pos+3]))
case protocol.PacketNumberLen4:
pn = protocol.PacketNumber(utils.BigEndian.Uint32(data[pos : pos+4]))
default:
return 0, 0, 0, 0, fmt.Errorf("invalid packet number length: %d", pnLen)
pn, err := readPacketNumber(data[pos:], pnLen)
if err != nil {
return 0, 0, 0, 0, err
}
kp := protocol.KeyPhaseZero
if data[0]&0b100 > 0 {
kp = protocol.KeyPhaseOne
}
var err error
if data[0]&0x18 != 0 {
err = ErrInvalidReservedBits
}

View File

@@ -58,7 +58,10 @@ func parseStreamFrame(b []byte, typ uint64, _ protocol.Version) (*StreamFrame, i
var frame *StreamFrame
if dataLen < protocol.MinStreamFrameBufferSize {
frame = &StreamFrame{Data: make([]byte, dataLen)}
frame = &StreamFrame{}
if dataLen > 0 {
frame.Data = make([]byte, dataLen)
}
} else {
frame = GetStreamFrame()
// The STREAM frame can't be larger than the StreamFrame we obtained from the buffer,
@@ -74,7 +77,7 @@ func parseStreamFrame(b []byte, typ uint64, _ protocol.Version) (*StreamFrame, i
frame.Fin = fin
frame.DataLenPresent = hasDataLen
if dataLen != 0 {
if dataLen > 0 {
copy(frame.Data, b)
}
if frame.Offset+frame.DataLen() > protocol.MaxByteCount {

View File

@@ -245,11 +245,15 @@ func (p *TransportParameters) readPreferredAddress(b []byte, expectedLen int) er
copy(ipv4[:], b[:4])
port4 := binary.BigEndian.Uint16(b[4:])
b = b[4+2:]
pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port4)
if port4 != 0 && ipv4 != [4]byte{} {
pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port4)
}
var ipv6 [16]byte
copy(ipv6[:], b[:16])
port6 := binary.BigEndian.Uint16(b[16:])
pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port6)
if port6 != 0 && ipv6 != [16]byte{} {
pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port6)
}
b = b[16+2:]
connIDLen := int(b[0])
b = b[1:]
@@ -391,12 +395,20 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
if p.PreferredAddress != nil {
b = quicvarint.Append(b, uint64(preferredAddressParameterID))
b = quicvarint.Append(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16)
ip4 := p.PreferredAddress.IPv4.Addr().As4()
b = append(b, ip4[:]...)
b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv4.Port())
ip6 := p.PreferredAddress.IPv6.Addr().As16()
b = append(b, ip6[:]...)
b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv6.Port())
if p.PreferredAddress.IPv4.IsValid() {
ipv4 := p.PreferredAddress.IPv4.Addr().As4()
b = append(b, ipv4[:]...)
b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv4.Port())
} else {
b = append(b, make([]byte, 6)...)
}
if p.PreferredAddress.IPv6.IsValid() {
ipv6 := p.PreferredAddress.IPv6.Addr().As16()
b = append(b, ipv6[:]...)
b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv6.Port())
} else {
b = append(b, make([]byte, 18)...)
}
b = append(b, uint8(p.PreferredAddress.ConnectionID.Len()))
b = append(b, p.PreferredAddress.ConnectionID.Bytes()...)
b = append(b, p.PreferredAddress.StatelessResetToken[:]...)

View File

@@ -16,11 +16,11 @@ func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenCon
}
b = b[n:]
if len(b) == 0 {
//nolint:stylecheck
//nolint:staticcheck // SA1021: the packet is called Version Negotiation packet
return nil, nil, nil, errors.New("Version Negotiation packet has empty version list")
}
if len(b)%4 != 0 {
//nolint:stylecheck
//nolint:staticcheck // SA1021: the packet is called Version Negotiation packet
return nil, nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length")
}
versions := make([]protocol.Version, len(b)/4)

View File

@@ -5,34 +5,36 @@ import (
"time"
)
//go:generate go run generate_multiplexer.go ConnectionTracer connection_tracer.go multiplexer.tmpl connection_tracer_multiplexer.go
// A ConnectionTracer records events.
type ConnectionTracer struct {
StartedConnection func(local, remote net.Addr, srcConnID, destConnID ConnectionID)
NegotiatedVersion func(chosen VersionNumber, clientVersions, serverVersions []VersionNumber)
ClosedConnection func(error)
SentTransportParameters func(*TransportParameters)
ReceivedTransportParameters func(*TransportParameters)
NegotiatedVersion func(chosen Version, clientVersions, serverVersions []Version)
ClosedConnection func(err error)
SentTransportParameters func(parameters *TransportParameters)
ReceivedTransportParameters func(parameters *TransportParameters)
RestoredTransportParameters func(parameters *TransportParameters) // for 0-RTT
SentLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame)
SentShortHeaderPacket func(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame)
ReceivedVersionNegotiationPacket func(dest, src ArbitraryLenConnectionID, _ []VersionNumber)
ReceivedRetry func(*Header)
ReceivedLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, []Frame)
ReceivedShortHeaderPacket func(*ShortHeader, ByteCount, ECN, []Frame)
BufferedPacket func(PacketType, ByteCount)
DroppedPacket func(PacketType, PacketNumber, ByteCount, PacketDropReason)
SentLongHeaderPacket func(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame)
SentShortHeaderPacket func(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame)
ReceivedVersionNegotiationPacket func(dest, src ArbitraryLenConnectionID, versions []Version)
ReceivedRetry func(hdr *Header)
ReceivedLongHeaderPacket func(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame)
ReceivedShortHeaderPacket func(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame)
BufferedPacket func(packetType PacketType, size ByteCount)
DroppedPacket func(packetType PacketType, pn PacketNumber, size ByteCount, reason PacketDropReason)
UpdatedMetrics func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int)
AcknowledgedPacket func(EncryptionLevel, PacketNumber)
LostPacket func(EncryptionLevel, PacketNumber, PacketLossReason)
AcknowledgedPacket func(encLevel EncryptionLevel, pn PacketNumber)
LostPacket func(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason)
UpdatedMTU func(mtu ByteCount, done bool)
UpdatedCongestionState func(CongestionState)
UpdatedCongestionState func(state CongestionState)
UpdatedPTOCount func(value uint32)
UpdatedKeyFromTLS func(EncryptionLevel, Perspective)
UpdatedKeyFromTLS func(encLevel EncryptionLevel, p Perspective)
UpdatedKey func(keyPhase KeyPhase, remote bool)
DroppedEncryptionLevel func(EncryptionLevel)
DroppedEncryptionLevel func(encLevel EncryptionLevel)
DroppedKey func(keyPhase KeyPhase)
SetLossTimer func(TimerType, EncryptionLevel, time.Time)
LossTimerExpired func(TimerType, EncryptionLevel)
SetLossTimer func(timerType TimerType, encLevel EncryptionLevel, time time.Time)
LossTimerExpired func(timerType TimerType, encLevel EncryptionLevel)
LossTimerCanceled func()
ECNStateUpdated func(state ECNState, trigger ECNStateTrigger)
ChoseALPN func(protocol string)
@@ -40,232 +42,3 @@ type ConnectionTracer struct {
Close func()
Debug func(name, msg string)
}
// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers.
func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTracer {
if len(tracers) == 0 {
return nil
}
if len(tracers) == 1 {
return tracers[0]
}
return &ConnectionTracer{
StartedConnection: func(local, remote net.Addr, srcConnID, destConnID ConnectionID) {
for _, t := range tracers {
if t.StartedConnection != nil {
t.StartedConnection(local, remote, srcConnID, destConnID)
}
}
},
NegotiatedVersion: func(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) {
for _, t := range tracers {
if t.NegotiatedVersion != nil {
t.NegotiatedVersion(chosen, clientVersions, serverVersions)
}
}
},
ClosedConnection: func(e error) {
for _, t := range tracers {
if t.ClosedConnection != nil {
t.ClosedConnection(e)
}
}
},
SentTransportParameters: func(tp *TransportParameters) {
for _, t := range tracers {
if t.SentTransportParameters != nil {
t.SentTransportParameters(tp)
}
}
},
ReceivedTransportParameters: func(tp *TransportParameters) {
for _, t := range tracers {
if t.ReceivedTransportParameters != nil {
t.ReceivedTransportParameters(tp)
}
}
},
RestoredTransportParameters: func(tp *TransportParameters) {
for _, t := range tracers {
if t.RestoredTransportParameters != nil {
t.RestoredTransportParameters(tp)
}
}
},
SentLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) {
for _, t := range tracers {
if t.SentLongHeaderPacket != nil {
t.SentLongHeaderPacket(hdr, size, ecn, ack, frames)
}
}
},
SentShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) {
for _, t := range tracers {
if t.SentShortHeaderPacket != nil {
t.SentShortHeaderPacket(hdr, size, ecn, ack, frames)
}
}
},
ReceivedVersionNegotiationPacket: func(dest, src ArbitraryLenConnectionID, versions []VersionNumber) {
for _, t := range tracers {
if t.ReceivedVersionNegotiationPacket != nil {
t.ReceivedVersionNegotiationPacket(dest, src, versions)
}
}
},
ReceivedRetry: func(hdr *Header) {
for _, t := range tracers {
if t.ReceivedRetry != nil {
t.ReceivedRetry(hdr)
}
}
},
ReceivedLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) {
for _, t := range tracers {
if t.ReceivedLongHeaderPacket != nil {
t.ReceivedLongHeaderPacket(hdr, size, ecn, frames)
}
}
},
ReceivedShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) {
for _, t := range tracers {
if t.ReceivedShortHeaderPacket != nil {
t.ReceivedShortHeaderPacket(hdr, size, ecn, frames)
}
}
},
BufferedPacket: func(typ PacketType, size ByteCount) {
for _, t := range tracers {
if t.BufferedPacket != nil {
t.BufferedPacket(typ, size)
}
}
},
DroppedPacket: func(typ PacketType, pn PacketNumber, size ByteCount, reason PacketDropReason) {
for _, t := range tracers {
if t.DroppedPacket != nil {
t.DroppedPacket(typ, pn, size, reason)
}
}
},
UpdatedMetrics: func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) {
for _, t := range tracers {
if t.UpdatedMetrics != nil {
t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight)
}
}
},
AcknowledgedPacket: func(encLevel EncryptionLevel, pn PacketNumber) {
for _, t := range tracers {
if t.AcknowledgedPacket != nil {
t.AcknowledgedPacket(encLevel, pn)
}
}
},
LostPacket: func(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) {
for _, t := range tracers {
if t.LostPacket != nil {
t.LostPacket(encLevel, pn, reason)
}
}
},
UpdatedMTU: func(mtu ByteCount, done bool) {
for _, t := range tracers {
if t.UpdatedMTU != nil {
t.UpdatedMTU(mtu, done)
}
}
},
UpdatedCongestionState: func(state CongestionState) {
for _, t := range tracers {
if t.UpdatedCongestionState != nil {
t.UpdatedCongestionState(state)
}
}
},
UpdatedPTOCount: func(value uint32) {
for _, t := range tracers {
if t.UpdatedPTOCount != nil {
t.UpdatedPTOCount(value)
}
}
},
UpdatedKeyFromTLS: func(encLevel EncryptionLevel, perspective Perspective) {
for _, t := range tracers {
if t.UpdatedKeyFromTLS != nil {
t.UpdatedKeyFromTLS(encLevel, perspective)
}
}
},
UpdatedKey: func(generation KeyPhase, remote bool) {
for _, t := range tracers {
if t.UpdatedKey != nil {
t.UpdatedKey(generation, remote)
}
}
},
DroppedEncryptionLevel: func(encLevel EncryptionLevel) {
for _, t := range tracers {
if t.DroppedEncryptionLevel != nil {
t.DroppedEncryptionLevel(encLevel)
}
}
},
DroppedKey: func(generation KeyPhase) {
for _, t := range tracers {
if t.DroppedKey != nil {
t.DroppedKey(generation)
}
}
},
SetLossTimer: func(typ TimerType, encLevel EncryptionLevel, exp time.Time) {
for _, t := range tracers {
if t.SetLossTimer != nil {
t.SetLossTimer(typ, encLevel, exp)
}
}
},
LossTimerExpired: func(typ TimerType, encLevel EncryptionLevel) {
for _, t := range tracers {
if t.LossTimerExpired != nil {
t.LossTimerExpired(typ, encLevel)
}
}
},
LossTimerCanceled: func() {
for _, t := range tracers {
if t.LossTimerCanceled != nil {
t.LossTimerCanceled()
}
}
},
ECNStateUpdated: func(state ECNState, trigger ECNStateTrigger) {
for _, t := range tracers {
if t.ECNStateUpdated != nil {
t.ECNStateUpdated(state, trigger)
}
}
},
ChoseALPN: func(protocol string) {
for _, t := range tracers {
if t.ChoseALPN != nil {
t.ChoseALPN(protocol)
}
}
},
Close: func() {
for _, t := range tracers {
if t.Close != nil {
t.Close()
}
}
},
Debug: func(name, msg string) {
for _, t := range tracers {
if t.Debug != nil {
t.Debug(name, msg)
}
}
},
}
}

View File

@@ -0,0 +1,236 @@
// Code generated by generate_multiplexer.go; DO NOT EDIT.
package logging
import (
"net"
"time"
)
func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTracer {
if len(tracers) == 0 {
return nil
}
if len(tracers) == 1 {
return tracers[0]
}
return &ConnectionTracer{
StartedConnection: func(local net.Addr, remote net.Addr, srcConnID ConnectionID, destConnID ConnectionID) {
for _, t := range tracers {
if t.StartedConnection != nil {
t.StartedConnection(local, remote, srcConnID, destConnID)
}
}
},
NegotiatedVersion: func(chosen Version, clientVersions []Version, serverVersions []Version) {
for _, t := range tracers {
if t.NegotiatedVersion != nil {
t.NegotiatedVersion(chosen, clientVersions, serverVersions)
}
}
},
ClosedConnection: func(err error) {
for _, t := range tracers {
if t.ClosedConnection != nil {
t.ClosedConnection(err)
}
}
},
SentTransportParameters: func(parameters *TransportParameters) {
for _, t := range tracers {
if t.SentTransportParameters != nil {
t.SentTransportParameters(parameters)
}
}
},
ReceivedTransportParameters: func(parameters *TransportParameters) {
for _, t := range tracers {
if t.ReceivedTransportParameters != nil {
t.ReceivedTransportParameters(parameters)
}
}
},
RestoredTransportParameters: func(parameters *TransportParameters) {
for _, t := range tracers {
if t.RestoredTransportParameters != nil {
t.RestoredTransportParameters(parameters)
}
}
},
SentLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) {
for _, t := range tracers {
if t.SentLongHeaderPacket != nil {
t.SentLongHeaderPacket(hdr, size, ecn, ack, frames)
}
}
},
SentShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) {
for _, t := range tracers {
if t.SentShortHeaderPacket != nil {
t.SentShortHeaderPacket(hdr, size, ecn, ack, frames)
}
}
},
ReceivedVersionNegotiationPacket: func(dest ArbitraryLenConnectionID, src ArbitraryLenConnectionID, versions []Version) {
for _, t := range tracers {
if t.ReceivedVersionNegotiationPacket != nil {
t.ReceivedVersionNegotiationPacket(dest, src, versions)
}
}
},
ReceivedRetry: func(hdr *Header) {
for _, t := range tracers {
if t.ReceivedRetry != nil {
t.ReceivedRetry(hdr)
}
}
},
ReceivedLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) {
for _, t := range tracers {
if t.ReceivedLongHeaderPacket != nil {
t.ReceivedLongHeaderPacket(hdr, size, ecn, frames)
}
}
},
ReceivedShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) {
for _, t := range tracers {
if t.ReceivedShortHeaderPacket != nil {
t.ReceivedShortHeaderPacket(hdr, size, ecn, frames)
}
}
},
BufferedPacket: func(packetType PacketType, size ByteCount) {
for _, t := range tracers {
if t.BufferedPacket != nil {
t.BufferedPacket(packetType, size)
}
}
},
DroppedPacket: func(packetType PacketType, pn PacketNumber, size ByteCount, reason PacketDropReason) {
for _, t := range tracers {
if t.DroppedPacket != nil {
t.DroppedPacket(packetType, pn, size, reason)
}
}
},
UpdatedMetrics: func(rttStats *RTTStats, cwnd ByteCount, bytesInFlight ByteCount, packetsInFlight int) {
for _, t := range tracers {
if t.UpdatedMetrics != nil {
t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight)
}
}
},
AcknowledgedPacket: func(encLevel EncryptionLevel, pn PacketNumber) {
for _, t := range tracers {
if t.AcknowledgedPacket != nil {
t.AcknowledgedPacket(encLevel, pn)
}
}
},
LostPacket: func(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) {
for _, t := range tracers {
if t.LostPacket != nil {
t.LostPacket(encLevel, pn, reason)
}
}
},
UpdatedMTU: func(mtu ByteCount, done bool) {
for _, t := range tracers {
if t.UpdatedMTU != nil {
t.UpdatedMTU(mtu, done)
}
}
},
UpdatedCongestionState: func(state CongestionState) {
for _, t := range tracers {
if t.UpdatedCongestionState != nil {
t.UpdatedCongestionState(state)
}
}
},
UpdatedPTOCount: func(value uint32) {
for _, t := range tracers {
if t.UpdatedPTOCount != nil {
t.UpdatedPTOCount(value)
}
}
},
UpdatedKeyFromTLS: func(encLevel EncryptionLevel, p Perspective) {
for _, t := range tracers {
if t.UpdatedKeyFromTLS != nil {
t.UpdatedKeyFromTLS(encLevel, p)
}
}
},
UpdatedKey: func(keyPhase KeyPhase, remote bool) {
for _, t := range tracers {
if t.UpdatedKey != nil {
t.UpdatedKey(keyPhase, remote)
}
}
},
DroppedEncryptionLevel: func(encLevel EncryptionLevel) {
for _, t := range tracers {
if t.DroppedEncryptionLevel != nil {
t.DroppedEncryptionLevel(encLevel)
}
}
},
DroppedKey: func(keyPhase KeyPhase) {
for _, t := range tracers {
if t.DroppedKey != nil {
t.DroppedKey(keyPhase)
}
}
},
SetLossTimer: func(timerType TimerType, encLevel EncryptionLevel, time time.Time) {
for _, t := range tracers {
if t.SetLossTimer != nil {
t.SetLossTimer(timerType, encLevel, time)
}
}
},
LossTimerExpired: func(timerType TimerType, encLevel EncryptionLevel) {
for _, t := range tracers {
if t.LossTimerExpired != nil {
t.LossTimerExpired(timerType, encLevel)
}
}
},
LossTimerCanceled: func() {
for _, t := range tracers {
if t.LossTimerCanceled != nil {
t.LossTimerCanceled()
}
}
},
ECNStateUpdated: func(state ECNState, trigger ECNStateTrigger) {
for _, t := range tracers {
if t.ECNStateUpdated != nil {
t.ECNStateUpdated(state, trigger)
}
}
},
ChoseALPN: func(protocol string) {
for _, t := range tracers {
if t.ChoseALPN != nil {
t.ChoseALPN(protocol)
}
}
},
Close: func() {
for _, t := range tracers {
if t.Close != nil {
t.Close()
}
}
},
Debug: func(name string, msg string) {
for _, t := range tracers {
if t.Debug != nil {
t.Debug(name, msg)
}
}
},
}
}

View File

@@ -0,0 +1,161 @@
//go:build generate
package main
import (
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"log"
"os"
"strings"
"text/template"
"golang.org/x/tools/imports"
)
func main() {
if len(os.Args) != 5 {
log.Fatalf("Usage: %s <struct_name> <input_file> <template_file> <output_file>", os.Args[0])
}
structName := os.Args[1]
inputFile := os.Args[2]
templateFile := os.Args[3]
outputFile := os.Args[4]
fset := token.NewFileSet()
// Parse the input file containing the struct type
file, err := parser.ParseFile(fset, inputFile, nil, parser.AllErrors)
if err != nil {
log.Fatalf("Failed to parse file: %v", err)
}
var fields []*ast.Field
// Find the specified struct type in the AST
for _, decl := range file.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*ast.TypeSpec)
if !ok || typeSpec.Name.Name != structName {
continue
}
structType, ok := typeSpec.Type.(*ast.StructType)
if !ok {
log.Fatalf("%s is not a struct", structName)
}
fields = structType.Fields.List
break
}
}
if fields == nil {
log.Fatalf("Could not find %s type", structName)
}
// Prepare data for the template
type FieldData struct {
Name string
Params string
Args string
HasParams bool
ReturnTypes string
HasReturn bool
}
var fieldDataList []FieldData
for _, field := range fields {
funcType, ok := field.Type.(*ast.FuncType)
if !ok {
continue
}
for _, name := range field.Names {
fieldData := FieldData{Name: name.Name}
// extract parameters
var params []string
var args []string
if funcType.Params != nil {
for i, param := range funcType.Params.List {
// We intentionally reject unnamed (and, further down, "_") function parameters.
// We could auto-generate parameter names,
// but having meaningful variable names will be more helpful for the user.
if len(param.Names) == 0 {
log.Fatalf("encountered unnamed parameter at position %d in function %s", i, fieldData.Name)
}
var buf bytes.Buffer
printer.Fprint(&buf, fset, param.Type)
paramType := buf.String()
for _, paramName := range param.Names {
if paramName.Name == "_" {
log.Fatalf("encountered underscore parameter at position %d in function %s", i, fieldData.Name)
}
params = append(params, fmt.Sprintf("%s %s", paramName.Name, paramType))
args = append(args, paramName.Name)
}
}
}
fieldData.Params = strings.Join(params, ", ")
fieldData.Args = strings.Join(args, ", ")
fieldData.HasParams = len(params) > 0
// extract return types
if funcType.Results != nil && len(funcType.Results.List) > 0 {
fieldData.HasReturn = true
var returns []string
for _, result := range funcType.Results.List {
var buf bytes.Buffer
printer.Fprint(&buf, fset, result.Type)
returns = append(returns, buf.String())
}
if len(returns) == 1 {
fieldData.ReturnTypes = fmt.Sprintf(" %s", returns[0])
} else {
fieldData.ReturnTypes = fmt.Sprintf(" (%s)", strings.Join(returns, ", "))
}
}
fieldDataList = append(fieldDataList, fieldData)
}
}
// Read the template from file
templateContent, err := os.ReadFile(templateFile)
if err != nil {
log.Fatalf("Failed to read template file: %v", err)
}
// Generate the code using the template
tmpl, err := template.New("multiplexer").Funcs(template.FuncMap{"join": strings.Join}).Parse(string(templateContent))
if err != nil {
log.Fatalf("Failed to parse template: %v", err)
}
var generatedCode bytes.Buffer
generatedCode.WriteString("// Code generated by generate_multiplexer.go; DO NOT EDIT.\n\n")
if err = tmpl.Execute(&generatedCode, map[string]interface{}{
"Fields": fieldDataList,
"StructName": structName,
}); err != nil {
log.Fatalf("Failed to execute template: %v", err)
}
// Format the generated code and add imports
formattedCode, err := imports.Process(outputFile, generatedCode.Bytes(), nil)
if err != nil {
log.Fatalf("Failed to process imports: %v", err)
}
if err := os.WriteFile(outputFile, formattedCode, 0o644); err != nil {
log.Fatalf("Failed to write output file: %v", err)
}
}

View File

@@ -36,8 +36,8 @@ type (
StreamNum = protocol.StreamNum
// The StreamType is the type of the stream (unidirectional or bidirectional).
StreamType = protocol.StreamType
// The VersionNumber is the QUIC version.
VersionNumber = protocol.Version
// The Version is the QUIC version.
Version = protocol.Version
// The Header is the QUIC packet header, before removing header protection.
Header = wire.Header
@@ -72,27 +72,27 @@ const (
const (
// KeyPhaseZero is key phase bit 0
KeyPhaseZero KeyPhaseBit = protocol.KeyPhaseZero
KeyPhaseZero = protocol.KeyPhaseZero
// KeyPhaseOne is key phase bit 1
KeyPhaseOne KeyPhaseBit = protocol.KeyPhaseOne
KeyPhaseOne = protocol.KeyPhaseOne
)
const (
// PerspectiveServer is used for a QUIC server
PerspectiveServer Perspective = protocol.PerspectiveServer
PerspectiveServer = protocol.PerspectiveServer
// PerspectiveClient is used for a QUIC client
PerspectiveClient Perspective = protocol.PerspectiveClient
PerspectiveClient = protocol.PerspectiveClient
)
const (
// EncryptionInitial is the Initial encryption level
EncryptionInitial EncryptionLevel = protocol.EncryptionInitial
EncryptionInitial = protocol.EncryptionInitial
// EncryptionHandshake is the Handshake encryption level
EncryptionHandshake EncryptionLevel = protocol.EncryptionHandshake
EncryptionHandshake = protocol.EncryptionHandshake
// Encryption1RTT is the 1-RTT encryption level
Encryption1RTT EncryptionLevel = protocol.Encryption1RTT
Encryption1RTT = protocol.Encryption1RTT
// Encryption0RTT is the 0-RTT encryption level
Encryption0RTT EncryptionLevel = protocol.Encryption0RTT
Encryption0RTT = protocol.Encryption0RTT
)
const (

View File

@@ -0,0 +1,21 @@
package logging
func NewMultiplexed{{ .StructName }} (tracers ...*{{ .StructName }}) *{{ .StructName }} {
if len(tracers) == 0 {
return nil
}
if len(tracers) == 1 {
return tracers[0]
}
return &{{ .StructName }}{
{{- range .Fields }}
{{ .Name }}: func({{ .Params }}){{ .ReturnTypes }} {
for _, t := range tracers {
if t.{{ .Name }} != nil {
t.{{ .Name }}({{ .Args }})
}
}
},
{{- end }}
}
}

View File

@@ -2,58 +2,13 @@ package logging
import "net"
//go:generate go run generate_multiplexer.go Tracer tracer.go multiplexer.tmpl tracer_multiplexer.go
// A Tracer traces events.
type Tracer struct {
SentPacket func(net.Addr, *Header, ByteCount, []Frame)
SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber)
DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason)
SentPacket func(dest net.Addr, hdr *Header, size ByteCount, frames []Frame)
SentVersionNegotiationPacket func(dest net.Addr, destConnID, srcConnID ArbitraryLenConnectionID, versions []Version)
DroppedPacket func(addr net.Addr, packetType PacketType, size ByteCount, reason PacketDropReason)
Debug func(name, msg string)
Close func()
}
// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers.
func NewMultiplexedTracer(tracers ...*Tracer) *Tracer {
if len(tracers) == 0 {
return nil
}
if len(tracers) == 1 {
return tracers[0]
}
return &Tracer{
SentPacket: func(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) {
for _, t := range tracers {
if t.SentPacket != nil {
t.SentPacket(remote, hdr, size, frames)
}
}
},
SentVersionNegotiationPacket: func(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []VersionNumber) {
for _, t := range tracers {
if t.SentVersionNegotiationPacket != nil {
t.SentVersionNegotiationPacket(remote, dest, src, versions)
}
}
},
DroppedPacket: func(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) {
for _, t := range tracers {
if t.DroppedPacket != nil {
t.DroppedPacket(remote, typ, size, reason)
}
}
},
Debug: func(name, msg string) {
for _, t := range tracers {
if t.Debug != nil {
t.Debug(name, msg)
}
}
},
Close: func() {
for _, t := range tracers {
if t.Close != nil {
t.Close()
}
}
},
}
}

View File

@@ -0,0 +1,51 @@
// Code generated by generate_multiplexer.go; DO NOT EDIT.
package logging
import "net"
func NewMultiplexedTracer(tracers ...*Tracer) *Tracer {
if len(tracers) == 0 {
return nil
}
if len(tracers) == 1 {
return tracers[0]
}
return &Tracer{
SentPacket: func(dest net.Addr, hdr *Header, size ByteCount, frames []Frame) {
for _, t := range tracers {
if t.SentPacket != nil {
t.SentPacket(dest, hdr, size, frames)
}
}
},
SentVersionNegotiationPacket: func(dest net.Addr, destConnID ArbitraryLenConnectionID, srcConnID ArbitraryLenConnectionID, versions []Version) {
for _, t := range tracers {
if t.SentVersionNegotiationPacket != nil {
t.SentVersionNegotiationPacket(dest, destConnID, srcConnID, versions)
}
}
},
DroppedPacket: func(addr net.Addr, packetType PacketType, size ByteCount, reason PacketDropReason) {
for _, t := range tracers {
if t.DroppedPacket != nil {
t.DroppedPacket(addr, packetType, size, reason)
}
}
},
Debug: func(name string, msg string) {
for _, t := range tracers {
if t.Debug != nil {
t.Debug(name, msg)
}
}
},
Close: func() {
for _, t := range tracers {
if t.Close != nil {
t.Close()
}
}
},
}
}

View File

@@ -63,9 +63,11 @@ type TimerType uint8
const (
// TimerTypeACK is the timer type for the early retransmit timer
TimerTypeACK TimerType = iota
TimerTypeACK TimerType = iota + 1
// TimerTypePTO is the timer type for the PTO retransmit timer
TimerTypePTO
// TimerTypePathProbe is the timer type for the path probe retransmit timer
TimerTypePathProbe
)
// TimeoutReason is the reason why a connection is closed

View File

@@ -14,23 +14,17 @@ type Sender = sender
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_internal_test.go github.com/quic-go/quic-go StreamI"
type StreamI = streamI
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_stream_test.go github.com/quic-go/quic-go CryptoStream"
type CryptoStream = cryptoStream
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_receive_stream_internal_test.go github.com/quic-go/quic-go ReceiveStreamI"
type ReceiveStreamI = receiveStreamI
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_stream_internal_test.go github.com/quic-go/quic-go SendStreamI"
type SendStreamI = sendStreamI
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_getter_test.go github.com/quic-go/quic-go StreamGetter"
type StreamGetter = streamGetter
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender"
type StreamSender = streamSender
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_data_handler_test.go github.com/quic-go/quic-go CryptoDataHandler"
type CryptoDataHandler = cryptoDataHandler
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_control_frame_getter_test.go github.com/quic-go/quic-go StreamControlFrameGetter"
type StreamControlFrameGetter = streamControlFrameGetter
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_frame_source_test.go github.com/quic-go/quic-go FrameSource"
type FrameSource = frameSource
@@ -67,10 +61,4 @@ type PacketHandler = packetHandler
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager"
type PacketHandlerManager = packetHandlerManager
// Need to use source mode for the batchConn, since reflect mode follows type aliases.
// See https://github.com/golang/mock/issues/244 for details.
//
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn"
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_token_store_test.go github.com/quic-go/quic-go TokenStore"
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_packetconn_test.go net PacketConn"

View File

@@ -13,16 +13,17 @@ import (
type mtuDiscoverer interface {
// Start starts the MTU discovery process.
// It's unnecessary to call ShouldSendProbe before that.
Start()
Start(now time.Time)
ShouldSendProbe(now time.Time) bool
CurrentSize() protocol.ByteCount
GetPing() (ping ackhandler.Frame, datagramSize protocol.ByteCount)
GetPing(now time.Time) (ping ackhandler.Frame, datagramSize protocol.ByteCount)
Reset(now time.Time, start, max protocol.ByteCount)
}
const (
// At some point, we have to stop searching for a higher MTU.
// We're happy to send a packet that's 10 bytes smaller than the actual MTU.
maxMTUDiff = 20
maxMTUDiff protocol.ByteCount = 20
// send a probe packet every mtuProbeDelay RTTs
mtuProbeDelay = 5
// Once maxLostMTUProbes MTU probe packets larger than a certain size are lost,
@@ -88,18 +89,21 @@ const (
type mtuFinder struct {
lastProbeTime time.Time
mtuIncreased func(protocol.ByteCount)
rttStats *utils.RTTStats
inFlight protocol.ByteCount // the size of the probe packet currently in flight. InvalidByteCount if none is in flight
min protocol.ByteCount
limit protocol.ByteCount
// on initialization, we treat the maximum size as the first "lost" packet
lost [maxLostMTUProbes]protocol.ByteCount
lastProbeWasLost bool
// The generation is used to ignore ACKs / losses for probe packets sent before a reset.
// Resets happen when the connection is migrated to a new path.
// We're therefore not concerned about overflows of this counter.
generation uint8
tracer *logging.ConnectionTracer
}
@@ -108,17 +112,19 @@ var _ mtuDiscoverer = &mtuFinder{}
func newMTUDiscoverer(
rttStats *utils.RTTStats,
start, max protocol.ByteCount,
mtuIncreased func(protocol.ByteCount),
tracer *logging.ConnectionTracer,
) *mtuFinder {
f := &mtuFinder{
inFlight: protocol.InvalidByteCount,
min: start,
limit: max,
rttStats: rttStats,
mtuIncreased: mtuIncreased,
tracer: tracer,
inFlight: protocol.InvalidByteCount,
rttStats: rttStats,
tracer: tracer,
}
f.init(start, max)
return f
}
func (f *mtuFinder) init(start, max protocol.ByteCount) {
f.min = start
for i := range f.lost {
if i == 0 {
f.lost[i] = max
@@ -126,7 +132,6 @@ func newMTUDiscoverer(
}
f.lost[i] = protocol.InvalidByteCount
}
return f
}
func (f *mtuFinder) done() bool {
@@ -142,8 +147,8 @@ func (f *mtuFinder) max() protocol.ByteCount {
return f.lost[len(f.lost)-1]
}
func (f *mtuFinder) Start() {
f.lastProbeTime = time.Now() // makes sure the first probe packet is not sent immediately
func (f *mtuFinder) Start(now time.Time) {
f.lastProbeTime = now // makes sure the first probe packet is not sent immediately
}
func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
@@ -156,18 +161,18 @@ func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
return !now.Before(f.lastProbeTime.Add(mtuProbeDelay * f.rttStats.SmoothedRTT()))
}
func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) {
func (f *mtuFinder) GetPing(now time.Time) (ackhandler.Frame, protocol.ByteCount) {
var size protocol.ByteCount
if f.lastProbeWasLost {
size = (f.min + f.lost[0]) / 2
} else {
size = (f.min + f.max()) / 2
}
f.lastProbeTime = time.Now()
f.lastProbeTime = now
f.inFlight = size
return ackhandler.Frame{
Frame: &wire.PingFrame{},
Handler: &mtuFinderAckHandler{f},
Handler: &mtuFinderAckHandler{mtuFinder: f, generation: f.generation},
}, size
}
@@ -175,13 +180,26 @@ func (f *mtuFinder) CurrentSize() protocol.ByteCount {
return f.min
}
func (f *mtuFinder) Reset(now time.Time, start, max protocol.ByteCount) {
f.generation++
f.lastProbeTime = now
f.lastProbeWasLost = false
f.inFlight = protocol.InvalidByteCount
f.init(start, max)
}
type mtuFinderAckHandler struct {
*mtuFinder
generation uint8
}
var _ ackhandler.FrameHandler = &mtuFinderAckHandler{}
func (h *mtuFinderAckHandler) OnAcked(wire.Frame) {
if h.generation != h.mtuFinder.generation {
// ACK for probe sent before reset
return
}
size := h.inFlight
if size == protocol.InvalidByteCount {
panic("OnAcked callback called although there's no MTU probe packet in flight")
@@ -209,10 +227,13 @@ func (h *mtuFinderAckHandler) OnAcked(wire.Frame) {
if h.tracer != nil && h.tracer.UpdatedMTU != nil {
h.tracer.UpdatedMTU(size, h.done())
}
h.mtuIncreased(size)
}
func (h *mtuFinderAckHandler) OnLost(wire.Frame) {
if h.generation != h.mtuFinder.generation {
// probe sent before reset received
return
}
size := h.inFlight
if size == protocol.InvalidByteCount {
panic("OnLost callback called although there's no MTU probe packet in flight")

View File

@@ -1,75 +0,0 @@
package quic
import (
"fmt"
"net"
"sync"
"github.com/quic-go/quic-go/internal/utils"
)
var (
connMuxerOnce sync.Once
connMuxer multiplexer
)
type indexableConn interface{ LocalAddr() net.Addr }
type multiplexer interface {
AddConn(conn indexableConn)
RemoveConn(indexableConn) error
}
// The connMultiplexer listens on multiple net.PacketConns and dispatches
// incoming packets to the connection handler.
type connMultiplexer struct {
mutex sync.Mutex
conns map[string] /* LocalAddr().String() */ indexableConn
logger utils.Logger
}
var _ multiplexer = &connMultiplexer{}
func getMultiplexer() multiplexer {
connMuxerOnce.Do(func() {
connMuxer = &connMultiplexer{
conns: make(map[string]indexableConn),
logger: utils.DefaultLogger.WithPrefix("muxer"),
}
})
return connMuxer
}
func (m *connMultiplexer) index(addr net.Addr) string {
return addr.Network() + " " + addr.String()
}
func (m *connMultiplexer) AddConn(c indexableConn) {
m.mutex.Lock()
defer m.mutex.Unlock()
connIndex := m.index(c.LocalAddr())
p, ok := m.conns[connIndex]
if ok {
// Panics if we're already listening on this connection.
// This is a safeguard because we're introducing a breaking API change, see
// https://github.com/quic-go/quic-go/issues/3727 for details.
// We'll remove this at a later time, when most users of the library have made the switch.
panic("connection already exists") // TODO: write a nice message
}
m.conns[connIndex] = p
}
func (m *connMultiplexer) RemoveConn(c indexableConn) error {
m.mutex.Lock()
defer m.mutex.Unlock()
connIndex := m.index(c.LocalAddr())
if _, ok := m.conns[connIndex]; !ok {
return fmt.Errorf("cannote remove connection, connection is unknown")
}
delete(m.conns, connIndex)
return nil
}

View File

@@ -3,12 +3,12 @@
# Install Go manually, since oss-fuzz ships with an outdated Go version.
# See https://github.com/google/oss-fuzz/pull/10643.
export CXX="${CXX} -lresolv" # required by Go 1.20
wget https://go.dev/dl/go1.22.0.linux-amd64.tar.gz \
wget https://go.dev/dl/go1.23.0.linux-amd64.tar.gz \
&& mkdir temp-go \
&& rm -rf /root/.go/* \
&& tar -C temp-go/ -xzf go1.22.0.linux-amd64.tar.gz \
&& tar -C temp-go/ -xzf go1.23.0.linux-amd64.tar.gz \
&& mv temp-go/go/* /root/.go/ \
&& rm -rf temp-go go1.22.0.linux-amd64.tar.gz
&& rm -rf temp-go go1.23.0.linux-amd64.tar.gz
(
# fuzz qpack

View File

@@ -1,10 +1,6 @@
package quic
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"hash"
"io"
"net"
"sync"
@@ -56,15 +52,12 @@ type packetHandlerMap struct {
deleteRetiredConnsAfter time.Duration
statelessResetMutex sync.Mutex
statelessResetHasher hash.Hash
logger utils.Logger
}
var _ packetHandlerManager = &packetHandlerMap{}
func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap {
func newPacketHandlerMap(enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap {
h := &packetHandlerMap{
closeChan: make(chan struct{}),
handlers: make(map[protocol.ConnectionID]packetHandler),
@@ -73,9 +66,6 @@ func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePa
enqueueClosePacket: enqueueClosePacket,
logger: logger,
}
if key != nil {
h.statelessResetHasher = hmac.New(sha256.New, key[:])
}
if h.logger.Debug() {
go h.logUsage()
}
@@ -236,20 +226,3 @@ func (h *packetHandlerMap) Close(e error) {
h.mutex.Unlock()
wg.Wait()
}
func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
var token protocol.StatelessResetToken
if h.statelessResetHasher == nil {
// Return a random stateless reset token.
// This token will be sent in the server's transport parameters.
// By using a random token, an off-path attacker won't be able to disrupt the connection.
rand.Read(token[:])
return token
}
h.statelessResetMutex.Lock()
h.statelessResetHasher.Write(connID.Bytes())
copy(token[:], h.statelessResetHasher.Sum(nil))
h.statelessResetHasher.Reset()
h.statelessResetMutex.Unlock()
return token
}

View File

@@ -5,8 +5,8 @@ import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/exp/rand"
"math/rand/v2"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/handshake"
@@ -18,12 +18,13 @@ import (
var errNothingToPack = errors.New("nothing to pack")
type packer interface {
PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error)
PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error)
MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (*coalescedPacket, error)
PackAckOnlyPacket(maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
AppendPacket(_ *packetBuffer, maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, error)
PackPTOProbePacket(_ protocol.EncryptionLevel, _ protocol.ByteCount, addPingIfEmpty bool, now time.Time, v protocol.Version) (*coalescedPacket, error)
PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
PackPathProbePacket(protocol.ConnectionID, []ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
SetToken([]byte)
@@ -56,6 +57,7 @@ type shortHeaderPacket struct {
Ack *wire.AckFrame
Length protocol.ByteCount
IsPathMTUProbePacket bool
IsPathProbePacket bool
// used for logging
DestConnID protocol.ConnectionID
@@ -106,12 +108,11 @@ type sealingManager interface {
type frameSource interface {
HasData() bool
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount)
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount)
Append([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount, time.Time, protocol.Version) ([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount)
}
type ackFrameSource interface {
GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame
GetAckFrame(_ protocol.EncryptionLevel, now time.Time, onlyIfQueued bool) *wire.AckFrame
}
type packetPacker struct {
@@ -121,8 +122,8 @@ type packetPacker struct {
perspective protocol.Perspective
cryptoSetup sealingManager
initialStream cryptoStream
handshakeStream cryptoStream
initialStream *cryptoStream
handshakeStream *cryptoStream
token []byte
@@ -141,7 +142,7 @@ var _ packer = &packetPacker{}
func newPacketPacker(
srcConnID protocol.ConnectionID,
getDestConnID func() protocol.ConnectionID,
initialStream, handshakeStream cryptoStream,
initialStream, handshakeStream *cryptoStream,
packetNumberManager packetNumberManager,
retransmissionQueue *retransmissionQueue,
cryptoSetup sealingManager,
@@ -150,7 +151,7 @@ func newPacketPacker(
datagramQueue *datagramQueue,
perspective protocol.Perspective,
) *packetPacker {
var b [8]byte
var b [16]byte
_, _ = crand.Read(b[:])
return &packetPacker{
@@ -164,7 +165,7 @@ func newPacketPacker(
perspective: perspective,
framer: framer,
acks: acks,
rand: *rand.New(rand.NewSource(binary.BigEndian.Uint64(b[:]))),
rand: *rand.New(rand.NewPCG(binary.BigEndian.Uint64(b[:8]), binary.BigEndian.Uint64(b[8:]))),
pnManager: packetNumberManager,
}
}
@@ -269,17 +270,17 @@ func (p *packetPacker) packConnectionClose(
if sealers[i] == nil {
continue
}
var paddingLen protocol.ByteCount
if encLevel == protocol.EncryptionInitial {
paddingLen = p.initialPaddingLen(payloads[i].frames, size, maxPacketSize)
}
if encLevel == protocol.Encryption1RTT {
shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, maxPacketSize, sealers[i], false, v)
shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], 0, maxPacketSize, sealers[i], false, v)
if err != nil {
return nil, err
}
packet.shortHdrPacket = &shp
} else {
var paddingLen protocol.ByteCount
if encLevel == protocol.EncryptionInitial {
paddingLen = p.initialPaddingLen(payloads[i].frames, size, maxPacketSize)
}
longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], v)
if err != nil {
return nil, err
@@ -328,7 +329,7 @@ func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, currentSize,
// PackCoalescedPacket packs a new packet.
// It packs an Initial / Handshake if there is data to send in these packet number spaces.
// It should only be called before the handshake is confirmed.
func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxSize protocol.ByteCount, now time.Time, v protocol.Version) (*coalescedPacket, error) {
var (
initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader
initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload
@@ -342,7 +343,15 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
}
var size protocol.ByteCount
if initialSealer != nil {
initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), protocol.EncryptionInitial, onlyAck, true, v)
initialHdr, initialPayload = p.maybeGetCryptoPacket(
maxSize-protocol.ByteCount(initialSealer.Overhead()),
protocol.EncryptionInitial,
now,
false,
onlyAck,
true,
v,
)
if initialPayload.length > 0 {
size += p.longHeaderPacketLength(initialHdr, initialPayload, v) + protocol.ByteCount(initialSealer.Overhead())
}
@@ -350,14 +359,22 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
// Add a Handshake packet.
var handshakeSealer sealer
if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) {
if (onlyAck && size == 0) || (!onlyAck && size < maxSize-protocol.MinCoalescedPacketSize) {
var err error
handshakeSealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
return nil, err
}
if handshakeSealer != nil {
handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), protocol.EncryptionHandshake, onlyAck, size == 0, v)
handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(
maxSize-size-protocol.ByteCount(handshakeSealer.Overhead()),
protocol.EncryptionHandshake,
now,
false,
onlyAck,
size == 0,
v,
)
if handshakePayload.length > 0 {
s := p.longHeaderPacketLength(handshakeHdr, handshakePayload, v) + protocol.ByteCount(handshakeSealer.Overhead())
size += s
@@ -370,7 +387,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
var oneRTTSealer handshake.ShortHeaderSealer
var connID protocol.ConnectionID
var kp protocol.KeyPhaseBit
if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) {
if (onlyAck && size == 0) || (!onlyAck && size < maxSize-protocol.MinCoalescedPacketSize) {
var err error
oneRTTSealer, err = p.cryptoSetup.Get1RTTSealer()
if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
@@ -381,7 +398,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
connID = p.getDestConnID()
oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
hdrLen := wire.ShortHeaderLen(connID, oneRTTPacketNumberLen)
oneRTTPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, hdrLen, maxPacketSize-size, onlyAck, size == 0, v)
oneRTTPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, hdrLen, maxSize-size, onlyAck, size == 0, now, v)
if oneRTTPayload.length > 0 {
size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, oneRTTPayload) + protocol.ByteCount(oneRTTSealer.Overhead())
}
@@ -392,7 +409,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
return nil, err
}
if zeroRTTSealer != nil {
zeroRTTHdr, zeroRTTPayload = p.maybeGetAppDataPacketFor0RTT(zeroRTTSealer, maxPacketSize-size, v)
zeroRTTHdr, zeroRTTPayload = p.maybeGetAppDataPacketFor0RTT(zeroRTTSealer, maxSize-size, now, v)
if zeroRTTPayload.length > 0 {
size += p.longHeaderPacketLength(zeroRTTHdr, zeroRTTPayload, v) + protocol.ByteCount(zeroRTTSealer.Overhead())
}
@@ -410,7 +427,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
longHdrPackets: make([]*longHeaderPacket, 0, 3),
}
if initialPayload.length > 0 {
padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize)
padding := p.initialPaddingLen(initialPayload.frames, size, maxSize)
cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v)
if err != nil {
return nil, err
@@ -431,7 +448,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
}
packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket)
} else if oneRTTPayload.length > 0 {
shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, maxPacketSize, oneRTTSealer, false, v)
shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, maxSize, oneRTTSealer, false, v)
if err != nil {
return nil, err
}
@@ -442,19 +459,25 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
// PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space.
// It should be called after the handshake is confirmed.
func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
func (p *packetPacker) PackAckOnlyPacket(maxSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
buf := getPacketBuffer()
packet, err := p.appendPacket(buf, true, maxPacketSize, v)
packet, err := p.appendPacket(buf, true, maxSize, now, v)
return packet, buf, err
}
// AppendPacket packs a packet in the application data packet number space.
// It should be called after the handshake is confirmed.
func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) {
return p.appendPacket(buf, false, maxPacketSize, v)
func (p *packetPacker) AppendPacket(buf *packetBuffer, maxSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, error) {
return p.appendPacket(buf, false, maxSize, now, v)
}
func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) {
func (p *packetPacker) appendPacket(
buf *packetBuffer,
onlyAck bool,
maxPacketSize protocol.ByteCount,
now time.Time,
v protocol.Version,
) (shortHeaderPacket, error) {
sealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return shortHeaderPacket{}, err
@@ -462,7 +485,7 @@ func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSi
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
connID := p.getDestConnID()
hdrLen := wire.ShortHeaderLen(connID, pnLen)
pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, maxPacketSize, onlyAck, true, v)
pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, maxPacketSize, onlyAck, true, now, v)
if pl.length == 0 {
return shortHeaderPacket{}, errNothingToPack
}
@@ -471,9 +494,16 @@ func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSi
return p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v)
}
func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.Version) (*wire.ExtendedHeader, payload) {
func (p *packetPacker) maybeGetCryptoPacket(
maxPacketSize protocol.ByteCount,
encLevel protocol.EncryptionLevel,
now time.Time,
addPingIfEmpty bool,
onlyAck, ackAllowed bool,
v protocol.Version,
) (*wire.ExtendedHeader, payload) {
if onlyAck {
if ack := p.acks.GetAckFrame(encLevel, true); ack != nil {
if ack := p.acks.GetAckFrame(encLevel, now, true); ack != nil {
return p.getLongHeader(encLevel, v), payload{
ack: ack,
length: ack.Length(v),
@@ -482,32 +512,33 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en
return nil, payload{}
}
var s cryptoStream
var handler ackhandler.FrameHandler
var hasRetransmission bool
var s *cryptoStream
//nolint:exhaustive // Initial and Handshake are the only two encryption levels here.
switch encLevel {
case protocol.EncryptionInitial:
s = p.initialStream
handler = p.retransmissionQueue.InitialAckHandler()
hasRetransmission = p.retransmissionQueue.HasInitialData()
case protocol.EncryptionHandshake:
s = p.handshakeStream
handler = p.retransmissionQueue.HandshakeAckHandler()
hasRetransmission = p.retransmissionQueue.HasHandshakeData()
}
hasData := s.HasData()
handler := p.retransmissionQueue.AckHandler(encLevel)
hasRetransmission := p.retransmissionQueue.HasData(encLevel)
var ack *wire.AckFrame
if ackAllowed {
ack = p.acks.GetAckFrame(encLevel, !hasRetransmission && !hasData)
ack = p.acks.GetAckFrame(encLevel, now, !hasRetransmission && !hasData)
}
var pl payload
if !hasData && !hasRetransmission && ack == nil {
// nothing to send
return nil, payload{}
if !addPingIfEmpty {
// nothing to send
return nil, payload{}
}
ping := &wire.PingFrame{}
pl.frames = append(pl.frames, ackhandler.Frame{Frame: ping, Handler: emptyHandler{}})
pl.length += ping.Length(v)
}
var pl payload
if ack != nil {
pl.ack = ack
pl.length = ack.Length(v)
@@ -517,49 +548,54 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en
maxPacketSize -= hdr.GetLength(v)
if hasRetransmission {
for {
var f ackhandler.Frame
//nolint:exhaustive // 0-RTT packets can't contain any retransmission.s
switch encLevel {
case protocol.EncryptionInitial:
f.Frame = p.retransmissionQueue.GetInitialFrame(maxPacketSize, v)
f.Handler = p.retransmissionQueue.InitialAckHandler()
case protocol.EncryptionHandshake:
f.Frame = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize, v)
f.Handler = p.retransmissionQueue.HandshakeAckHandler()
}
if f.Frame == nil {
frame := p.retransmissionQueue.GetFrame(encLevel, maxPacketSize, v)
if frame == nil {
break
}
pl.frames = append(pl.frames, f)
frameLen := f.Frame.Length(v)
pl.frames = append(pl.frames, ackhandler.Frame{
Frame: frame,
Handler: p.retransmissionQueue.AckHandler(encLevel),
})
frameLen := frame.Length(v)
pl.length += frameLen
maxPacketSize -= frameLen
}
} else if s.HasData() {
cf := s.PopCryptoFrame(maxPacketSize)
pl.frames = []ackhandler.Frame{{Frame: cf, Handler: handler}}
pl.frames = append(pl.frames, ackhandler.Frame{Frame: cf, Handler: handler})
pl.length += cf.Length(v)
}
return hdr, pl
}
func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount, v protocol.Version) (*wire.ExtendedHeader, payload) {
func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxSize protocol.ByteCount, now time.Time, v protocol.Version) (*wire.ExtendedHeader, payload) {
if p.perspective != protocol.PerspectiveClient {
return nil, payload{}
}
hdr := p.getLongHeader(protocol.Encryption0RTT, v)
maxPayloadSize := maxPacketSize - hdr.GetLength(v) - protocol.ByteCount(sealer.Overhead())
return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false, v)
maxPayloadSize := maxSize - hdr.GetLength(v) - protocol.ByteCount(sealer.Overhead())
return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false, now, v)
}
func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload {
func (p *packetPacker) maybeGetShortHeaderPacket(
sealer handshake.ShortHeaderSealer,
hdrLen, maxPacketSize protocol.ByteCount,
onlyAck, ackAllowed bool,
now time.Time,
v protocol.Version,
) payload {
maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead())
return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed, v)
return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed, now, v)
}
func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload {
pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v)
func (p *packetPacker) maybeGetAppDataPacket(
maxPayloadSize protocol.ByteCount,
onlyAck, ackAllowed bool,
now time.Time,
v protocol.Version,
) payload {
pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, now, v)
// check if we have anything to send
if len(pl.frames) == 0 && len(pl.streamFrames) == 0 {
@@ -581,21 +617,26 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount,
return pl
}
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload {
func (p *packetPacker) composeNextPacket(
maxPayloadSize protocol.ByteCount,
onlyAck, ackAllowed bool,
now time.Time,
v protocol.Version,
) payload {
if onlyAck {
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil {
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, now, true); ack != nil {
return payload{ack: ack, length: ack.Length(v)}
}
return payload{}
}
hasData := p.framer.HasData()
hasRetransmission := p.retransmissionQueue.HasAppData()
hasRetransmission := p.retransmissionQueue.HasData(protocol.Encryption1RTT)
var hasAck bool
var pl payload
if ackAllowed {
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData); ack != nil {
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, now, !hasRetransmission && !hasData); ack != nil {
pl.ack = ack
pl.length += ack.Length(v)
hasAck = true
@@ -605,7 +646,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
if p.datagramQueue != nil {
if f := p.datagramQueue.Peek(); f != nil {
size := f.Length(v)
if size <= maxFrameSize-pl.length { // DATAGRAM frame fits
if size <= maxPayloadSize-pl.length { // DATAGRAM frame fits
pl.frames = append(pl.frames, ackhandler.Frame{Frame: f})
pl.length += size
p.datagramQueue.Pop()
@@ -625,15 +666,15 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
if hasRetransmission {
for {
remainingLen := maxFrameSize - pl.length
remainingLen := maxPayloadSize - pl.length
if remainingLen < protocol.MinStreamFrameSize {
break
}
f := p.retransmissionQueue.GetAppDataFrame(remainingLen, v)
f := p.retransmissionQueue.GetFrame(protocol.Encryption1RTT, remainingLen, v)
if f == nil {
break
}
pl.frames = append(pl.frames, ackhandler.Frame{Frame: f, Handler: p.retransmissionQueue.AppDataAckHandler()})
pl.frames = append(pl.frames, ackhandler.Frame{Frame: f, Handler: p.retransmissionQueue.AckHandler(protocol.Encryption1RTT)})
pl.length += f.Length(v)
}
}
@@ -641,51 +682,37 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
if hasData {
var lengthAdded protocol.ByteCount
startLen := len(pl.frames)
pl.frames, lengthAdded = p.framer.AppendControlFrames(pl.frames, maxFrameSize-pl.length, v)
pl.frames, pl.streamFrames, lengthAdded = p.framer.Append(pl.frames, pl.streamFrames, maxPayloadSize-pl.length, now, v)
pl.length += lengthAdded
// add handlers for the control frames that were added
for i := startLen; i < len(pl.frames); i++ {
if pl.frames[i].Handler != nil {
continue
}
switch pl.frames[i].Frame.(type) {
case *wire.PathChallengeFrame, *wire.PathResponseFrame:
// Path probing is currently not supported, therefore we don't need to set the OnAcked callback yet.
// PATH_CHALLENGE and PATH_RESPONSE are never retransmitted.
default:
pl.frames[i].Handler = p.retransmissionQueue.AppDataAckHandler()
// we might be packing a 0-RTT packet, but we need to use the 1-RTT ack handler anyway
pl.frames[i].Handler = p.retransmissionQueue.AckHandler(protocol.Encryption1RTT)
}
}
pl.streamFrames, lengthAdded = p.framer.AppendStreamFrames(pl.streamFrames, maxFrameSize-pl.length, v)
pl.length += lengthAdded
}
return pl
}
func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
func (p *packetPacker) PackPTOProbePacket(
encLevel protocol.EncryptionLevel,
maxPacketSize protocol.ByteCount,
addPingIfEmpty bool,
now time.Time,
v protocol.Version,
) (*coalescedPacket, error) {
if encLevel == protocol.Encryption1RTT {
s, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return nil, err
}
kp := s.KeyPhase()
connID := p.getDestConnID()
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
hdrLen := wire.ShortHeaderLen(connID, pnLen)
pl := p.maybeGetAppDataPacket(maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true, v)
if pl.length == 0 {
return nil, nil
}
buffer := getPacketBuffer()
packet := &coalescedPacket{buffer: buffer}
shp, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, maxPacketSize, s, false, v)
if err != nil {
return nil, err
}
packet.shortHdrPacket = &shp
return packet, nil
return p.packPTOProbePacket1RTT(maxPacketSize, addPingIfEmpty, now, v)
}
var hdr *wire.ExtendedHeader
var pl payload
var sealer handshake.LongHeaderSealer
//nolint:exhaustive // Probe packets are never sent for 0-RTT.
switch encLevel {
@@ -695,18 +722,24 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, m
if err != nil {
return nil, err
}
hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v)
case protocol.EncryptionHandshake:
var err error
sealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil {
return nil, err
}
hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true, v)
default:
panic("unknown encryption level")
}
hdr, pl := p.maybeGetCryptoPacket(
maxPacketSize-protocol.ByteCount(sealer.Overhead()),
encLevel,
now,
addPingIfEmpty,
false,
true,
v,
)
if pl.length == 0 {
return nil, nil
}
@@ -726,6 +759,34 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, m
return packet, nil
}
func (p *packetPacker) packPTOProbePacket1RTT(maxPacketSize protocol.ByteCount, addPingIfEmpty bool, now time.Time, v protocol.Version) (*coalescedPacket, error) {
s, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return nil, err
}
kp := s.KeyPhase()
connID := p.getDestConnID()
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
hdrLen := wire.ShortHeaderLen(connID, pnLen)
pl := p.maybeGetAppDataPacket(maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true, now, v)
if pl.length == 0 {
if !addPingIfEmpty {
return nil, nil
}
ping := &wire.PingFrame{}
pl.frames = append(pl.frames, ackhandler.Frame{Frame: ping, Handler: emptyHandler{}})
pl.length += ping.Length(v)
}
buffer := getPacketBuffer()
packet := &coalescedPacket{buffer: buffer}
shp, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, maxPacketSize, s, false, v)
if err != nil {
return nil, err
}
packet.shortHdrPacket = &shp
return packet, nil
}
func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
pl := payload{
frames: []ackhandler.Frame{ping},
@@ -744,6 +805,30 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B
return packet, buffer, err
}
func (p *packetPacker) PackPathProbePacket(connID protocol.ConnectionID, frames []ackhandler.Frame, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
buf := getPacketBuffer()
s, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return shortHeaderPacket{}, nil, err
}
var l protocol.ByteCount
for _, f := range frames {
l += f.Frame.Length(v)
}
payload := payload{
frames: frames,
length: l,
}
padding := protocol.MinInitialPacketSize - p.shortHeaderPacketLength(connID, pnLen, payload) - protocol.ByteCount(s.Overhead())
packet, err := p.appendShortHeaderPacket(buf, connID, pn, pnLen, s.KeyPhase(), payload, padding, protocol.MinInitialPacketSize, s, false, v)
if err != nil {
return shortHeaderPacket{}, nil, err
}
packet.IsPathProbePacket = true
return packet, buf, err
}
func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.Version) *wire.ExtendedHeader {
pn, pnLen := p.pnManager.PeekPacketNumber(encLevel)
hdr := &wire.ExtendedHeader{
@@ -909,3 +994,10 @@ func (p *packetPacker) encryptPacket(raw []byte, sealer sealer, pn protocol.Pack
func (p *packetPacker) SetToken(token []byte) {
p.token = token
}
type emptyHandler struct{}
var _ ackhandler.FrameHandler = emptyHandler{}
func (emptyHandler) OnAcked(wire.Frame) {}
func (emptyHandler) OnLost(wire.Frame) {}

View File

@@ -1,7 +1,6 @@
package quic
import (
"bytes"
"fmt"
"time"
@@ -53,7 +52,7 @@ func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetU
// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
// If any other error occurred when parsing the header, the error is of type headerParseError.
// If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error) {
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
var encLevel protocol.EncryptionLevel
var extHdr *wire.ExtendedHeader
var decrypted []byte
@@ -65,7 +64,7 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
@@ -75,7 +74,7 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
@@ -85,7 +84,7 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
@@ -125,8 +124,8 @@ func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (prot
return pn, pnLen, kp, decrypted, nil
}
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v)
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpackLongHeader(opener, hdr, data)
// If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker
// to gain information about the header encryption.
@@ -187,21 +186,18 @@ func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int
}
// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
extHdr, err := unpackLongHeader(hd, hdr, data, v)
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
extHdr, err := unpackLongHeader(hd, hdr, data)
if err != nil && err != wire.ErrInvalidReservedBits {
return nil, &headerParseError{err: err}
}
return extHdr, err
}
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
r := bytes.NewReader(data)
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
hdrLen := hdr.ParsedLen()
if protocol.ByteCount(len(data)) < hdrLen+4+16 {
//nolint:stylecheck
return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", protocol.ByteCount(len(data))-hdrLen)
return nil, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", protocol.ByteCount(len(data))-hdrLen)
}
// The packet number can be up to 4 bytes long, but we won't know the length until we decrypt it.
// 1. save a copy of the 4 bytes
@@ -214,7 +210,7 @@ func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v proto
data[hdrLen:hdrLen+4],
)
// 3. parse the header (and learn the actual length of the packet number)
extHdr, parseErr := hdr.ParseExtended(r, v)
extHdr, parseErr := hdr.ParseExtended(data)
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, parseErr
}

205
vendor/github.com/quic-go/quic-go/path_manager.go generated vendored Normal file
View File

@@ -0,0 +1,205 @@
package quic
import (
"crypto/rand"
"net"
"slices"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
)
type pathID int64
const invalidPathID pathID = -1
// Maximum number of paths to keep track of.
// If the peer probes another path (before the pathTimeout of an existing path expires),
// this probing attempt is ignored.
const maxPaths = 3
// If no packet is received for a path for pathTimeout,
// the path can be evicted when the peer probes another path.
// This prevents an attacker from churning through paths by duplicating packets and
// sending them with spoofed source addresses.
const pathTimeout = 5 * time.Second
type path struct {
id pathID
addr net.Addr
lastPacketTime time.Time
pathChallenge [8]byte
validated bool
rcvdNonProbing bool
}
type pathManager struct {
nextPathID pathID
// ordered by lastPacketTime, with the most recently used path at the end
paths []*path
getConnID func(pathID) (_ protocol.ConnectionID, ok bool)
retireConnID func(pathID)
logger utils.Logger
}
func newPathManager(
getConnID func(pathID) (_ protocol.ConnectionID, ok bool),
retireConnID func(pathID),
logger utils.Logger,
) *pathManager {
return &pathManager{
paths: make([]*path, 0, maxPaths+1),
getConnID: getConnID,
retireConnID: retireConnID,
logger: logger,
}
}
// Returns a path challenge frame if one should be sent.
// May return nil.
func (pm *pathManager) HandlePacket(
remoteAddr net.Addr,
t time.Time,
pathChallenge *wire.PathChallengeFrame, // may be nil if the packet didn't contain a PATH_CHALLENGE
isNonProbing bool,
) (_ protocol.ConnectionID, _ []ackhandler.Frame, shouldSwitch bool) {
var p *path
for i, path := range pm.paths {
if addrsEqual(path.addr, remoteAddr) {
p = path
p.lastPacketTime = t
// already sent a PATH_CHALLENGE for this path
if isNonProbing {
path.rcvdNonProbing = true
}
if pm.logger.Debug() {
pm.logger.Debugf("received packet for path %s that was already probed, validated: %t", remoteAddr, path.validated)
}
shouldSwitch = path.validated && path.rcvdNonProbing
if i != len(pm.paths)-1 {
// move the path to the end of the list
pm.paths = slices.Delete(pm.paths, i, i+1)
pm.paths = append(pm.paths, p)
}
if pathChallenge == nil {
return protocol.ConnectionID{}, nil, shouldSwitch
}
}
}
if len(pm.paths) >= maxPaths {
if pm.paths[0].lastPacketTime.Add(pathTimeout).After(t) {
if pm.logger.Debug() {
pm.logger.Debugf("received packet for previously unseen path %s, but already have %d paths", remoteAddr, len(pm.paths))
}
return protocol.ConnectionID{}, nil, shouldSwitch
}
// evict the oldest path, if the last packet was received more than pathTimeout ago
pm.retireConnID(pm.paths[0].id)
pm.paths = pm.paths[1:]
}
var pathID pathID
if p != nil {
pathID = p.id
} else {
pathID = pm.nextPathID
}
// previously unseen path, initiate path validation by sending a PATH_CHALLENGE
connID, ok := pm.getConnID(pathID)
if !ok {
pm.logger.Debugf("skipping validation of new path %s since no connection ID is available", remoteAddr)
return protocol.ConnectionID{}, nil, shouldSwitch
}
frames := make([]ackhandler.Frame, 0, 2)
if p == nil {
var pathChallengeData [8]byte
rand.Read(pathChallengeData[:])
p = &path{
id: pm.nextPathID,
addr: remoteAddr,
lastPacketTime: t,
rcvdNonProbing: isNonProbing,
pathChallenge: pathChallengeData,
}
pm.nextPathID++
pm.paths = append(pm.paths, p)
frames = append(frames, ackhandler.Frame{
Frame: &wire.PathChallengeFrame{Data: p.pathChallenge},
Handler: (*pathManagerAckHandler)(pm),
})
pm.logger.Debugf("enqueueing PATH_CHALLENGE for new path %s", remoteAddr)
}
if pathChallenge != nil {
frames = append(frames, ackhandler.Frame{
Frame: &wire.PathResponseFrame{Data: pathChallenge.Data},
Handler: (*pathManagerAckHandler)(pm),
})
}
return connID, frames, shouldSwitch
}
func (pm *pathManager) HandlePathResponseFrame(f *wire.PathResponseFrame) {
for _, p := range pm.paths {
if f.Data == p.pathChallenge {
// path validated
p.validated = true
pm.logger.Debugf("path %s validated", p.addr)
break
}
}
}
// SwitchToPath is called when the connection switches to a new path
func (pm *pathManager) SwitchToPath(addr net.Addr) {
// retire all other paths
for _, path := range pm.paths {
if addrsEqual(path.addr, addr) {
pm.logger.Debugf("switching to path %d (%s)", path.id, addr)
continue
}
pm.retireConnID(path.id)
}
clear(pm.paths)
pm.paths = pm.paths[:0]
}
type pathManagerAckHandler pathManager
var _ ackhandler.FrameHandler = &pathManagerAckHandler{}
// Acknowledging the frame doesn't validate the path, only receiving the PATH_RESPONSE does.
func (pm *pathManagerAckHandler) OnAcked(f wire.Frame) {}
func (pm *pathManagerAckHandler) OnLost(f wire.Frame) {
pc, ok := f.(*wire.PathChallengeFrame)
if !ok {
return
}
for i, path := range pm.paths {
if path.pathChallenge == pc.Data {
pm.paths = slices.Delete(pm.paths, i, i+1)
pm.retireConnID(path.id)
break
}
}
}
func addrsEqual(addr1, addr2 net.Addr) bool {
if addr1 == nil || addr2 == nil {
return false
}
a1, ok1 := addr1.(*net.UDPAddr)
a2, ok2 := addr2.(*net.UDPAddr)
if ok1 && ok2 {
return a1.IP.Equal(a2.IP) && a1.Port == a2.Port
}
return addr1.String() == addr2.String()
}

View File

@@ -0,0 +1,311 @@
package quic
import (
"context"
"crypto/rand"
"errors"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
var (
// ErrPathClosed is returned when trying to switch to a path that has been closed.
ErrPathClosed = errors.New("path closed")
// ErrPathNotValidated is returned when trying to use a path before path probing has completed.
ErrPathNotValidated = errors.New("path not yet validated")
)
var errPathDoesNotExist = errors.New("path does not exist")
// Path is a network path.
type Path struct {
id pathID
pathManager *pathManagerOutgoing
tr *Transport
initialRTT time.Duration
enablePath func()
validated atomic.Bool
abandon chan struct{}
}
func (p *Path) Probe(ctx context.Context) error {
path := p.pathManager.addPath(p, p.enablePath)
p.pathManager.enqueueProbe(p)
nextProbeDur := p.initialRTT
var timer *time.Timer
var timerChan <-chan time.Time
for {
select {
case <-ctx.Done():
return context.Cause(ctx)
case <-path.Validated():
p.validated.Store(true)
return nil
case <-timerChan:
p.pathManager.enqueueProbe(p)
case <-path.ProbeSent():
case <-p.abandon:
return ErrPathClosed
}
if timer != nil {
timer.Stop()
}
timer = time.NewTimer(nextProbeDur)
timerChan = timer.C
nextProbeDur *= 2 // exponential backoff
}
}
// Switch switches the QUIC connection to this path.
// It immediately stops sending on the old path, and sends on this new path.
func (p *Path) Switch() error {
if err := p.pathManager.switchToPath(p.id); err != nil {
switch {
case errors.Is(err, ErrPathNotValidated):
return err
case errors.Is(err, errPathDoesNotExist) && !p.validated.Load():
select {
case <-p.abandon:
return ErrPathClosed
default:
return ErrPathNotValidated
}
default:
return ErrPathClosed
}
}
return nil
}
// Close abandons a path.
// It is not possible to close the path thats currently active.
// After closing, it is not possible to probe this path again.
func (p *Path) Close() error {
select {
case <-p.abandon:
return nil
default:
}
if err := p.pathManager.removePath(p.id); err != nil {
return err
}
close(p.abandon)
return nil
}
type pathOutgoing struct {
pathChallenges [][8]byte // length is implicitly limited by exponential backoff
tr *Transport
isValidated bool
probeSent chan struct{} // receives when a PATH_CHALLENGE is sent
validated chan struct{} // closed when the path the corresponding PATH_RESPONSE is received
enablePath func()
}
func (p *pathOutgoing) ProbeSent() <-chan struct{} { return p.probeSent }
func (p *pathOutgoing) Validated() <-chan struct{} { return p.validated }
type pathManagerOutgoing struct {
getConnID func(pathID) (_ protocol.ConnectionID, ok bool)
retireConnID func(pathID)
scheduleSending func()
mx sync.Mutex
activePath pathID
pathsToProbe []pathID
paths map[pathID]*pathOutgoing
nextPathID pathID
pathToSwitchTo *pathOutgoing
}
func newPathManagerOutgoing(
getConnID func(pathID) (_ protocol.ConnectionID, ok bool),
retireConnID func(pathID),
scheduleSending func(),
) *pathManagerOutgoing {
return &pathManagerOutgoing{
activePath: 0, // at initialization time, we're guaranteed to be using the handshake path
nextPathID: 1,
getConnID: getConnID,
retireConnID: retireConnID,
scheduleSending: scheduleSending,
paths: make(map[pathID]*pathOutgoing, 4),
}
}
func (pm *pathManagerOutgoing) addPath(p *Path, enablePath func()) *pathOutgoing {
pm.mx.Lock()
defer pm.mx.Unlock()
// path might already exist, and just being re-probed
if existingPath, ok := pm.paths[p.id]; ok {
existingPath.validated = make(chan struct{})
return existingPath
}
path := &pathOutgoing{
tr: p.tr,
probeSent: make(chan struct{}, 1),
validated: make(chan struct{}),
enablePath: enablePath,
}
pm.paths[p.id] = path
return path
}
func (pm *pathManagerOutgoing) enqueueProbe(p *Path) {
pm.mx.Lock()
pm.pathsToProbe = append(pm.pathsToProbe, p.id)
pm.mx.Unlock()
pm.scheduleSending()
}
func (pm *pathManagerOutgoing) removePath(id pathID) error {
if err := pm.removePathImpl(id); err != nil {
return err
}
pm.scheduleSending()
return nil
}
func (pm *pathManagerOutgoing) removePathImpl(id pathID) error {
pm.mx.Lock()
defer pm.mx.Unlock()
if id == pm.activePath {
return errors.New("cannot close active path")
}
p, ok := pm.paths[id]
if !ok {
return nil
}
if len(p.pathChallenges) > 0 {
pm.retireConnID(id)
}
delete(pm.paths, id)
return nil
}
func (pm *pathManagerOutgoing) switchToPath(id pathID) error {
pm.mx.Lock()
defer pm.mx.Unlock()
p, ok := pm.paths[id]
if !ok {
return errPathDoesNotExist
}
if !p.isValidated {
return ErrPathNotValidated
}
pm.pathToSwitchTo = p
pm.activePath = id
return nil
}
func (pm *pathManagerOutgoing) NewPath(t *Transport, initialRTT time.Duration, enablePath func()) *Path {
pm.mx.Lock()
defer pm.mx.Unlock()
id := pm.nextPathID
pm.nextPathID++
return &Path{
pathManager: pm,
id: id,
tr: t,
enablePath: enablePath,
initialRTT: initialRTT,
abandon: make(chan struct{}),
}
}
func (pm *pathManagerOutgoing) NextPathToProbe() (_ protocol.ConnectionID, _ ackhandler.Frame, _ *Transport, hasPath bool) {
pm.mx.Lock()
defer pm.mx.Unlock()
var p *pathOutgoing
id := invalidPathID
for _, pID := range pm.pathsToProbe {
var ok bool
p, ok = pm.paths[pID]
if ok {
id = pID
break
}
// if the path doesn't exist in the map, it might have been abandoned
pm.pathsToProbe = pm.pathsToProbe[1:]
}
if id == invalidPathID {
return protocol.ConnectionID{}, ackhandler.Frame{}, nil, false
}
connID, ok := pm.getConnID(id)
if !ok {
return protocol.ConnectionID{}, ackhandler.Frame{}, nil, false
}
var b [8]byte
_, _ = rand.Read(b[:])
p.pathChallenges = append(p.pathChallenges, b)
pm.pathsToProbe = pm.pathsToProbe[1:]
p.enablePath()
select {
case p.probeSent <- struct{}{}:
default:
}
frame := ackhandler.Frame{
Frame: &wire.PathChallengeFrame{Data: b},
Handler: (*pathManagerOutgoingAckHandler)(pm),
}
return connID, frame, p.tr, true
}
func (pm *pathManagerOutgoing) HandlePathResponseFrame(f *wire.PathResponseFrame) {
pm.mx.Lock()
defer pm.mx.Unlock()
for _, p := range pm.paths {
if slices.Contains(p.pathChallenges, f.Data) {
// path validated
if !p.isValidated {
// make sure that duplicate PATH_RESPONSE frames are ignored
p.isValidated = true
p.pathChallenges = nil
close(p.validated)
}
break
}
}
}
func (pm *pathManagerOutgoing) ShouldSwitchPath() (*Transport, bool) {
pm.mx.Lock()
defer pm.mx.Unlock()
if pm.pathToSwitchTo == nil {
return nil, false
}
p := pm.pathToSwitchTo
pm.pathToSwitchTo = nil
return p.tr, true
}
type pathManagerOutgoingAckHandler pathManagerOutgoing
var _ ackhandler.FrameHandler = &pathManagerOutgoingAckHandler{}
// OnAcked is called when the PATH_CHALLENGE is acked.
// This doesn't validate the path, only receiving the PATH_RESPONSE does.
func (pm *pathManagerOutgoingAckHandler) OnAcked(wire.Frame) {}
func (pm *pathManagerOutgoingAckHandler) OnLost(wire.Frame) {}

View File

@@ -31,7 +31,7 @@ func NewReader(r io.Reader) Reader {
func (r *byteReader) ReadByte() (byte, error) {
var b [1]byte
n, err := r.Reader.Read(b[:])
n, err := r.Read(b[:])
if n == 1 && err == io.EOF {
err = nil
}
@@ -63,6 +63,6 @@ func NewWriter(w io.Writer) Writer {
}
func (w *byteWriter) WriteByte(c byte) error {
_, err := w.Writer.Write([]byte{c})
_, err := w.Write([]byte{c})
return err
}

View File

@@ -125,17 +125,18 @@ func AppendWithLen(b []byte, i uint64, length int) []byte {
if l > length {
panic(fmt.Sprintf("cannot encode %d in %d bytes", i, length))
}
if length == 2 {
switch length {
case 2:
b = append(b, 0b01000000)
} else if length == 4 {
case 4:
b = append(b, 0b10000000)
} else if length == 8 {
case 8:
b = append(b, 0b11000000)
}
for j := 1; j < length-l; j++ {
for range length - l - 1 {
b = append(b, 0)
}
for j := 0; j < l; j++ {
for j := range l {
b = append(b, uint8(i>>(8*(l-1-j))))
}
return b

View File

@@ -6,6 +6,7 @@ import (
"sync"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/flowcontrol"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
@@ -16,10 +17,9 @@ import (
type receiveStreamI interface {
ReceiveStream
handleStreamFrame(*wire.StreamFrame) error
handleResetStreamFrame(*wire.ResetStreamFrame) error
handleStreamFrame(*wire.StreamFrame, time.Time) error
handleResetStreamFrame(*wire.ResetStreamFrame, time.Time) error
closeForShutdown(error)
getWindowUpdate() protocol.ByteCount
}
type receiveStream struct {
@@ -37,6 +37,9 @@ type receiveStream struct {
readPosInFrame int
currentFrameIsLast bool // is the currentFrame the last frame on this stream
queuedStopSending bool
queuedMaxStreamData bool
// Set once we read the io.EOF or the cancellation error.
// Note that for local cancellations, this doesn't necessarily mean that we know the final offset yet.
errorRead bool
@@ -54,8 +57,9 @@ type receiveStream struct {
}
var (
_ ReceiveStream = &receiveStream{}
_ receiveStreamI = &receiveStream{}
_ ReceiveStream = &receiveStream{}
_ receiveStreamI = &receiveStream{}
_ streamControlFrameGetter = &receiveStream{}
)
func newReceiveStream(
@@ -87,13 +91,19 @@ func (s *receiveStream) Read(p []byte) (int, error) {
defer func() { <-s.readOnce }()
s.mutex.Lock()
n, err := s.readImpl(p)
queuedStreamWindowUpdate, queuedConnWindowUpdate, n, err := s.readImpl(p)
completed := s.isNewlyCompleted()
s.mutex.Unlock()
if completed {
s.sender.onStreamCompleted(s.streamID)
}
if queuedStreamWindowUpdate {
s.sender.onHasStreamControlFrame(s.streamID, s)
}
if queuedConnWindowUpdate {
s.sender.onHasConnectionData()
}
return n, err
}
@@ -118,17 +128,17 @@ func (s *receiveStream) isNewlyCompleted() bool {
return false
}
func (s *receiveStream) readImpl(p []byte) (int, error) {
func (s *receiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnWindowUpdate bool, _ int, _ error) {
if s.currentFrameIsLast && s.currentFrame == nil {
s.errorRead = true
return 0, io.EOF
return false, false, 0, io.EOF
}
if s.cancelledRemotely || s.cancelledLocally {
s.errorRead = true
return 0, s.cancelErr
return false, false, 0, s.cancelErr
}
if s.closeForShutdownErr != nil {
return 0, s.closeForShutdownErr
return false, false, 0, s.closeForShutdownErr
}
var bytesRead int
@@ -138,23 +148,23 @@ func (s *receiveStream) readImpl(p []byte) (int, error) {
s.dequeueNextFrame()
}
if s.currentFrame == nil && bytesRead > 0 {
return bytesRead, s.closeForShutdownErr
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.closeForShutdownErr
}
for {
// Stop waiting on errors
if s.closeForShutdownErr != nil {
return bytesRead, s.closeForShutdownErr
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.closeForShutdownErr
}
if s.cancelledRemotely || s.cancelledLocally {
s.errorRead = true
return 0, s.cancelErr
return hasStreamWindowUpdate, hasConnWindowUpdate, 0, s.cancelErr
}
deadline := s.deadline
if !deadline.IsZero() {
if !time.Now().Before(deadline) {
return bytesRead, errDeadline
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, errDeadline
}
if deadlineTimer == nil {
deadlineTimer = utils.NewTimer()
@@ -184,10 +194,10 @@ func (s *receiveStream) readImpl(p []byte) (int, error) {
}
if bytesRead > len(p) {
return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
}
if s.readPosInFrame > len(s.currentFrame) {
return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame))
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame))
}
m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:])
@@ -197,7 +207,14 @@ func (s *receiveStream) readImpl(p []byte) (int, error) {
// when a RESET_STREAM was received, the flow controller was already
// informed about the final byteOffset for this stream
if !s.cancelledRemotely {
s.flowController.AddBytesRead(protocol.ByteCount(m))
hasStream, hasConn := s.flowController.AddBytesRead(protocol.ByteCount(m))
if hasStream {
s.queuedMaxStreamData = true
hasStreamWindowUpdate = true
}
if hasConn {
hasConnWindowUpdate = true
}
}
if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast {
@@ -206,10 +223,10 @@ func (s *receiveStream) readImpl(p []byte) (int, error) {
s.currentFrameDone()
}
s.errorRead = true
return bytesRead, io.EOF
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, io.EOF
}
}
return bytesRead, nil
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, nil
}
func (s *receiveStream) dequeueNextFrame() {
@@ -225,35 +242,39 @@ func (s *receiveStream) dequeueNextFrame() {
func (s *receiveStream) CancelRead(errorCode StreamErrorCode) {
s.mutex.Lock()
s.cancelReadImpl(errorCode)
queuedNewControlFrame := s.cancelReadImpl(errorCode)
completed := s.isNewlyCompleted()
s.mutex.Unlock()
if queuedNewControlFrame {
s.sender.onHasStreamControlFrame(s.streamID, s)
}
if completed {
s.flowController.Abandon()
s.sender.onStreamCompleted(s.streamID)
}
}
func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) {
func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) (queuedNewControlFrame bool) {
if s.cancelledLocally { // duplicate call to CancelRead
return
return false
}
if s.closeForShutdownErr != nil {
return false
}
s.cancelledLocally = true
if s.errorRead || s.cancelledRemotely {
return
return false
}
s.queuedStopSending = true
s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.signalRead()
s.sender.queueControlFrame(&wire.StopSendingFrame{
StreamID: s.streamID,
ErrorCode: errorCode,
})
return true
}
func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame, now time.Time) error {
s.mutex.Lock()
err := s.handleStreamFrameImpl(frame)
err := s.handleStreamFrameImpl(frame, now)
completed := s.isNewlyCompleted()
s.mutex.Unlock()
@@ -264,9 +285,9 @@ func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
return err
}
func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) error {
func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame, now time.Time) error {
maxOffset := frame.Offset + frame.DataLen()
if err := s.flowController.UpdateHighestReceived(maxOffset, frame.Fin); err != nil {
if err := s.flowController.UpdateHighestReceived(maxOffset, frame.Fin, now); err != nil {
return err
}
if frame.Fin {
@@ -282,9 +303,9 @@ func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) error {
return nil
}
func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) error {
func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame, now time.Time) error {
s.mutex.Lock()
err := s.handleResetStreamFrameImpl(frame)
err := s.handleResetStreamFrameImpl(frame, now)
completed := s.isNewlyCompleted()
s.mutex.Unlock()
@@ -294,11 +315,11 @@ func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) err
return err
}
func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) error {
func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame, now time.Time) error {
if s.closeForShutdownErr != nil {
return nil
}
if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true); err != nil {
if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true, now); err != nil {
return err
}
s.finalOffset = frame.FinalSize
@@ -318,6 +339,29 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame)
return nil
}
func (s *receiveStream) getControlFrame(now time.Time) (_ ackhandler.Frame, ok, hasMore bool) {
s.mutex.Lock()
defer s.mutex.Unlock()
if !s.queuedStopSending && !s.queuedMaxStreamData {
return ackhandler.Frame{}, false, false
}
if s.queuedStopSending {
s.queuedStopSending = false
return ackhandler.Frame{
Frame: &wire.StopSendingFrame{StreamID: s.streamID, ErrorCode: s.cancelErr.ErrorCode},
}, true, s.queuedMaxStreamData
}
s.queuedMaxStreamData = false
return ackhandler.Frame{
Frame: &wire.MaxStreamDataFrame{
StreamID: s.streamID,
MaximumStreamData: s.flowController.GetWindowUpdate(now),
},
}, true, false
}
func (s *receiveStream) SetReadDeadline(t time.Time) error {
s.mutex.Lock()
s.deadline = t
@@ -336,10 +380,6 @@ func (s *receiveStream) closeForShutdown(err error) {
s.signalRead()
}
func (s *receiveStream) getWindowUpdate() protocol.ByteCount {
return s.flowController.GetWindowUpdate()
}
// signalRead performs a non-blocking send on the readChan
func (s *receiveStream) signalRead() {
select {

View File

@@ -9,126 +9,106 @@ import (
"github.com/quic-go/quic-go/internal/wire"
)
type framesToRetransmit struct {
crypto []*wire.CryptoFrame
other []wire.Frame
}
type retransmissionQueue struct {
initial []wire.Frame
initialCryptoData []*wire.CryptoFrame
handshake []wire.Frame
handshakeCryptoData []*wire.CryptoFrame
appData []wire.Frame
initial *framesToRetransmit
handshake *framesToRetransmit
appData framesToRetransmit
}
func newRetransmissionQueue() *retransmissionQueue {
return &retransmissionQueue{}
}
// AddPing queues a ping.
// It is used when a probe packet needs to be sent
func (q *retransmissionQueue) AddPing(encLevel protocol.EncryptionLevel) {
//nolint:exhaustive // Cannot send probe packets for 0-RTT.
switch encLevel {
case protocol.EncryptionInitial:
q.addInitial(&wire.PingFrame{})
case protocol.EncryptionHandshake:
q.addHandshake(&wire.PingFrame{})
case protocol.Encryption1RTT:
q.addAppData(&wire.PingFrame{})
default:
panic("unexpected encryption level")
return &retransmissionQueue{
initial: &framesToRetransmit{},
handshake: &framesToRetransmit{},
}
}
func (q *retransmissionQueue) addInitial(f wire.Frame) {
if cf, ok := f.(*wire.CryptoFrame); ok {
q.initialCryptoData = append(q.initialCryptoData, cf)
if q.initial == nil {
return
}
q.initial = append(q.initial, f)
if cf, ok := f.(*wire.CryptoFrame); ok {
q.initial.crypto = append(q.initial.crypto, cf)
return
}
q.initial.other = append(q.initial.other, f)
}
func (q *retransmissionQueue) addHandshake(f wire.Frame) {
if cf, ok := f.(*wire.CryptoFrame); ok {
q.handshakeCryptoData = append(q.handshakeCryptoData, cf)
if q.handshake == nil {
return
}
q.handshake = append(q.handshake, f)
}
func (q *retransmissionQueue) HasInitialData() bool {
return len(q.initialCryptoData) > 0 || len(q.initial) > 0
}
func (q *retransmissionQueue) HasHandshakeData() bool {
return len(q.handshakeCryptoData) > 0 || len(q.handshake) > 0
}
func (q *retransmissionQueue) HasAppData() bool {
return len(q.appData) > 0
if cf, ok := f.(*wire.CryptoFrame); ok {
q.handshake.crypto = append(q.handshake.crypto, cf)
return
}
q.handshake.other = append(q.handshake.other, f)
}
func (q *retransmissionQueue) addAppData(f wire.Frame) {
if _, ok := f.(*wire.StreamFrame); ok {
switch f := f.(type) {
case *wire.StreamFrame:
panic("STREAM frames are handled with their respective streams.")
case *wire.CryptoFrame:
q.appData.crypto = append(q.appData.crypto, f)
default:
q.appData.other = append(q.appData.other, f)
}
q.appData = append(q.appData, f)
}
func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
if len(q.initialCryptoData) > 0 {
f := q.initialCryptoData[0]
func (q *retransmissionQueue) HasData(encLevel protocol.EncryptionLevel) bool {
//nolint:exhaustive // 0-RTT data is retransmitted in 1-RTT packets.
switch encLevel {
case protocol.EncryptionInitial:
return q.initial != nil &&
(len(q.initial.crypto) > 0 || len(q.initial.other) > 0)
case protocol.EncryptionHandshake:
return q.handshake != nil &&
(len(q.handshake.crypto) > 0 || len(q.handshake.other) > 0)
case protocol.Encryption1RTT:
return len(q.appData.crypto) > 0 || len(q.appData.other) > 0
}
return false
}
func (q *retransmissionQueue) GetFrame(encLevel protocol.EncryptionLevel, maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
var r *framesToRetransmit
//nolint:exhaustive // 0-RTT data is retransmitted in 1-RTT packets.
switch encLevel {
case protocol.EncryptionInitial:
r = q.initial
case protocol.EncryptionHandshake:
r = q.handshake
case protocol.Encryption1RTT:
r = &q.appData
}
if r == nil {
return nil
}
if len(r.crypto) > 0 {
f := r.crypto[0]
newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v)
if newFrame == nil && !needsSplit { // the whole frame fits
q.initialCryptoData = q.initialCryptoData[1:]
r.crypto = r.crypto[1:]
return f
}
if newFrame != nil { // frame was split. Leave the original frame in the queue.
return newFrame
}
}
if len(q.initial) == 0 {
if len(r.other) == 0 {
return nil
}
f := q.initial[0]
f := r.other[0]
if f.Length(v) > maxLen {
return nil
}
q.initial = q.initial[1:]
return f
}
func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
if len(q.handshakeCryptoData) > 0 {
f := q.handshakeCryptoData[0]
newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v)
if newFrame == nil && !needsSplit { // the whole frame fits
q.handshakeCryptoData = q.handshakeCryptoData[1:]
return f
}
if newFrame != nil { // frame was split. Leave the original frame in the queue.
return newFrame
}
}
if len(q.handshake) == 0 {
return nil
}
f := q.handshake[0]
if f.Length(v) > maxLen {
return nil
}
q.handshake = q.handshake[1:]
return f
}
func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
if len(q.appData) == 0 {
return nil
}
f := q.appData[0]
if f.Length(v) > maxLen {
return nil
}
q.appData = q.appData[1:]
r.other = r.other[1:]
return f
}
@@ -137,25 +117,23 @@ func (q *retransmissionQueue) DropPackets(encLevel protocol.EncryptionLevel) {
switch encLevel {
case protocol.EncryptionInitial:
q.initial = nil
q.initialCryptoData = nil
case protocol.EncryptionHandshake:
q.handshake = nil
q.handshakeCryptoData = nil
default:
panic(fmt.Sprintf("unexpected encryption level: %s", encLevel))
}
}
func (q *retransmissionQueue) InitialAckHandler() ackhandler.FrameHandler {
return (*retransmissionQueueInitialAckHandler)(q)
}
func (q *retransmissionQueue) HandshakeAckHandler() ackhandler.FrameHandler {
return (*retransmissionQueueHandshakeAckHandler)(q)
}
func (q *retransmissionQueue) AppDataAckHandler() ackhandler.FrameHandler {
return (*retransmissionQueueAppDataAckHandler)(q)
func (q *retransmissionQueue) AckHandler(encLevel protocol.EncryptionLevel) ackhandler.FrameHandler {
switch encLevel {
case protocol.EncryptionInitial:
return (*retransmissionQueueInitialAckHandler)(q)
case protocol.EncryptionHandshake:
return (*retransmissionQueueHandshakeAckHandler)(q)
case protocol.Encryption0RTT, protocol.Encryption1RTT:
return (*retransmissionQueueAppDataAckHandler)(q)
}
return nil
}
type retransmissionQueueInitialAckHandler retransmissionQueue

View File

@@ -2,6 +2,7 @@ package quic
import (
"net"
"sync/atomic"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
@@ -10,22 +11,29 @@ import (
// A sendConn allows sending using a simple Write() on a non-connected packet conn.
type sendConn interface {
Write(b []byte, gsoSize uint16, ecn protocol.ECN) error
WriteTo([]byte, net.Addr) error
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
ChangeRemoteAddr(addr net.Addr, info packetInfo)
capabilities() connCapabilities
}
type remoteAddrInfo struct {
addr net.Addr
oob []byte
}
type sconn struct {
rawConn
localAddr net.Addr
remoteAddr net.Addr
localAddr net.Addr
remoteAddrInfo atomic.Pointer[remoteAddrInfo]
logger utils.Logger
packetInfoOOB []byte
// If GSO enabled, and we receive a GSO error for this remote address, GSO is disabled.
gotGSOError bool
// Used to catch the error sometimes returned by the first sendmsg call on Linux,
@@ -49,22 +57,26 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge
// increase oob slice capacity, so we can add the UDP_SEGMENT and ECN control messages without allocating
l := len(oob)
oob = append(oob, make([]byte, 64)...)[:l]
return &sconn{
rawConn: c,
localAddr: localAddr,
remoteAddr: remote,
packetInfoOOB: oob,
logger: logger,
sc := &sconn{
rawConn: c,
localAddr: localAddr,
logger: logger,
}
sc.remoteAddrInfo.Store(&remoteAddrInfo{
addr: remote,
oob: oob,
})
return sc
}
func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error {
err := c.writePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize, ecn)
ai := c.remoteAddrInfo.Load()
err := c.writePacket(p, ai.addr, ai.oob, gsoSize, ecn)
if err != nil && isGSOError(err) {
// disable GSO for future calls
c.gotGSOError = true
if c.logger.Debug() {
c.logger.Debugf("GSO failed when sending to %s", c.remoteAddr)
c.logger.Debugf("GSO failed when sending to %s", ai.addr)
}
// send out the packets one by one
for len(p) > 0 {
@@ -72,7 +84,7 @@ func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error {
if l > int(gsoSize) {
l = int(gsoSize)
}
if err := c.writePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0, ecn); err != nil {
if err := c.writePacket(p[:l], ai.addr, ai.oob, 0, ecn); err != nil {
return err
}
p = p[l:]
@@ -91,6 +103,11 @@ func (c *sconn) writePacket(p []byte, addr net.Addr, oob []byte, gsoSize uint16,
return err
}
func (c *sconn) WriteTo(b []byte, addr net.Addr) error {
_, err := c.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported)
return err
}
func (c *sconn) capabilities() connCapabilities {
capabilities := c.rawConn.capabilities()
if capabilities.GSO {
@@ -99,5 +116,12 @@ func (c *sconn) capabilities() connCapabilities {
return capabilities
}
func (c *sconn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *sconn) ChangeRemoteAddr(addr net.Addr, info packetInfo) {
c.remoteAddrInfo.Store(&remoteAddrInfo{
addr: addr,
oob: info.OOB(),
})
}
func (c *sconn) RemoteAddr() net.Addr { return c.remoteAddrInfo.Load().addr }
func (c *sconn) LocalAddr() net.Addr { return c.localAddr }

View File

@@ -1,9 +1,14 @@
package quic
import "github.com/quic-go/quic-go/internal/protocol"
import (
"net"
"github.com/quic-go/quic-go/internal/protocol"
)
type sender interface {
Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN)
SendProbe(*packetBuffer, net.Addr)
Run() error
WouldBlock() bool
Available() <-chan struct{}
@@ -57,6 +62,10 @@ func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) {
}
}
func (h *sendQueue) SendProbe(p *packetBuffer, addr net.Addr) {
h.conn.WriteTo(p.Data, addr)
}
func (h *sendQueue) WouldBlock() bool {
return len(h.queue) == sendQueueCapacity
}

View File

@@ -18,7 +18,7 @@ type sendStreamI interface {
SendStream
handleStopSendingFrame(*wire.StopSendingFrame)
hasData() bool
popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (frame ackhandler.StreamFrame, ok, hasMore bool)
popStreamFrame(protocol.ByteCount, protocol.Version) (_ ackhandler.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMore bool)
closeForShutdown(error)
updateSendWindow(protocol.ByteCount)
}
@@ -26,7 +26,7 @@ type sendStreamI interface {
type sendStream struct {
mutex sync.Mutex
numOutstandingFrames int64
numOutstandingFrames int64 // outstanding STREAM and RESET_STREAM frames
retransmissionQueue []*wire.StreamFrame
ctx context.Context
@@ -37,8 +37,10 @@ type sendStream struct {
writeOffset protocol.ByteCount
cancelWriteErr error
closeForShutdownErr error
// finalError is the error that is returned by Write.
// It can either be a cancellation error or the shutdown error.
finalError error
queuedResetStreamFrame *wire.ResetStreamFrame
finishedWriting bool // set once Close() is called
finSent bool // set when a STREAM_FRAME with FIN bit has been sent
@@ -46,6 +48,8 @@ type sendStream struct {
// This can happen because the application called CancelWrite,
// or because Write returned the error (for remote cancellations).
cancellationFlagged bool
cancelled bool // both local and remote cancellations
closedForShutdown bool // set by closeForShutdown
completed bool // set when this stream has been reported to the streamSender as completed
dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
@@ -59,8 +63,9 @@ type sendStream struct {
}
var (
_ SendStream = &sendStream{}
_ sendStreamI = &sendStream{}
_ SendStream = &sendStream{}
_ sendStreamI = &sendStream{}
_ streamControlFrameGetter = &sendStream{}
)
func newSendStream(
@@ -102,16 +107,15 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error)
s.mutex.Lock()
defer s.mutex.Unlock()
if s.finalError != nil {
if s.cancelled {
s.cancellationFlagged = true
}
return s.isNewlyCompleted(), 0, s.finalError
}
if s.finishedWriting {
return false, 0, fmt.Errorf("write on closed stream %d", s.streamID)
}
if s.cancelWriteErr != nil {
s.cancellationFlagged = true
return s.isNewlyCompleted(), 0, s.cancelWriteErr
}
if s.closeForShutdownErr != nil {
return false, 0, s.closeForShutdownErr
}
if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
return false, 0, errDeadline
}
@@ -165,14 +169,14 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error)
}
deadlineTimer.Reset(deadline)
}
if s.dataForWriting == nil || s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
if s.dataForWriting == nil || s.finalError != nil {
break
}
}
s.mutex.Unlock()
if !notifiedSender {
s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex
s.sender.onHasStreamData(s.streamID, s) // must be called without holding the mutex
notifiedSender = true
}
if copied {
@@ -194,11 +198,11 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error)
if bytesWritten == len(p) {
return false, bytesWritten, nil
}
if s.closeForShutdownErr != nil {
return false, bytesWritten, s.closeForShutdownErr
} else if s.cancelWriteErr != nil {
s.cancellationFlagged = true
return s.isNewlyCompleted(), bytesWritten, s.cancelWriteErr
if s.finalError != nil {
if s.cancelled {
s.cancellationFlagged = true
}
return s.isNewlyCompleted(), bytesWritten, s.finalError
}
return false, bytesWritten, nil
}
@@ -213,37 +217,37 @@ func (s *sendStream) canBufferStreamFrame() bool {
// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
// maxBytes is the maximum length this frame (including frame header) will have.
func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (af ackhandler.StreamFrame, ok, hasMore bool) {
func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ ackhandler.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMore bool) {
s.mutex.Lock()
f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
f, blocked, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
if f != nil {
s.numOutstandingFrames++
}
s.mutex.Unlock()
if f == nil {
return ackhandler.StreamFrame{}, false, hasMoreData
return ackhandler.StreamFrame{}, blocked, hasMoreData
}
return ackhandler.StreamFrame{
Frame: f,
Handler: (*sendStreamAckHandler)(s),
}, true, hasMoreData
}, blocked, hasMoreData
}
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more data to send */) {
if s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
return nil, false
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMoreData bool) {
if s.finalError != nil {
return nil, nil, false
}
if len(s.retransmissionQueue) > 0 {
f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes, v)
if f != nil || hasMoreRetransmissions {
if f == nil {
return nil, true
return nil, nil, true
}
// We always claim that we have more data to send.
// This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future.
return f, true
return f, nil, true
}
}
@@ -255,41 +259,45 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun
Offset: s.writeOffset,
DataLenPresent: true,
Fin: true,
}, false
}, nil, false
}
return nil, false
return nil, nil, false
}
sendWindow := s.flowController.SendWindowSize()
if sendWindow == 0 {
if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked {
s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{
StreamID: s.streamID,
MaximumStreamData: offset,
})
return nil, false
}
return nil, true
return nil, nil, true
}
f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow, v)
if dataLen := f.DataLen(); dataLen > 0 {
if f == nil {
return nil, nil, hasMoreData
}
if f.DataLen() > 0 {
s.writeOffset += f.DataLen()
s.flowController.AddBytesSent(f.DataLen())
}
var blocked *wire.StreamDataBlockedFrame
// If the entire send window is used, the stream might have become blocked on stream-level flow control.
// This is not guaranteed though, because the stream might also have been blocked on connection-level flow control.
if f.DataLen() == sendWindow && s.flowController.IsNewlyBlocked() {
blocked = &wire.StreamDataBlockedFrame{StreamID: s.streamID, MaximumStreamData: s.writeOffset}
}
f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent
if f.Fin {
s.finSent = true
}
return f, hasMoreData
return f, blocked, hasMoreData
}
func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool) {
if s.nextFrame != nil {
maxDataLen := min(sendWindow, s.nextFrame.MaxDataLen(maxBytes, v))
if maxDataLen == 0 {
return nil, true
}
nextFrame := s.nextFrame
s.nextFrame = nil
maxDataLen := min(sendWindow, nextFrame.MaxDataLen(maxBytes, v))
if nextFrame.DataLen() > maxDataLen {
s.nextFrame = wire.GetStreamFrame()
s.nextFrame.StreamID = s.streamID
@@ -367,7 +375,7 @@ func (s *sendStream) isNewlyCompleted() bool {
return false
}
// We need to keep the stream around until all frames have been sent and acknowledged.
if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 {
if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 || s.queuedResetStreamFrame != nil {
return false
}
// The stream is completed if we sent the FIN.
@@ -379,8 +387,8 @@ func (s *sendStream) isNewlyCompleted() bool {
// 1. the application called CancelWrite, or
// 2. we received a STOP_SENDING, and
// * the application consumed the error via Write, or
// * the application called CLsoe
if s.cancelWriteErr != nil && (s.cancellationFlagged || s.finishedWriting) {
// * the application called Close
if s.cancelled && (s.cancellationFlagged || s.finishedWriting) {
s.completed = true
return true
}
@@ -389,13 +397,13 @@ func (s *sendStream) isNewlyCompleted() bool {
func (s *sendStream) Close() error {
s.mutex.Lock()
if s.closeForShutdownErr != nil {
if s.closedForShutdown || s.finishedWriting {
s.mutex.Unlock()
return nil
}
s.finishedWriting = true
cancelWriteErr := s.cancelWriteErr
if cancelWriteErr != nil {
cancelled := s.cancelled
if cancelled {
s.cancellationFlagged = true
}
completed := s.isNewlyCompleted()
@@ -404,44 +412,60 @@ func (s *sendStream) Close() error {
if completed {
s.sender.onStreamCompleted(s.streamID)
}
if cancelWriteErr != nil {
if cancelled {
return fmt.Errorf("close called for canceled stream %d", s.streamID)
}
s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
s.sender.onHasStreamData(s.streamID, s) // need to send the FIN, must be called without holding the mutex
s.ctxCancel(nil)
return nil
}
func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
s.cancelWriteImpl(errorCode, false)
s.cancelWrite(errorCode, false)
}
func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) {
// cancelWrite cancels the stream
// It is possible to cancel a stream after it has been closed, both locally and remotely.
// This is useful to prevent the retransmission of outstanding stream data.
func (s *sendStream) cancelWrite(errorCode qerr.StreamErrorCode, remote bool) {
s.mutex.Lock()
if !remote {
s.cancellationFlagged = true
}
if s.cancelWriteErr != nil {
if s.closedForShutdown {
s.mutex.Unlock()
return
}
s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote}
s.ctxCancel(s.cancelWriteErr)
if !remote {
s.cancellationFlagged = true
if s.cancelled {
completed := s.isNewlyCompleted()
s.mutex.Unlock()
// The user has called CancelWrite. If the previous cancellation was
// because of a STOP_SENDING, we don't need to flag the error to the
// user anymore.
if completed {
s.sender.onStreamCompleted(s.streamID)
}
return
}
}
if s.cancelled {
s.mutex.Unlock()
return
}
s.cancelled = true
s.finalError = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote}
s.ctxCancel(s.finalError)
s.numOutstandingFrames = 0
s.retransmissionQueue = nil
newlyCompleted := s.isNewlyCompleted()
s.mutex.Unlock()
s.signalWrite()
s.sender.queueControlFrame(&wire.ResetStreamFrame{
s.queuedResetStreamFrame = &wire.ResetStreamFrame{
StreamID: s.streamID,
FinalSize: s.writeOffset,
ErrorCode: errorCode,
})
if newlyCompleted {
s.sender.onStreamCompleted(s.streamID)
}
s.mutex.Unlock()
s.signalWrite()
s.sender.onHasStreamControlFrame(s.streamID, s)
}
func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
@@ -453,12 +477,28 @@ func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
s.mutex.Unlock()
if hasStreamData {
s.sender.onHasStreamData(s.streamID)
s.sender.onHasStreamData(s.streamID, s)
}
}
func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
s.cancelWriteImpl(frame.ErrorCode, true)
s.cancelWrite(frame.ErrorCode, true)
}
func (s *sendStream) getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore bool) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.queuedResetStreamFrame == nil {
return ackhandler.Frame{}, false, false
}
s.numOutstandingFrames++
f := ackhandler.Frame{
Frame: s.queuedResetStreamFrame,
Handler: (*sendStreamResetStreamHandler)(s),
}
s.queuedResetStreamFrame = nil
return f, true, false
}
func (s *sendStream) Context() context.Context {
@@ -478,7 +518,10 @@ func (s *sendStream) SetWriteDeadline(t time.Time) error {
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
func (s *sendStream) closeForShutdown(err error) {
s.mutex.Lock()
s.closeForShutdownErr = err
s.closedForShutdown = true
if s.finalError == nil && !s.finishedWriting {
s.finalError = err
}
s.mutex.Unlock()
s.signalWrite()
}
@@ -499,7 +542,7 @@ func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
sf := f.(*wire.StreamFrame)
sf.PutBack()
s.mutex.Lock()
if s.cancelWriteErr != nil {
if s.cancelled {
s.mutex.Unlock()
return
}
@@ -507,10 +550,10 @@ func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
if s.numOutstandingFrames < 0 {
panic("numOutStandingFrames negative")
}
newlyCompleted := (*sendStream)(s).isNewlyCompleted()
completed := (*sendStream)(s).isNewlyCompleted()
s.mutex.Unlock()
if newlyCompleted {
if completed {
s.sender.onStreamCompleted(s.streamID)
}
}
@@ -518,7 +561,7 @@ func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
sf := f.(*wire.StreamFrame)
s.mutex.Lock()
if s.cancelWriteErr != nil {
if s.cancelled {
s.mutex.Unlock()
return
}
@@ -530,5 +573,31 @@ func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
}
s.mutex.Unlock()
s.sender.onHasStreamData(s.streamID)
s.sender.onHasStreamData(s.streamID, (*sendStream)(s))
}
type sendStreamResetStreamHandler sendStream
var _ ackhandler.FrameHandler = &sendStreamResetStreamHandler{}
func (s *sendStreamResetStreamHandler) OnAcked(wire.Frame) {
s.mutex.Lock()
s.numOutstandingFrames--
if s.numOutstandingFrames < 0 {
panic("numOutStandingFrames negative")
}
completed := (*sendStream)(s).isNewlyCompleted()
s.mutex.Unlock()
if completed {
s.sender.onStreamCompleted(s.streamID)
}
}
func (s *sendStreamResetStreamHandler) OnLost(f wire.Frame) {
s.mutex.Lock()
s.queuedResetStreamFrame = f.(*wire.ResetStreamFrame)
s.numOutstandingFrames--
s.mutex.Unlock()
s.sender.onHasStreamControlFrame(s.streamID, (*sendStream)(s))
}

View File

@@ -17,8 +17,13 @@ import (
"github.com/quic-go/quic-go/logging"
)
// ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close.
var ErrServerClosed = errors.New("quic: server closed")
// ErrServerClosed is returned by the [Listener] or [EarlyListener]'s Accept method after a call to Close.
var ErrServerClosed = errServerClosed{}
type errServerClosed struct{}
func (errServerClosed) Error() string { return "quic: server closed" }
func (errServerClosed) Unwrap() error { return net.ErrClosed }
// packetHandler handles packets
type packetHandler interface {
@@ -56,6 +61,7 @@ type rejectedPacket struct {
// A Listener of QUIC
type baseServer struct {
tr *Transport
disableVersionNegotiation bool
acceptEarlyConns bool
@@ -67,9 +73,9 @@ type baseServer struct {
tokenGenerator *handshake.TokenGenerator
maxTokenAge time.Duration
connIDGenerator ConnectionIDGenerator
connHandler packetHandlerManager
onClose func()
connIDGenerator ConnectionIDGenerator
statelessResetter *statelessResetter
onClose func()
receivedPackets chan receivedPacket
@@ -83,14 +89,14 @@ type baseServer struct {
context.Context,
context.CancelCauseFunc,
sendConn,
connRunner,
*Transport,
protocol.ConnectionID, /* original dest connection ID */
*protocol.ConnectionID, /* retry src connection ID */
protocol.ConnectionID, /* client dest connection ID */
protocol.ConnectionID, /* destination connection ID */
protocol.ConnectionID, /* source connection ID */
ConnectionIDGenerator,
protocol.StatelessResetToken,
*statelessResetter,
*Config,
*tls.Config,
*handshake.TokenGenerator,
@@ -100,15 +106,24 @@ type baseServer struct {
protocol.Version,
) quicConn
closeMx sync.Mutex
errorChan chan struct{} // is closed when the server is closed
closeErr error
running chan struct{} // closed as soon as run() returns
closeMx sync.Mutex
// errorChan is closed when Close is called. This has two effects:
// 1. it cancels handshakes that are still in flight (using CONNECTION_REFUSED) errors
// 2. it stops handling of packets passed to this server
errorChan chan struct{}
// acceptChan is closed when Close returns.
// This only happens once all handshake in flight have either completed and canceled.
// Calls to Accept will first drain the queue of connections that have completed the handshake,
// and then return ErrServerClosed.
stopAccepting chan struct{}
closeErr error
running chan struct{} // closed as soon as run() returns
versionNegotiationQueue chan receivedPacket
invalidTokenQueue chan rejectedPacket
connectionRefusedQueue chan rejectedPacket
retryQueue chan rejectedPacket
handshakingCount sync.WaitGroup
verifySourceAddress func(net.Addr) bool
@@ -131,11 +146,11 @@ func (l *Listener) Accept(ctx context.Context) (Connection, error) {
}
// Close closes the listener.
// Accept will return ErrServerClosed as soon as all connections in the accept queue have been accepted.
// Accept will return [ErrServerClosed] as soon as all connections in the accept queue have been accepted.
// QUIC handshakes that are still in flight will be rejected with a CONNECTION_REFUSED error.
// The effect of closing the listener depends on how it was created:
// * if it was created using Transport.Listen, already established connections will be unaffected
// * if it was created using the Listen convenience method, all established connection will be closed immediately
// - if it was created using [Transport.Listen], already established connections will be unaffected
// - if it was created using the [Listen] convenience method, all established connection will be closed immediately
func (l *Listener) Close() error {
return l.baseServer.Close()
}
@@ -171,7 +186,7 @@ func (l *EarlyListener) Addr() net.Addr {
}
// ListenAddr creates a QUIC server listening on a given address.
// See Listen for more details.
// See [Listen] for more details.
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) {
conn, err := listenUDP(addr)
if err != nil {
@@ -184,7 +199,7 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, er
}).Listen(tlsConf, config)
}
// ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes.
// ListenAddrEarly works like [ListenAddr], but it returns connections before the handshake completes.
func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
conn, err := listenUDP(addr)
if err != nil {
@@ -206,16 +221,16 @@ func listenUDP(addr string) (*net.UDPConn, error) {
}
// Listen listens for QUIC connections on a given net.PacketConn.
// If the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn does),
// If the PacketConn satisfies the [OOBCapablePacketConn] interface (as a [net.UDPConn] does),
// ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
// will be used instead of ReadFrom and WriteTo to read/write packets.
// A single net.PacketConn can only be used for a single call to Listen.
//
// The tls.Config must not be nil and must contain a certificate configuration.
// Furthermore, it must define an application control (using NextProtos).
// Furthermore, it must define an application control (using [NextProtos]).
// The quic.Config may be nil, in that case the default values will be used.
//
// This is a convenience function. More advanced use cases should instantiate a Transport,
// This is a convenience function. More advanced use cases should instantiate a [Transport],
// which offers configuration options for a more fine-grained control of the connection establishment,
// including reusing the underlying UDP socket for outgoing QUIC connections.
// When closing a listener created with Listen, all established QUIC connections will be closed immediately.
@@ -224,7 +239,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener
return tr.Listen(tlsConf, config)
}
// ListenEarly works like Listen, but it returns connections before the handshake completes.
// ListenEarly works like [Listen], but it returns connections before the handshake completes.
func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
tr := &Transport{Conn: conn, isSingleUse: true}
return tr.ListenEarly(tlsConf, config)
@@ -232,8 +247,9 @@ func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Ear
func newServer(
conn rawConn,
connHandler packetHandlerManager,
tr *Transport,
connIDGenerator ConnectionIDGenerator,
statelessResetter *statelessResetter,
connContext func(context.Context) context.Context,
tlsConf *tls.Config,
config *Config,
@@ -248,15 +264,17 @@ func newServer(
s := &baseServer{
conn: conn,
connContext: connContext,
tr: tr,
tlsConf: tlsConf,
config: config,
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
maxTokenAge: maxTokenAge,
verifySourceAddress: verifySourceAddress,
connIDGenerator: connIDGenerator,
connHandler: connHandler,
statelessResetter: statelessResetter,
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
errorChan: make(chan struct{}),
stopAccepting: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
versionNegotiationQueue: make(chan receivedPacket, 4),
@@ -327,7 +345,13 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) {
return nil, ctx.Err()
case conn := <-s.connQueue:
return conn, nil
case <-s.errorChan:
case <-s.stopAccepting:
// first drain the queue
select {
case conn := <-s.connQueue:
return conn, nil
default:
}
return nil, s.closeErr
}
}
@@ -351,6 +375,9 @@ func (s *baseServer) close(e error, notifyOnClose bool) {
if notifyOnClose {
s.onClose()
}
// wait until all handshakes in flight have terminated
s.handshakingCount.Wait()
close(s.stopAccepting)
}
// Addr returns the server's network address
@@ -361,6 +388,8 @@ func (s *baseServer) Addr() net.Addr {
func (s *baseServer) handlePacket(p receivedPacket) {
select {
case s.receivedPackets <- p:
case <-s.errorChan:
return
default:
s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size())
if s.tracer != nil && s.tracer.DroppedPacket != nil {
@@ -397,6 +426,9 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
// send a Version Negotiation Packet if the client is speaking a different protocol version
if !protocol.IsSupportedVersion(s.config.Versions, v) {
if s.disableVersionNegotiation {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedVersion)
}
return false
}
@@ -469,7 +501,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
}
// check again if we might have a connection now
if handler, ok := s.connHandler.Get(connID); ok {
if handler, ok := s.tr.connRunner().Get(connID); ok {
handler.handlePacket(p)
return true
}
@@ -559,7 +591,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
// The server queues packets for a while, and we might already have established a connection by now.
// This results in a second check in the connection map.
// That's ok since it's not the hot path (it's only taken by some Initial and 0-RTT packets).
if handler, ok := s.connHandler.Get(hdr.DestConnectionID); ok {
if handler, ok := s.tr.connRunner().Get(hdr.DestConnectionID); ok {
handler.handlePacket(p)
return nil
}
@@ -617,7 +649,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
config := s.config
if s.config.GetConfigForClient != nil {
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{
conf, err := s.config.GetConfigForClient(&ClientInfo{
RemoteAddr: p.remoteAddr,
AddrVerified: clientAddrVerified,
})
@@ -674,14 +706,14 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
ctx,
cancel,
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
s.connHandler,
s.tr,
origDestConnID,
retrySrcConnID,
hdr.DestConnectionID,
hdr.SrcConnectionID,
connID,
s.connIDGenerator,
s.connHandler.GetStatelessResetToken(connID),
s.statelessResetter,
config,
s.tlsConf,
s.tokenGenerator,
@@ -695,7 +727,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
// This is very unlikely: Even if an attacker chooses a connection ID that's already in use,
// under normal circumstances the packet would just be routed to that connection.
// The only time this collision will occur if we receive the two Initial packets at the same time.
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
if added := s.tr.connRunner().AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
delete(s.zeroRTTQueues, hdr.DestConnectionID)
conn.closeWithTransportError(qerr.ConnectionRefused)
return nil
@@ -708,43 +740,42 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
delete(s.zeroRTTQueues, hdr.DestConnectionID)
}
go conn.run()
s.handshakingCount.Add(1)
go func() {
if completed := s.handleNewConn(conn); !completed {
return
}
select {
case s.connQueue <- conn:
default:
conn.closeWithTransportError(ConnectionRefused)
}
defer s.handshakingCount.Done()
s.handleNewConn(conn)
}()
go conn.run()
return nil
}
func (s *baseServer) handleNewConn(conn quicConn) bool {
func (s *baseServer) handleNewConn(conn quicConn) {
if s.acceptEarlyConns {
// wait until the early connection is ready, the handshake fails, or the server is closed
select {
case <-s.errorChan:
conn.closeWithTransportError(ConnectionRefused)
return false
return
case <-conn.Context().Done():
return false
return
case <-conn.earlyConnReady():
return true
}
} else {
// wait until the handshake completes, fails, or the server is closed
select {
case <-s.errorChan:
conn.closeWithTransportError(ConnectionRefused)
return
case <-conn.Context().Done():
return
case <-conn.HandshakeComplete():
}
}
// wait until the handshake completes, fails, or the server is closed
select {
case <-s.errorChan:
case s.connQueue <- conn:
default:
conn.closeWithTransportError(ConnectionRefused)
return false
case <-conn.Context().Done():
return false
case <-conn.HandshakeComplete():
return true
}
}
@@ -803,7 +834,7 @@ func (s *baseServer) maybeSendInvalidToken(p rejectedPacket) {
hdr := p.hdr
sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version)
data := p.data[:hdr.ParsedLen()+hdr.Length]
extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version)
extHdr, err := unpackLongHeader(opener, hdr, data)
// Only send INVALID_TOKEN if we can unprotect the packet.
// This makes sure that we won't send it for packets that were corrupted.
if err != nil {

42
vendor/github.com/quic-go/quic-go/stateless_reset.go generated vendored Normal file
View File

@@ -0,0 +1,42 @@
package quic
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"hash"
"sync"
"github.com/quic-go/quic-go/internal/protocol"
)
type statelessResetter struct {
mx sync.Mutex
h hash.Hash
}
// newStatelessRetter creates a new stateless reset generator.
// It is valid to use a nil key. In that case, a random key will be used.
// This makes is impossible for on-path attackers to shut down established connections.
func newStatelessResetter(key *StatelessResetKey) *statelessResetter {
var h hash.Hash
if key != nil {
h = hmac.New(sha256.New, key[:])
} else {
b := make([]byte, 32)
_, _ = rand.Read(b)
h = hmac.New(sha256.New, b)
}
return &statelessResetter{h: h}
}
func (r *statelessResetter) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
r.mx.Lock()
defer r.mx.Unlock()
var token protocol.StatelessResetToken
r.h.Write(connID.Bytes())
copy(token[:], r.h.Sum(nil))
r.h.Reset()
return token
}

View File

@@ -24,8 +24,9 @@ var errDeadline net.Error = &deadlineError{}
// The streamSender is notified by the stream about various events.
type streamSender interface {
queueControlFrame(wire.Frame)
onHasStreamData(protocol.StreamID)
onHasConnectionData()
onHasStreamData(protocol.StreamID, sendStreamI)
onHasStreamControlFrame(protocol.StreamID, streamControlFrameGetter)
// must be called without holding the mutex that is acquired by closeForShutdown
onStreamCompleted(protocol.StreamID)
}
@@ -34,19 +35,16 @@ type streamSender interface {
// This is necessary in order to keep track when both halves have been completed.
type uniStreamSender struct {
streamSender
onStreamCompletedImpl func()
onStreamCompletedImpl func()
onHasStreamControlFrameImpl func(protocol.StreamID, streamControlFrameGetter)
}
func (s *uniStreamSender) queueControlFrame(f wire.Frame) {
s.streamSender.queueControlFrame(f)
func (s *uniStreamSender) onHasStreamData(id protocol.StreamID, str sendStreamI) {
s.streamSender.onHasStreamData(id, str)
}
func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) {
s.streamSender.onHasStreamData(id)
}
func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) {
s.onStreamCompletedImpl()
func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) { s.onStreamCompletedImpl() }
func (s *uniStreamSender) onHasStreamControlFrame(id protocol.StreamID, str streamControlFrameGetter) {
s.onHasStreamControlFrameImpl(id, str)
}
var _ streamSender = &uniStreamSender{}
@@ -55,13 +53,12 @@ type streamI interface {
Stream
closeForShutdown(error)
// for receiving
handleStreamFrame(*wire.StreamFrame) error
handleResetStreamFrame(*wire.ResetStreamFrame) error
getWindowUpdate() protocol.ByteCount
handleStreamFrame(*wire.StreamFrame, time.Time) error
handleResetStreamFrame(*wire.ResetStreamFrame, time.Time) error
// for sending
hasData() bool
handleStopSendingFrame(*wire.StopSendingFrame)
popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, bool, bool)
popStreamFrame(protocol.ByteCount, protocol.Version) (_ ackhandler.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMore bool)
updateSendWindow(protocol.ByteCount)
}
@@ -83,7 +80,10 @@ type stream struct {
sendStreamCompleted bool
}
var _ Stream = &stream{}
var (
_ Stream = &stream{}
_ streamControlFrameGetter = &receiveStream{}
)
// newStream creates a new Stream
func newStream(
@@ -101,6 +101,9 @@ func newStream(
s.checkIfCompleted()
s.completedMutex.Unlock()
},
onHasStreamControlFrameImpl: func(id protocol.StreamID, str streamControlFrameGetter) {
sender.onHasStreamControlFrame(streamID, s)
},
}
s.sendStream = *newSendStream(ctx, streamID, senderForSendStream, flowController)
senderForReceiveStream := &uniStreamSender{
@@ -111,6 +114,9 @@ func newStream(
s.checkIfCompleted()
s.completedMutex.Unlock()
},
onHasStreamControlFrameImpl: func(id protocol.StreamID, str streamControlFrameGetter) {
sender.onHasStreamControlFrame(streamID, s)
},
}
s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController)
return s
@@ -126,6 +132,14 @@ func (s *stream) Close() error {
return s.sendStream.Close()
}
func (s *stream) getControlFrame(now time.Time) (_ ackhandler.Frame, ok, hasMore bool) {
f, ok, _ := s.sendStream.getControlFrame(now)
if ok {
return f, true, true
}
return s.receiveStream.getControlFrame(now)
}
func (s *stream) SetDeadline(t time.Time) error {
_ = s.SetReadDeadline(t) // SetReadDeadline never errors
_ = s.SetWriteDeadline(t) // SetWriteDeadline never errors

View File

@@ -2,9 +2,7 @@ package quic
import (
"context"
"errors"
"fmt"
"net"
"sync"
"github.com/quic-go/quic-go/internal/flowcontrol"
@@ -34,15 +32,12 @@ func convertStreamError(err error, stype protocol.StreamType, pers protocol.Pers
return fmt.Errorf(strError.Error(), ids...)
}
type streamOpenErr struct{ error }
// StreamLimitReachedError is returned from Connection.OpenStream and Connection.OpenUniStream
// when it is not possible to open a new stream because the number of opens streams reached
// the peer's stream limit.
type StreamLimitReachedError struct{}
var _ net.Error = &streamOpenErr{}
func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams }
func (streamOpenErr) Timeout() bool { return false }
// errTooManyOpenStreams is used internally by the outgoing streams maps.
var errTooManyOpenStreams = errors.New("too many open streams")
func (e StreamLimitReachedError) Error() string { return "too many open streams" }
type streamsMap struct {
ctx context.Context // not used for cancellations, but carries the values associated with the connection
@@ -52,6 +47,7 @@ type streamsMap struct {
maxIncomingUniStreams uint64
sender streamSender
queueControlFrame func(wire.Frame)
newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController
mutex sync.Mutex
@@ -67,14 +63,16 @@ var _ streamManager = &streamsMap{}
func newStreamsMap(
ctx context.Context,
sender streamSender,
queueControlFrame func(wire.Frame),
newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
maxIncomingBidiStreams uint64,
maxIncomingUniStreams uint64,
perspective protocol.Perspective,
) streamManager {
) *streamsMap {
m := &streamsMap{
ctx: ctx,
perspective: perspective,
queueControlFrame: queueControlFrame,
newFlowController: newFlowController,
maxIncomingBidiStreams: maxIncomingBidiStreams,
maxIncomingUniStreams: maxIncomingUniStreams,
@@ -91,7 +89,7 @@ func (m *streamsMap) initMaps() {
id := num.StreamID(protocol.StreamTypeBidi, m.perspective)
return newStream(m.ctx, id, m.sender, m.newFlowController(id))
},
m.sender.queueControlFrame,
m.queueControlFrame,
)
m.incomingBidiStreams = newIncomingStreamsMap(
protocol.StreamTypeBidi,
@@ -100,7 +98,7 @@ func (m *streamsMap) initMaps() {
return newStream(m.ctx, id, m.sender, m.newFlowController(id))
},
m.maxIncomingBidiStreams,
m.sender.queueControlFrame,
m.queueControlFrame,
)
m.outgoingUniStreams = newOutgoingStreamsMap(
protocol.StreamTypeUni,
@@ -108,7 +106,7 @@ func (m *streamsMap) initMaps() {
id := num.StreamID(protocol.StreamTypeUni, m.perspective)
return newSendStream(m.ctx, id, m.sender, m.newFlowController(id))
},
m.sender.queueControlFrame,
m.queueControlFrame,
)
m.incomingUniStreams = newIncomingStreamsMap(
protocol.StreamTypeUni,
@@ -117,7 +115,7 @@ func (m *streamsMap) initMaps() {
return newReceiveStream(id, m.sender, m.newFlowController(id))
},
m.maxIncomingUniStreams,
m.sender.queueControlFrame,
m.queueControlFrame,
)
}

View File

@@ -2,6 +2,7 @@ package quic
import (
"context"
"slices"
"sync"
"github.com/quic-go/quic-go/internal/protocol"
@@ -19,9 +20,7 @@ type outgoingStreamsMap[T outgoingStream] struct {
streamType protocol.StreamType
streams map[protocol.StreamNum]T
openQueue map[uint64]chan struct{}
lowestInQueue uint64
highestInQueue uint64
openQueue []chan struct{}
nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamNum // the maximum stream ID we're allowed to open
@@ -41,7 +40,6 @@ func newOutgoingStreamsMap[T outgoingStream](
return &outgoingStreamsMap[T]{
streamType: streamType,
streams: make(map[protocol.StreamNum]T),
openQueue: make(map[uint64]chan struct{}),
maxStream: protocol.InvalidStreamNum,
nextStream: 1,
newStream: newStream,
@@ -60,7 +58,7 @@ func (m *outgoingStreamsMap[T]) OpenStream() (T, error) {
// if there are OpenStreamSync calls waiting, return an error here
if len(m.openQueue) > 0 || m.nextStream > m.maxStream {
m.maybeSendBlockedFrame()
return *new(T), streamOpenErr{errTooManyOpenStreams}
return *new(T), &StreamLimitReachedError{}
}
return m.openStream(), nil
}
@@ -72,22 +70,15 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
if m.closeErr != nil {
return *new(T), m.closeErr
}
if err := ctx.Err(); err != nil {
return *new(T), err
}
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
return m.openStream(), nil
}
waitChan := make(chan struct{}, 1)
queuePos := m.highestInQueue
m.highestInQueue++
if len(m.openQueue) == 0 {
m.lowestInQueue = queuePos
}
m.openQueue[queuePos] = waitChan
m.openQueue = append(m.openQueue, waitChan)
m.maybeSendBlockedFrame()
for {
@@ -95,12 +86,17 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
select {
case <-ctx.Done():
m.mutex.Lock()
delete(m.openQueue, queuePos)
m.openQueue = slices.DeleteFunc(m.openQueue, func(c chan struct{}) bool {
return c == waitChan
})
// If we just received a MAX_STREAMS frame, this might have been the next stream
// that could be opened. Make sure we unblock the next OpenStreamSync call.
m.maybeUnblockOpenSync()
return *new(T), ctx.Err()
case <-waitChan:
}
m.mutex.Lock()
m.mutex.Lock()
if m.closeErr != nil {
return *new(T), m.closeErr
}
@@ -109,9 +105,8 @@ func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
continue
}
str := m.openStream()
delete(m.openQueue, queuePos)
m.lowestInQueue = queuePos + 1
m.unblockOpenSync()
m.openQueue = m.openQueue[1:]
m.maybeUnblockOpenSync()
return str, nil
}
}
@@ -181,7 +176,7 @@ func (m *outgoingStreamsMap[T]) SetMaxStream(num protocol.StreamNum) {
if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) {
m.maybeSendBlockedFrame()
}
m.unblockOpenSync()
m.maybeUnblockOpenSync()
}
// UpdateSendWindow is called when the peer's transport parameters are received.
@@ -196,27 +191,25 @@ func (m *outgoingStreamsMap[T]) UpdateSendWindow(limit protocol.ByteCount) {
}
// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream
func (m *outgoingStreamsMap[T]) unblockOpenSync() {
func (m *outgoingStreamsMap[T]) maybeUnblockOpenSync() {
if len(m.openQueue) == 0 {
return
}
for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
c, ok := m.openQueue[qp]
if !ok { // entry was deleted because the context was canceled
continue
}
// unblockOpenSync is called both from OpenStreamSync and from SetMaxStream.
// It's sufficient to only unblock OpenStreamSync once.
select {
case c <- struct{}{}:
default:
}
if m.nextStream > m.maxStream {
return
}
// unblockOpenSync is called both from OpenStreamSync and from SetMaxStream.
// It's sufficient to only unblock OpenStreamSync once.
select {
case m.openQueue[0] <- struct{}{}:
default:
}
}
func (m *outgoingStreamsMap[T]) CloseWithError(err error) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.closeErr = err
for _, str := range m.streams {
str.closeForShutdown(err)
@@ -226,5 +219,5 @@ func (m *outgoingStreamsMap[T]) CloseWithError(err error) {
close(c)
}
}
m.mutex.Unlock()
m.openQueue = nil
}

View File

@@ -14,7 +14,7 @@ import (
)
// OOBCapablePacketConn is a connection that allows the reading of ECN bits from the IP header.
// If the PacketConn passed to Dial or Listen satisfies this interface, quic-go will use it.
// If the PacketConn passed to the [Transport] satisfies this interface, quic-go will use it.
// In this case, ReadMsgUDP() will be used instead of ReadFrom() to read packets.
type OOBCapablePacketConn interface {
net.PacketConn
@@ -58,8 +58,8 @@ func wrapConn(pc net.PacketConn) (rawConn, error) {
return nil, err
}
// only set DF on UDP sockets
if _, ok := pc.LocalAddr().(*net.UDPAddr); ok {
// Only set DF on sockets that we expect to be able to handle that configuration.
var err error
supportsDF, err = setDF(rawConn)
if err != nil {
@@ -92,7 +92,7 @@ func (c *basicConn) ReadPacket() (receivedPacket, error) {
// The packet size should not exceed protocol.MaxPacketBufferSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable
buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize]
n, addr, err := c.PacketConn.ReadFrom(buffer.Data)
n, addr, err := c.ReadFrom(buffer.Data)
if err != nil {
return receivedPacket{}, err
}
@@ -111,7 +111,7 @@ func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint1
if ecn != protocol.ECNUnsupported {
panic("cannot use ECN with a basicConn")
}
return c.PacketConn.WriteTo(b, addr)
return c.WriteTo(b, addr)
}
func (c *basicConn) capabilities() connCapabilities { return connCapabilities{DF: c.supportsDF} }

View File

@@ -4,47 +4,67 @@ package quic
import (
"errors"
"fmt"
"strconv"
"strings"
"syscall"
"golang.org/x/sys/unix"
)
"github.com/quic-go/quic-go/internal/utils"
// for macOS versions, see https://en.wikipedia.org/wiki/Darwin_(operating_system)#Darwin_20_onwards
const (
macOSVersion11 = 20
macOSVersion15 = 24
)
func setDF(rawConn syscall.RawConn) (bool, error) {
// Setting DF bit is only supported from macOS11
// Setting DF bit is only supported from macOS 11.
// https://github.com/chromium/chromium/blob/117.0.5881.2/net/socket/udp_socket_posix.cc#L555
if supportsDF, err := isAtLeastMacOS11(); !supportsDF || err != nil {
version, err := getMacOSVersion()
if err != nil || version < macOSVersion11 {
return false, err
}
// Enabling IP_DONTFRAG will force the kernel to return "sendto: message too long"
// and the datagram will not be fragmented
var errDFIPv4, errDFIPv6 error
var controlErr error
var disableDF bool
if err := rawConn.Control(func(fd uintptr) {
errDFIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1)
errDFIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1)
addr, err := unix.Getsockname(int(fd))
if err != nil {
controlErr = fmt.Errorf("getsockname: %w", err)
return
}
// Dual-stack sockets are effectively IPv6 sockets (with IPV6_ONLY set to 0).
// On macOS, the DF bit on dual-stack sockets is controlled by the IPV6_DONTFRAG option.
// See https://datatracker.ietf.org/doc/draft-seemann-tsvwg-udp-fragmentation/ for details.
switch addr.(type) {
case *unix.SockaddrInet4:
controlErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1)
case *unix.SockaddrInet6:
controlErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1)
// Setting the DF bit on dual-stack sockets works since macOS Sequoia.
// Disable DF on dual-stack sockets before Sequoia.
if version < macOSVersion15 {
// check if this is a dual-stack socket by reading the IPV6_V6ONLY flag
v6only, err := unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY)
if err != nil {
controlErr = fmt.Errorf("getting IPV6_V6ONLY: %w", err)
return
}
disableDF = v6only == 0
}
default:
controlErr = fmt.Errorf("unknown address type: %T", addr)
}
}); err != nil {
return false, err
}
switch {
case errDFIPv4 == nil && errDFIPv6 == nil:
utils.DefaultLogger.Debugf("Setting DF for IPv4 and IPv6.")
case errDFIPv4 == nil && errDFIPv6 != nil:
utils.DefaultLogger.Debugf("Setting DF for IPv4.")
case errDFIPv4 != nil && errDFIPv6 == nil:
utils.DefaultLogger.Debugf("Setting DF for IPv6.")
// On macOS, the syscall for setting DF bit for IPv4 fails on dual-stack listeners.
// Treat the connection as not having DF enabled, even though the DF bit will be set
// when used for IPv6.
// See https://github.com/quic-go/quic-go/issues/3793 for details.
return false, nil
case errDFIPv4 != nil && errDFIPv6 != nil:
return false, errors.New("setting DF failed for both IPv4 and IPv6")
if controlErr != nil {
return false, controlErr
}
return true, nil
return !disableDF, nil
}
func isSendMsgSizeErr(err error) bool {
@@ -53,22 +73,20 @@ func isSendMsgSizeErr(err error) bool {
func isRecvMsgSizeErr(error) bool { return false }
func isAtLeastMacOS11() (bool, error) {
func getMacOSVersion() (int, error) {
uname := &unix.Utsname{}
err := unix.Uname(uname)
if err != nil {
return false, err
if err := unix.Uname(uname); err != nil {
return 0, err
}
release := string(uname.Release[:])
if idx := strings.Index(release, "."); idx != -1 {
version, err := strconv.Atoi(release[:idx])
if err != nil {
return false, err
}
// Darwin version 20 is macOS version 11
// https://en.wikipedia.org/wiki/Darwin_(operating_system)#Darwin_20_onwards
return version >= 20, nil
idx := strings.Index(release, ".")
if idx == -1 {
return 0, nil
}
return false, nil
version, err := strconv.Atoi(release[:idx])
if err != nil {
return 0, err
}
return version, nil
}

View File

@@ -16,8 +16,8 @@ func setDF(rawConn syscall.RawConn) (bool, error) {
// and the datagram will not be fragmented
var errDFIPv4, errDFIPv6 error
if err := rawConn.Control(func(fd uintptr) {
errDFIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
errDFIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_DO)
errDFIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE)
errDFIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_PROBE)
}); err != nil {
return false, err
}

View File

@@ -12,21 +12,19 @@ import (
)
const (
// IP_DONTFRAGMENT controls the Don't Fragment (DF) bit.
//
// It's the same code point for both IPv4 and IPv6 on Windows.
// https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IP_DONTFRAG.html
// https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IPV6_DONTFRAG.html
//
// https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IP_DONTFRAGMENT.html
//nolint:stylecheck
IP_DONTFRAGMENT = 14
// https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IPV6_DONTFRAG.html
//nolint:stylecheck
IPV6_DONTFRAG = 14
)
func setDF(rawConn syscall.RawConn) (bool, error) {
var errDFIPv4, errDFIPv6 error
if err := rawConn.Control(func(fd uintptr) {
errDFIPv4 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_DONTFRAGMENT, 1)
errDFIPv6 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IP_DONTFRAGMENT, 1)
errDFIPv6 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_DONTFRAG, 1)
}); err != nil {
return false, err
}

View File

@@ -83,7 +83,7 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) {
if err != nil {
return nil, err
}
needsPacketInfo := false
var needsPacketInfo bool
if udpAddr, ok := c.LocalAddr().(*net.UDPAddr); ok && udpAddr.IP.IsUnspecified() {
needsPacketInfo = true
}
@@ -257,7 +257,7 @@ func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gso
}
}
}
n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr))
n, _, err := c.WriteMsgUDP(b, oob, addr.(*net.UDPAddr))
return n, err
}

View File

@@ -5,6 +5,7 @@ import (
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
@@ -16,6 +17,31 @@ import (
"github.com/quic-go/quic-go/logging"
)
// ErrTransportClosed is returned by the [Transport]'s Listen or Dial method after it was closed.
var ErrTransportClosed = &errTransportClosed{}
type errTransportClosed struct {
err error
}
func (e *errTransportClosed) Unwrap() []error { return []error{net.ErrClosed, e.err} }
func (e *errTransportClosed) Error() string {
if e.err == nil {
return "quic: transport closed"
}
return fmt.Sprintf("quic: transport closed: %s", e.err)
}
func (e *errTransportClosed) Is(target error) bool {
_, ok := target.(*errTransportClosed)
return ok
}
type transportID uint64
var transportIDCounter atomic.Uint64
var errListenerAlreadySet = errors.New("listener already set")
// The Transport is the central point to manage incoming and outgoing QUIC connections.
@@ -23,7 +49,7 @@ var errListenerAlreadySet = errors.New("listener already set")
// This means that a single UDP socket can be used for listening for incoming connections, as well as
// for dialing an arbitrary number of outgoing connections.
// A Transport handles a single net.PacketConn, and offers a range of configuration options
// compared to the simple helper functions like Listen and Dial that this package provides.
// compared to the simple helper functions like [Listen] and [Dial] that this package provides.
type Transport struct {
// A single net.PacketConn can only be handled by one Transport.
// Bad things will happen if passed to multiple Transports.
@@ -111,11 +137,13 @@ type Transport struct {
initErr error
// Set in init.
transportID transportID
// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
connIDLen int
// Set in init.
// If no ConnectionIDGenerator is set, this is set to a default.
connIDGenerator ConnectionIDGenerator
connIDGenerator ConnectionIDGenerator
statelessResetter *statelessResetter
server *baseServer
@@ -125,7 +153,7 @@ type Transport struct {
statelessResetQueue chan receivedPacket
listening chan struct{} // is closed when listen returns
closed bool
closeErr error
createdConn bool
isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial
@@ -137,7 +165,7 @@ type Transport struct {
// Listen starts listening for incoming QUIC connections.
// There can only be a single listener on any net.PacketConn.
// Listen may only be called again after the current Listener was closed.
// Listen may only be called again after the current listener was closed.
func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) {
s, err := t.createServer(tlsConf, conf, false)
if err != nil {
@@ -148,7 +176,7 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error)
// ListenEarly starts listening for incoming QUIC connections.
// There can only be a single listener on any net.PacketConn.
// Listen may only be called again after the current Listener was closed.
// ListenEarly may only be called again after the current listener was closed.
func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) {
s, err := t.createServer(tlsConf, conf, true)
if err != nil {
@@ -168,6 +196,9 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
t.mutex.Lock()
defer t.mutex.Unlock()
if t.closeErr != nil {
return nil, t.closeErr
}
if t.server != nil {
return nil, errListenerAlreadySet
}
@@ -175,17 +206,22 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
if err := t.init(false); err != nil {
return nil, err
}
maxTokenAge := t.MaxTokenAge
if maxTokenAge == 0 {
maxTokenAge = 24 * time.Hour
}
s := newServer(
t.conn,
t.handlerMap,
t,
t.connIDGenerator,
t.statelessResetter,
t.ConnContext,
tlsConf,
conf,
t.Tracer,
t.closeServer,
*t.TokenGeneratorKey,
t.MaxTokenAge,
maxTokenAge,
t.VerifySourceAddress,
t.DisableVersionNegotiationPackets,
allow0RTT,
@@ -205,24 +241,142 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
}
func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) {
if err := t.init(t.isSingleUse); err != nil {
return nil, err
}
if err := validateConfig(conf); err != nil {
return nil, err
}
conf = populateConfig(conf)
if err := t.init(t.isSingleUse); err != nil {
return nil, err
}
var onClose func()
if t.isSingleUse {
onClose = func() { t.Close() }
}
tlsConf = tlsConf.Clone()
setTLSConfigServerName(tlsConf, addr, host)
return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT)
return t.doDial(ctx,
newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger),
tlsConf,
conf,
0,
false,
use0RTT,
conf.Versions[0],
)
}
func (t *Transport) doDial(
ctx context.Context,
sendConn sendConn,
tlsConf *tls.Config,
config *Config,
initialPacketNumber protocol.PacketNumber,
hasNegotiatedVersion bool,
use0RTT bool,
version protocol.Version,
) (quicConn, error) {
srcConnID, err := t.connIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err
}
destConnID, err := generateConnectionIDForInitial()
if err != nil {
return nil, err
}
tracingID := nextConnTracingID()
ctx = context.WithValue(ctx, ConnectionTracingKey, tracingID)
t.mutex.Lock()
if t.closeErr != nil {
t.mutex.Unlock()
return nil, t.closeErr
}
var tracer *logging.ConnectionTracer
if config.Tracer != nil {
tracer = config.Tracer(ctx, protocol.PerspectiveClient, destConnID)
}
if tracer != nil && tracer.StartedConnection != nil {
tracer.StartedConnection(sendConn.LocalAddr(), sendConn.RemoteAddr(), srcConnID, destConnID)
}
logger := utils.DefaultLogger.WithPrefix("client")
logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", tlsConf.ServerName, sendConn.LocalAddr(), sendConn.RemoteAddr(), srcConnID, destConnID, version)
conn := newClientConnection(
context.WithoutCancel(ctx),
sendConn,
t,
destConnID,
srcConnID,
t.connIDGenerator,
t.statelessResetter,
config,
tlsConf,
initialPacketNumber,
use0RTT,
hasNegotiatedVersion,
tracer,
logger,
version,
)
t.handlerMap.Add(srcConnID, conn)
t.mutex.Unlock()
// The error channel needs to be buffered, as the run loop will continue running
// after doDial returns (if the handshake is successful).
errChan := make(chan error, 1)
recreateChan := make(chan errCloseForRecreating)
go func() {
err := conn.run()
var recreateErr *errCloseForRecreating
if errors.As(err, &recreateErr) {
recreateChan <- *recreateErr
return
}
if t.isSingleUse {
t.Close()
}
errChan <- err
}()
// Only set when we're using 0-RTT.
// Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
var earlyConnChan <-chan struct{}
if use0RTT {
earlyConnChan = conn.earlyConnReady()
}
select {
case <-ctx.Done():
conn.destroy(nil)
// wait until the Go routine that called Connection.run() returns
select {
case <-errChan:
case <-recreateChan:
}
return nil, context.Cause(ctx)
case params := <-recreateChan:
return t.doDial(ctx,
sendConn,
tlsConf,
config,
params.nextPacketNumber,
true,
use0RTT,
params.nextVersion,
)
case err := <-errChan:
return nil, err
case <-earlyConnChan:
// ready to send 0-RTT data
return conn, nil
case <-conn.HandshakeComplete():
// handshake successfully completed
return conn, nil
}
}
func (t *Transport) init(allowZeroLengthConnIDs bool) error {
t.initOnce.Do(func() {
t.transportID = transportID(transportIDCounter.Add(1))
var conn rawConn
if c, ok := t.Conn.(rawConn); ok {
conn = c
@@ -237,7 +391,9 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error {
t.logger = utils.DefaultLogger // TODO: make this configurable
t.conn = conn
t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger)
if t.handlerMap == nil { // allows mocking the handlerMap in tests
t.handlerMap = newPacketHandlerMap(t.enqueueClosePacket, t.logger)
}
t.listening = make(chan struct{})
t.closeQueue = make(chan closePacket, 4)
@@ -262,14 +418,20 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error {
t.connIDLen = connIDLen
t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen}
}
t.statelessResetter = newStatelessResetter(t.StatelessResetKey)
getMultiplexer().AddConn(t.Conn)
go t.listen(conn)
go t.runSendQueue()
})
return t.initErr
}
func (t *Transport) connRunner() packetHandlerManager {
return t.handlerMap
}
func (t *Transport) id() transportID { return t.transportID }
// WriteTo sends a packet on the underlying connection.
func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
if err := t.init(false); err != nil {
@@ -300,11 +462,14 @@ func (t *Transport) runSendQueue() {
}
}
// Close closes the underlying connection.
// Close stops listening for UDP datagrams on the Transport.Conn.
// If any listener was started, it will be closed as well.
// It is invalid to start new listeners or connections after that.
func (t *Transport) Close() error {
t.close(errors.New("closing"))
// avoid race condition if the transport is currently being initialized
t.init(false)
t.close(nil)
if t.createdConn {
if err := t.Conn.Close(); err != nil {
return err
@@ -323,7 +488,7 @@ func (t *Transport) closeServer() {
t.mutex.Lock()
t.server = nil
if t.isSingleUse {
t.closed = true
t.closeErr = ErrServerClosed
}
t.mutex.Unlock()
if t.createdConn {
@@ -339,10 +504,12 @@ func (t *Transport) closeServer() {
func (t *Transport) close(e error) {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.closed {
if t.closeErr != nil {
return
}
e = &errTransportClosed{err: e}
if t.handlerMap != nil {
t.handlerMap.Close(e)
}
@@ -352,7 +519,7 @@ func (t *Transport) close(e error) {
if t.Tracer != nil && t.Tracer.Close != nil {
t.Tracer.Close()
}
t.closed = true
t.closeErr = e
}
// only print warnings about the UDP receive buffer size once
@@ -360,7 +527,6 @@ var setBufferWarningOnce sync.Once
func (t *Transport) listen(conn rawConn) {
defer close(t.listening)
defer getMultiplexer().RemoveConn(t.Conn)
for {
p, err := conn.ReadPacket()
@@ -370,7 +536,7 @@ func (t *Transport) listen(conn rawConn) {
// See https://github.com/quic-go/quic-go/issues/1737 for details.
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
t.mutex.Lock()
closed := t.closed
closed := t.closeErr != nil
t.mutex.Unlock()
if closed {
return
@@ -424,7 +590,12 @@ func (t *Transport) handlePacket(p receivedPacket) {
return
}
if !wire.IsLongHeaderPacket(p.data[0]) {
t.maybeSendStatelessReset(p)
if statelessResetQueued := t.maybeSendStatelessReset(p); !statelessResetQueued {
if t.Tracer != nil && t.Tracer.DroppedPacket != nil {
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnknownConnectionID)
}
p.buffer.Release()
}
return
}
@@ -432,29 +603,32 @@ func (t *Transport) handlePacket(p receivedPacket) {
defer t.mutex.Unlock()
if t.server == nil { // no server set
t.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
if t.Tracer != nil && t.Tracer.DroppedPacket != nil {
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnknownConnectionID)
}
p.buffer.MaybeRelease()
return
}
t.server.handlePacket(p)
}
func (t *Transport) maybeSendStatelessReset(p receivedPacket) {
func (t *Transport) maybeSendStatelessReset(p receivedPacket) (statelessResetQueued bool) {
if t.StatelessResetKey == nil {
p.buffer.Release()
return
return false
}
// Don't send a stateless reset in response to very small packets.
// This includes packets that could be stateless resets.
if len(p.data) <= protocol.MinStatelessResetSize {
p.buffer.Release()
return
return false
}
select {
case t.statelessResetQueue <- p:
return true
default:
// it's fine to not send a stateless reset when we're busy
p.buffer.Release()
return false
}
}
@@ -466,7 +640,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) {
t.logger.Errorf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
return
}
token := t.handlerMap.GetStatelessResetToken(connID)
token := t.statelessResetter.GetStatelessResetToken(connID)
t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
rand.Read(data)
@@ -489,7 +663,7 @@ func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
if conn, ok := t.handlerMap.GetByResetToken(token); ok {
t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
go conn.destroy(&StatelessResetError{Token: token})
go conn.destroy(&StatelessResetError{})
return true
}
return false

View File

@@ -1,71 +0,0 @@
package quic
import (
"sync"
"github.com/quic-go/quic-go/internal/flowcontrol"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
type windowUpdateQueue struct {
mutex sync.Mutex
queue map[protocol.StreamID]struct{} // used as a set
queuedConn bool // connection-level window update
streamGetter streamGetter
connFlowController flowcontrol.ConnectionFlowController
callback func(wire.Frame)
}
func newWindowUpdateQueue(
streamGetter streamGetter,
connFC flowcontrol.ConnectionFlowController,
cb func(wire.Frame),
) *windowUpdateQueue {
return &windowUpdateQueue{
queue: make(map[protocol.StreamID]struct{}),
streamGetter: streamGetter,
connFlowController: connFC,
callback: cb,
}
}
func (q *windowUpdateQueue) AddStream(id protocol.StreamID) {
q.mutex.Lock()
q.queue[id] = struct{}{}
q.mutex.Unlock()
}
func (q *windowUpdateQueue) AddConnection() {
q.mutex.Lock()
q.queuedConn = true
q.mutex.Unlock()
}
func (q *windowUpdateQueue) QueueAll() {
q.mutex.Lock()
// queue a connection-level window update
if q.queuedConn {
q.callback(&wire.MaxDataFrame{MaximumData: q.connFlowController.GetWindowUpdate()})
q.queuedConn = false
}
// queue all stream-level window updates
for id := range q.queue {
delete(q.queue, id)
str, err := q.streamGetter.GetOrOpenReceiveStream(id)
if err != nil || str == nil { // the stream can be nil if it was completed before dequeing the window update
continue
}
offset := str.getWindowUpdate()
if offset == 0 { // can happen if we received a final offset, right after queueing the window update
continue
}
q.callback(&wire.MaxStreamDataFrame{
StreamID: id,
MaximumStreamData: offset,
})
}
q.mutex.Unlock()
}