1
0
mirror of https://github.com/fumiama/terasu-cloudflared.git synced 2026-06-05 00:50:24 +08:00
Files
terasu-cloudflared/supervisor/supervisor.go
Devin Carr 41dffd7f3c CUSTESC-53681: Correct QUIC connection management for datagram handlers
Corrects the pattern of using errgroup's and context cancellation to simplify the logic for canceling extra routines for the QUIC connection. This is because the extra context cancellation is redundant with the fact that the errgroup also cancels it's own provided context when a routine returns (error or not).

For the datagram handler specifically, since it can respond faster to a context cancellation from the QUIC connection, we wrap the error before surfacing it outside of the QUIC connection scope to the supervisor. Additionally, the supervisor will look for this error type to check if it should retry the QUIC connection. These two operations are required because the supervisor does not look for a context canceled error when deciding to retry a connection. If a context canceled from the datagram handler were to be returned up to the supervisor on the initial connection, the cloudflared application would exit. We want to ensure that cloudflared maintains connection attempts even if any of the services on-top of a QUIC connection fail (datagram handler in this case).

Additional logging is also introduced along these paths to help with understanding the error conditions from the specific handlers on-top of a QUIC connection.

Related CUSTESC-53681

Closes TUN-9610
2025-08-19 16:10:00 -07:00

331 lines
10 KiB
Go

