mirror of
https://github.com/fumiama/terasu-cloudflared.git
synced 2026-06-05 00:50:24 +08:00
Corrects the pattern of using errgroup's and context cancellation to simplify the logic for canceling extra routines for the QUIC connection. This is because the extra context cancellation is redundant with the fact that the errgroup also cancels it's own provided context when a routine returns (error or not). For the datagram handler specifically, since it can respond faster to a context cancellation from the QUIC connection, we wrap the error before surfacing it outside of the QUIC connection scope to the supervisor. Additionally, the supervisor will look for this error type to check if it should retry the QUIC connection. These two operations are required because the supervisor does not look for a context canceled error when deciding to retry a connection. If a context canceled from the datagram handler were to be returned up to the supervisor on the initial connection, the cloudflared application would exit. We want to ensure that cloudflared maintains connection attempts even if any of the services on-top of a QUIC connection fail (datagram handler in this case). Additional logging is also introduced along these paths to help with understanding the error conditions from the specific handlers on-top of a QUIC connection. Related CUSTESC-53681 Closes TUN-9610
331 lines
10 KiB
Go
331 lines
10 KiB
Go
package supervisor
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/quic-go/quic-go"
|
|
"github.com/rs/zerolog"
|
|
|
|
"github.com/cloudflare/cloudflared/connection"
|
|
"github.com/cloudflare/cloudflared/edgediscovery"
|
|
"github.com/cloudflare/cloudflared/orchestration"
|
|
v3 "github.com/cloudflare/cloudflared/quic/v3"
|
|
"github.com/cloudflare/cloudflared/retry"
|
|
"github.com/cloudflare/cloudflared/signal"
|
|
"github.com/cloudflare/cloudflared/tunnelstate"
|
|
)
|
|
|
|
const (
|
|
// Waiting time before retrying a failed tunnel connection
|
|
tunnelRetryDuration = time.Second * 10
|
|
// Interval between registering new tunnels
|
|
registrationInterval = time.Second
|
|
)
|
|
|
|
// Supervisor manages non-declarative tunnels. Establishes TCP connections with the edge, and
|
|
// reconnects them if they disconnect.
|
|
type Supervisor struct {
|
|
config *TunnelConfig
|
|
orchestrator *orchestration.Orchestrator
|
|
edgeIPs *edgediscovery.Edge
|
|
edgeTunnelServer TunnelServer
|
|
tunnelErrors chan tunnelError
|
|
tunnelsConnecting map[int]chan struct{}
|
|
tunnelsProtocolFallback map[int]*protocolFallback
|
|
// nextConnectedIndex and nextConnectedSignal are used to wait for all
|
|
// currently-connecting tunnels to finish connecting so we can reset backoff timer
|
|
nextConnectedIndex int
|
|
nextConnectedSignal chan struct{}
|
|
|
|
log *ConnAwareLogger
|
|
logTransport *zerolog.Logger
|
|
|
|
reconnectCh chan ReconnectSignal
|
|
gracefulShutdownC <-chan struct{}
|
|
}
|
|
|
|
var errEarlyShutdown = errors.New("shutdown started")
|
|
|
|
type tunnelError struct {
|
|
index int
|
|
err error
|
|
}
|
|
|
|
func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrator, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) {
|
|
isStaticEdge := len(config.EdgeAddrs) > 0
|
|
|
|
var err error
|
|
var edgeIPs *edgediscovery.Edge
|
|
if isStaticEdge { // static edge addresses
|
|
edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs)
|
|
} else {
|
|
edgeIPs, err = edgediscovery.ResolveEdge(config.Log, config.Region, config.EdgeIPVersion)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tracker := tunnelstate.NewConnTracker(config.Log)
|
|
log := NewConnAwareLogger(config.Log, tracker, config.Observer)
|
|
|
|
edgeAddrHandler := NewIPAddrFallback(config.MaxEdgeAddrRetries)
|
|
edgeBindAddr := config.EdgeBindAddr
|
|
|
|
datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer)
|
|
|
|
sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, config.OriginDialerService, orchestrator.GetFlowLimiter())
|
|
|
|
edgeTunnelServer := EdgeTunnelServer{
|
|
config: config,
|
|
orchestrator: orchestrator,
|
|
sessionManager: sessionManager,
|
|
datagramMetrics: datagramMetrics,
|
|
edgeAddrs: edgeIPs,
|
|
edgeAddrHandler: edgeAddrHandler,
|
|
edgeBindAddr: edgeBindAddr,
|
|
tracker: tracker,
|
|
reconnectCh: reconnectCh,
|
|
gracefulShutdownC: gracefulShutdownC,
|
|
connAwareLogger: log,
|
|
}
|
|
|
|
return &Supervisor{
|
|
config: config,
|
|
orchestrator: orchestrator,
|
|
edgeIPs: edgeIPs,
|
|
edgeTunnelServer: &edgeTunnelServer,
|
|
tunnelErrors: make(chan tunnelError),
|
|
tunnelsConnecting: map[int]chan struct{}{},
|
|
tunnelsProtocolFallback: map[int]*protocolFallback{},
|
|
log: log,
|
|
logTransport: config.LogTransport,
|
|
reconnectCh: reconnectCh,
|
|
gracefulShutdownC: gracefulShutdownC,
|
|
}, nil
|
|
}
|
|
|
|
func (s *Supervisor) Run(
|
|
ctx context.Context,
|
|
connectedSignal *signal.Signal,
|
|
) error {
|
|
if s.config.ICMPRouterServer != nil {
|
|
go func() {
|
|
if err := s.config.ICMPRouterServer.Serve(ctx); err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
s.log.Logger().Info().Err(err).Msg("icmp router terminated")
|
|
} else {
|
|
s.log.Logger().Err(err).Msg("icmp router terminated")
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Setup DNS Resolver refresh
|
|
go s.config.OriginDNSService.StartRefreshLoop(ctx)
|
|
|
|
if err := s.initialize(ctx, connectedSignal); err != nil {
|
|
if err == errEarlyShutdown {
|
|
return nil
|
|
}
|
|
s.log.Logger().Error().Err(err).Msg("initial tunnel connection failed")
|
|
return err
|
|
}
|
|
var tunnelsWaiting []int
|
|
tunnelsActive := s.config.HAConnections
|
|
|
|
backoff := retry.NewBackoff(s.config.Retries, tunnelRetryDuration, true)
|
|
var backoffTimer <-chan time.Time
|
|
|
|
shuttingDown := false
|
|
for {
|
|
select {
|
|
// Context cancelled
|
|
case <-ctx.Done():
|
|
for tunnelsActive > 0 {
|
|
<-s.tunnelErrors
|
|
tunnelsActive--
|
|
}
|
|
return nil
|
|
// startTunnel completed with a response
|
|
// (note that this may also be caused by context cancellation)
|
|
case tunnelError := <-s.tunnelErrors:
|
|
tunnelsActive--
|
|
s.log.ConnAwareLogger().Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
|
|
if tunnelError.err != nil && !shuttingDown {
|
|
switch tunnelError.err.(type) {
|
|
case ReconnectSignal:
|
|
// For tunnels that closed with reconnect signal, we reconnect immediately
|
|
go s.startTunnel(ctx, tunnelError.index, s.newConnectedTunnelSignal(tunnelError.index))
|
|
tunnelsActive++
|
|
continue
|
|
}
|
|
// Make sure we don't continue if there is no more fallback allowed
|
|
if _, retry := s.tunnelsProtocolFallback[tunnelError.index].GetMaxBackoffDuration(ctx); !retry {
|
|
continue
|
|
}
|
|
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
|
s.waitForNextTunnel(tunnelError.index)
|
|
|
|
if backoffTimer == nil {
|
|
backoffTimer = backoff.BackoffTimer()
|
|
}
|
|
} else if tunnelsActive == 0 {
|
|
s.log.ConnAwareLogger().Msg("no more connections active and exiting")
|
|
// All connected tunnels exited gracefully, no more work to do
|
|
return nil
|
|
}
|
|
// Backoff was set and its timer expired
|
|
case <-backoffTimer:
|
|
backoffTimer = nil
|
|
for _, index := range tunnelsWaiting {
|
|
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
|
|
}
|
|
tunnelsActive += len(tunnelsWaiting)
|
|
tunnelsWaiting = nil
|
|
// Tunnel successfully connected
|
|
case <-s.nextConnectedSignal:
|
|
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
|
|
// No more tunnels outstanding, clear backoff timer
|
|
backoff.SetGracePeriod()
|
|
}
|
|
case <-s.gracefulShutdownC:
|
|
shuttingDown = true
|
|
}
|
|
}
|
|
}
|
|
|
|
// Returns nil if initialization succeeded, else the initialization error.
|
|
// Attempts here will be made to connect one tunnel, if successful, it will
|
|
// connect the available tunnels up to config.HAConnections.
|
|
func (s *Supervisor) initialize(
|
|
ctx context.Context,
|
|
connectedSignal *signal.Signal,
|
|
) error {
|
|
availableAddrs := s.edgeIPs.AvailableAddrs()
|
|
if s.config.HAConnections > availableAddrs {
|
|
s.log.Logger().Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
|
|
s.config.HAConnections = availableAddrs
|
|
}
|
|
s.tunnelsProtocolFallback[0] = &protocolFallback{
|
|
retry.NewBackoff(s.config.Retries, retry.DefaultBaseTime, true),
|
|
s.config.ProtocolSelector.Current(),
|
|
false,
|
|
}
|
|
|
|
go s.startFirstTunnel(ctx, connectedSignal)
|
|
|
|
// Wait for response from first tunnel before proceeding to attempt other HA edge tunnels
|
|
select {
|
|
case <-ctx.Done():
|
|
<-s.tunnelErrors
|
|
return ctx.Err()
|
|
case tunnelError := <-s.tunnelErrors:
|
|
return tunnelError.err
|
|
case <-s.gracefulShutdownC:
|
|
return errEarlyShutdown
|
|
case <-connectedSignal.Wait():
|
|
}
|
|
|
|
// At least one successful connection, so start the rest
|
|
for i := 1; i < s.config.HAConnections; i++ {
|
|
s.tunnelsProtocolFallback[i] = &protocolFallback{
|
|
retry.NewBackoff(s.config.Retries, retry.DefaultBaseTime, true),
|
|
// Set the protocol we know the first tunnel connected with.
|
|
s.tunnelsProtocolFallback[0].protocol,
|
|
false,
|
|
}
|
|
go s.startTunnel(ctx, i, s.newConnectedTunnelSignal(i))
|
|
time.Sleep(registrationInterval)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
|
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
|
func (s *Supervisor) startFirstTunnel(
|
|
ctx context.Context,
|
|
connectedSignal *signal.Signal,
|
|
) {
|
|
var err error
|
|
const firstConnIndex = 0
|
|
isStaticEdge := len(s.config.EdgeAddrs) > 0
|
|
defer func() {
|
|
s.tunnelErrors <- tunnelError{index: firstConnIndex, err: err}
|
|
}()
|
|
|
|
// If the first tunnel disconnects, keep restarting it.
|
|
for {
|
|
err = s.edgeTunnelServer.Serve(ctx, firstConnIndex, s.tunnelsProtocolFallback[firstConnIndex], connectedSignal)
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
if err == nil {
|
|
return
|
|
}
|
|
// Make sure we don't continue if there is no more fallback allowed
|
|
if _, retry := s.tunnelsProtocolFallback[firstConnIndex].GetMaxBackoffDuration(ctx); !retry {
|
|
return
|
|
}
|
|
// Try again for Unauthorized errors because we hope them to be
|
|
// transient due to edge propagation lag on new Tunnels.
|
|
if strings.Contains(err.Error(), "Unauthorized") {
|
|
continue
|
|
}
|
|
switch err.(type) {
|
|
case edgediscovery.ErrNoAddressesLeft:
|
|
// If your provided addresses are not available, we will keep trying regardless.
|
|
if !isStaticEdge {
|
|
return
|
|
}
|
|
case connection.DupConnRegisterTunnelError,
|
|
*quic.IdleTimeoutError,
|
|
*quic.ApplicationError,
|
|
edgediscovery.DialError,
|
|
*connection.EdgeQuicDialError,
|
|
*connection.ControlStreamError,
|
|
*connection.StreamListenerError,
|
|
*connection.DatagramManagerError:
|
|
// Try again for these types of errors
|
|
default:
|
|
// Uncaught errors should bail startup
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
|
// s.tunnelError as this is expected to run in a goroutine.
|
|
func (s *Supervisor) startTunnel(
|
|
ctx context.Context,
|
|
index int,
|
|
connectedSignal *signal.Signal,
|
|
) {
|
|
// nolint: gosec
|
|
err := s.edgeTunnelServer.Serve(ctx, uint8(index), s.tunnelsProtocolFallback[index], connectedSignal)
|
|
s.tunnelErrors <- tunnelError{index: index, err: err}
|
|
}
|
|
|
|
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
|
|
sig := make(chan struct{})
|
|
s.tunnelsConnecting[index] = sig
|
|
s.nextConnectedSignal = sig
|
|
s.nextConnectedIndex = index
|
|
return signal.New(sig)
|
|
}
|
|
|
|
func (s *Supervisor) waitForNextTunnel(index int) bool {
|
|
delete(s.tunnelsConnecting, index)
|
|
s.nextConnectedSignal = nil
|
|
for k, v := range s.tunnelsConnecting {
|
|
s.nextConnectedIndex = k
|
|
s.nextConnectedSignal = v
|
|
return true
|
|
}
|
|
return false
|
|
}
|