package supervisor
import (
"context"
"errors"
"net"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/orchestration"
v3 "github.com/cloudflare/cloudflared/quic/v3"
"github.com/cloudflare/cloudflared/retry"
"github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/tunnelstate"
)
const (
// Waiting time before retrying a failed tunnel connection
tunnelRetryDuration = time.Second * 10
// Interval between registering new tunnels
registrationInterval = time.Second
)
// Supervisor manages non-declarative tunnels. Establishes TCP connections with the edge, and
// reconnects them if they disconnect.
type Supervisor struct {
config *TunnelConfig
orchestrator *orchestration.Orchestrator
edgeIPs *edgediscovery.Edge
edgeTunnelServer TunnelServer
tunnelErrors chan tunnelError
tunnelsConnecting map[int]chan struct{}
tunnelsProtocolFallback map[int]*protocolFallback
// nextConnectedIndex and nextConnectedSignal are used to wait for all
// currently-connecting tunnels to finish connecting so we can reset backoff timer
nextConnectedIndex int
nextConnectedSignal chan struct{}
log *ConnAwareLogger
logTransport *zerolog.Logger
reconnectCh chan ReconnectSignal
gracefulShutdownC <-chan struct{}
}
var errEarlyShutdown = errors.New("shutdown started")
type tunnelError struct {
index int
err error
}
func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrator, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) {
isStaticEdge := len(config.EdgeAddrs) > 0
var err error
var edgeIPs *edgediscovery.Edge
if isStaticEdge { // static edge addresses
edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs)
} else {
edgeIPs, err = edgediscovery.ResolveEdge(config.Log, config.Region, config.EdgeIPVersion)
}
if err != nil {
return nil, err
}
tracker := tunnelstate.NewConnTracker(config.Log)
log := NewConnAwareLogger(config.Log, tracker, config.Observer)
edgeAddrHandler := NewIPAddrFallback(config.MaxEdgeAddrRetries)
edgeBindAddr := config.EdgeBindAddr
datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer)
sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, config.OriginDialerService, orchestrator.GetFlowLimiter())
edgeTunnelServer := EdgeTunnelServer{
config: config,
orchestrator: orchestrator,
sessionManager: sessionManager,
datagramMetrics: datagramMetrics,
edgeAddrs: edgeIPs,
edgeAddrHandler: edgeAddrHandler,
edgeBindAddr: edgeBindAddr,
tracker: tracker,
reconnectCh: reconnectCh,
gracefulShutdownC: gracefulShutdownC,
connAwareLogger: log,
}
return &Supervisor{
config: config,
orchestrator: orchestrator,
edgeIPs: edgeIPs,
edgeTunnelServer: &edgeTunnelServer,
tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{},
tunnelsProtocolFallback: map[int]*protocolFallback{},
log: log,
logTransport: config.LogTransport,
reconnectCh: reconnectCh,
gracefulShutdownC: gracefulShutdownC,
}, nil
}
func (s *Supervisor) Run(
ctx context.Context,
connectedSignal *signal.Signal,
) error {
if s.config.ICMPRouterServer != nil {
go func() {
if err := s.config.ICMPRouterServer.Serve(ctx); err != nil {
if errors.Is(err, net.ErrClosed) {
s.log.Logger().Info().Err(err).Msg("icmp router terminated")
} else {
s.log.Logger().Err(err).Msg("icmp router terminated")
}
}
}()
}
// Setup DNS Resolver refresh
go s.config.OriginDNSService.StartRefreshLoop(ctx)
if err := s.initialize(ctx, connectedSignal); err != nil {
if err == errEarlyShutdown {
return nil
}
s.log.Logger().Error().Err(err).Msg("initial tunnel connection failed")
return err
}
var tunnelsWaiting []int
tunnelsActive := s.config.HAConnections
backoff := retry.NewBackoff(s.config.Retries, tunnelRetryDuration, true)
var backoffTimer <-chan time.Time
shuttingDown := false
for {
select {
// Context cancelled
case <-ctx.Done():
for tunnelsActive > 0 {
<-s.tunnelErrors
tunnelsActive--
}
return nil
// startTunnel completed with a response
// (note that this may also be caused by context cancellation)
case tunnelError := <-s.tunnelErrors:
tunnelsActive--
s.log.ConnAwareLogger().Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
if tunnelError.err != nil && !shuttingDown {
switch tunnelError.err.(type) {
case ReconnectSignal:
// For tunnels that closed with reconnect signal, we reconnect immediately
go s.startTunnel(ctx, tunnelError.index, s.newConnectedTunnelSignal(tunnelError.index))
tunnelsActive++
continue
}
// Make sure we don't continue if there is no more fallback allowed
if _, retry := s.tunnelsProtocolFallback[tunnelError.index].GetMaxBackoffDuration(ctx); !retry {
continue
}
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
s.waitForNextTunnel(tunnelError.index)
if backoffTimer == nil {
backoffTimer = backoff.BackoffTimer()
}
} else if tunnelsActive == 0 {
s.log.ConnAwareLogger().Msg("no more connections active and exiting")
// All connected tunnels exited gracefully, no more work to do
return nil
}
// Backoff was set and its timer expired
case <-backoffTimer:
backoffTimer = nil
for _, index := range tunnelsWaiting {
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
}
tunnelsActive += len(tunnelsWaiting)
tunnelsWaiting = nil
// Tunnel successfully connected
case <-s.nextConnectedSignal:
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
// No more tunnels outstanding, clear backoff timer
backoff.SetGracePeriod()
}
case <-s.gracefulShutdownC:
shuttingDown = true
}
}
}
// Returns nil if initialization succeeded, else the initialization error.
// Attempts here will be made to connect one tunnel, if successful, it will
// connect the available tunnels up to config.HAConnections.
func (s *Supervisor) initialize(
ctx context.Context,
connectedSignal *signal.Signal,
) error {
availableAddrs := s.edgeIPs.AvailableAddrs()
if s.config.HAConnections > availableAddrs {
s.log.Logger().Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
s.config.HAConnections = availableAddrs
}
s.tunnelsProtocolFallback[0] = &protocolFallback{
retry.NewBackoff(s.config.Retries, retry.DefaultBaseTime, true),
s.config.ProtocolSelector.Current(),
false,
}
go s.startFirstTunnel(ctx, connectedSignal)
// Wait for response from first tunnel before proceeding to attempt other HA edge tunnels
select {
case <-ctx.Done():
<-s.tunnelErrors
return ctx.Err()
case tunnelError := <-s.tunnelErrors:
return tunnelError.err
case <-s.gracefulShutdownC:
return errEarlyShutdown
case <-connectedSignal.Wait():
}
// At least one successful connection, so start the rest
for i := 1; i < s.config.HAConnections; i++ {
s.tunnelsProtocolFallback[i] = &protocolFallback{
retry.NewBackoff(s.config.Retries, retry.DefaultBaseTime, true),
// Set the protocol we know the first tunnel connected with.
s.tunnelsProtocolFallback[0].protocol,
false,
}
go s.startTunnel(ctx, i, s.newConnectedTunnelSignal(i))
time.Sleep(registrationInterval)
}
return nil
}
// startTunnel starts the first tunnel connection. The resulting error will be sent on
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
func (s *Supervisor) startFirstTunnel(
ctx context.Context,
connectedSignal *signal.Signal,
) {
var err error
const firstConnIndex = 0
isStaticEdge := len(s.config.EdgeAddrs) > 0
defer func() {
s.tunnelErrors <- tunnelError{index: firstConnIndex, err: err}
}()
// If the first tunnel disconnects, keep restarting it.
for {
err = s.edgeTunnelServer.Serve(ctx, firstConnIndex, s.tunnelsProtocolFallback[firstConnIndex], connectedSignal)
if ctx.Err() != nil {
return
}
if err == nil {
return
}
// Make sure we don't continue if there is no more fallback allowed
if _, retry := s.tunnelsProtocolFallback[firstConnIndex].GetMaxBackoffDuration(ctx); !retry {
return
}
// Try again for Unauthorized errors because we hope them to be
// transient due to edge propagation lag on new Tunnels.
if strings.Contains(err.Error(), "Unauthorized") {
continue
}
switch err.(type) {
case edgediscovery.ErrNoAddressesLeft:
// If your provided addresses are not available, we will keep trying regardless.
if !isStaticEdge {
return
}
case connection.DupConnRegisterTunnelError,
*quic.IdleTimeoutError,
*quic.ApplicationError,
edgediscovery.DialError,
*connection.EdgeQuicDialError,
*connection.ControlStreamError,
*connection.StreamListenerError,
*connection.DatagramManagerError:
// Try again for these types of errors
default:
// Uncaught errors should bail startup
return
}
}
}
// startTunnel starts a new tunnel connection. The resulting error will be sent on
// s.tunnelError as this is expected to run in a goroutine.
func (s *Supervisor) startTunnel(
ctx context.Context,
index int,
connectedSignal *signal.Signal,
) {
// nolint: gosec
err := s.edgeTunnelServer.Serve(ctx, uint8(index), s.tunnelsProtocolFallback[index], connectedSignal)
s.tunnelErrors <- tunnelError{index: index, err: err}
}
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
sig := make(chan struct{})
s.tunnelsConnecting[index] = sig
s.nextConnectedSignal = sig
s.nextConnectedIndex = index
return signal.New(sig)
}
func (s *Supervisor) waitForNextTunnel(index int) bool {
delete(s.tunnelsConnecting, index)
s.nextConnectedSignal = nil
for k, v := range s.tunnelsConnecting {
s.nextConnectedIndex = k
s.nextConnectedSignal = v
return true
}
return false
}