From b89c092c1b1cf843e59e6bb43f684efe3d263403 Mon Sep 17 00:00:00 2001 From: Devin Carr Date: Wed, 12 Apr 2023 09:43:38 -0700 Subject: [PATCH] TUN-7134: Acquire token for cloudflared tail cloudflared tail will now fetch the management token from by making a request to the Cloudflare API using the cert.pem (acquired from cloudflared login). Refactored some of the credentials code into it's own package as to allow for easier use between subcommands outside of `cloudflared tunnel`. --- certutil/certutil.go | 58 -------- certutil/certutil_test.go | 51 ------- cfapi/client.go | 1 + cfapi/tunnel.go | 26 ++++ cmd/cloudflared/cliutil/build_info.go | 4 + cmd/cloudflared/main.go | 2 +- cmd/cloudflared/tail/cmd.go | 91 ++++++++++-- cmd/cloudflared/tunnel/cmd.go | 5 +- cmd/cloudflared/tunnel/configuration.go | 72 ---------- cmd/cloudflared/tunnel/credential_finder.go | 7 +- cmd/cloudflared/tunnel/login.go | 3 +- cmd/cloudflared/tunnel/subcommand_context.go | 59 ++------ .../tunnel/subcommand_context_test.go | 5 +- config/configuration.go | 2 - credentials/credentials.go | 83 +++++++++++ credentials/credentials_test.go | 38 +++++ credentials/origin_cert.go | 130 ++++++++++++++++++ credentials/origin_cert_test.go | 110 +++++++++++++++ .../test-cert-no-token.pem | 0 .../test-cert-unknown-block.pem | 0 .../test-cloudflare-tunnel-cert-json.pem | 0 21 files changed, 497 insertions(+), 250 deletions(-) delete mode 100644 certutil/certutil.go delete mode 100644 certutil/certutil_test.go create mode 100644 credentials/credentials.go create mode 100644 credentials/credentials_test.go create mode 100644 credentials/origin_cert.go create mode 100644 credentials/origin_cert_test.go rename {certutil => credentials}/test-cert-no-token.pem (100%) rename {certutil => credentials}/test-cert-unknown-block.pem (100%) rename {certutil => credentials}/test-cloudflare-tunnel-cert-json.pem (100%) diff --git a/certutil/certutil.go b/certutil/certutil.go deleted file mode 100644 index 951926bb..00000000 --- a/certutil/certutil.go +++ /dev/null @@ -1,58 +0,0 @@ -package certutil - -import ( - "encoding/json" - "encoding/pem" - "fmt" -) - -type namedTunnelToken struct { - ZoneID string `json:"zoneID"` - AccountID string `json:"accountID"` - APIToken string `json:"apiToken"` -} - -type OriginCert struct { - ZoneID string - APIToken string - AccountID string -} - -func DecodeOriginCert(blocks []byte) (*OriginCert, error) { - if len(blocks) == 0 { - return nil, fmt.Errorf("Cannot decode empty certificate") - } - originCert := OriginCert{} - block, rest := pem.Decode(blocks) - for { - if block == nil { - break - } - switch block.Type { - case "PRIVATE KEY", "CERTIFICATE": - // this is for legacy purposes. - break - case "ARGO TUNNEL TOKEN": - if originCert.ZoneID != "" || originCert.APIToken != "" { - return nil, fmt.Errorf("Found multiple tokens in the certificate") - } - // The token is a string, - // Try the newer JSON format - ntt := namedTunnelToken{} - if err := json.Unmarshal(block.Bytes, &ntt); err == nil { - originCert.ZoneID = ntt.ZoneID - originCert.APIToken = ntt.APIToken - originCert.AccountID = ntt.AccountID - } - default: - return nil, fmt.Errorf("Unknown block %s in the certificate", block.Type) - } - block, rest = pem.Decode(rest) - } - - if originCert.ZoneID == "" || originCert.APIToken == "" { - return nil, fmt.Errorf("Missing token in the certificate") - } - - return &originCert, nil -} diff --git a/certutil/certutil_test.go b/certutil/certutil_test.go deleted file mode 100644 index e48ffcf3..00000000 --- a/certutil/certutil_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package certutil - -import ( - "fmt" - "io/ioutil" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestLoadOriginCert(t *testing.T) { - cert, err := DecodeOriginCert([]byte{}) - assert.Equal(t, fmt.Errorf("Cannot decode empty certificate"), err) - assert.Nil(t, cert) - - blocks, err := ioutil.ReadFile("test-cert-unknown-block.pem") - assert.Nil(t, err) - cert, err = DecodeOriginCert(blocks) - assert.Equal(t, fmt.Errorf("Unknown block RSA PRIVATE KEY in the certificate"), err) - assert.Nil(t, cert) -} - -func TestJSONArgoTunnelTokenEmpty(t *testing.T) { - cert, err := DecodeOriginCert([]byte{}) - blocks, err := ioutil.ReadFile("test-cert-no-token.pem") - assert.Nil(t, err) - cert, err = DecodeOriginCert(blocks) - assert.Equal(t, fmt.Errorf("Missing token in the certificate"), err) - assert.Nil(t, cert) -} - -func TestJSONArgoTunnelToken(t *testing.T) { - // The given cert's Argo Tunnel Token was generated by base64 encoding this JSON: - // { - // "zoneID": "7b0a4d77dfb881c1a3b7d61ea9443e19", - // "apiToken": "test-service-key", - // "accountID": "abcdabcdabcdabcd1234567890abcdef" - // } - CloudflareTunnelTokenTest(t, "test-cloudflare-tunnel-cert-json.pem") -} - -func CloudflareTunnelTokenTest(t *testing.T, path string) { - blocks, err := ioutil.ReadFile(path) - assert.Nil(t, err) - cert, err := DecodeOriginCert(blocks) - assert.Nil(t, err) - assert.NotNil(t, cert) - assert.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", cert.ZoneID) - key := "test-service-key" - assert.Equal(t, key, cert.APIToken) -} diff --git a/cfapi/client.go b/cfapi/client.go index 192d64ce..f8c2a734 100644 --- a/cfapi/client.go +++ b/cfapi/client.go @@ -8,6 +8,7 @@ type TunnelClient interface { CreateTunnel(name string, tunnelSecret []byte) (*TunnelWithToken, error) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) GetTunnelToken(tunnelID uuid.UUID) (string, error) + GetManagementToken(tunnelID uuid.UUID) (string, error) DeleteTunnel(tunnelID uuid.UUID) error ListTunnels(filter *TunnelFilter) ([]*Tunnel, error) ListActiveClients(tunnelID uuid.UUID) ([]*ActiveClient, error) diff --git a/cfapi/tunnel.go b/cfapi/tunnel.go index 87caf230..fa6f8f33 100644 --- a/cfapi/tunnel.go +++ b/cfapi/tunnel.go @@ -50,6 +50,10 @@ type newTunnel struct { TunnelSecret []byte `json:"tunnel_secret"` } +type managementRequest struct { + Resources []string `json:"resources"` +} + type CleanupParams struct { queryParams url.Values } @@ -133,6 +137,28 @@ func (r *RESTClient) GetTunnelToken(tunnelID uuid.UUID) (token string, err error return "", r.statusCodeToError("get tunnel token", resp) } +func (r *RESTClient) GetManagementToken(tunnelID uuid.UUID) (token string, err error) { + endpoint := r.baseEndpoints.accountLevel + endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/management", tunnelID)) + + body := &managementRequest{ + Resources: []string{"logs"}, + } + + resp, err := r.sendRequest("POST", endpoint, body) + if err != nil { + return "", errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + err = parseResponse(resp.Body, &token) + return token, err + } + + return "", r.statusCodeToError("get tunnel token", resp) +} + func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error { endpoint := r.baseEndpoints.accountLevel endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID)) diff --git a/cmd/cloudflared/cliutil/build_info.go b/cmd/cloudflared/cliutil/build_info.go index 4d73701e..fff4febf 100644 --- a/cmd/cloudflared/cliutil/build_info.go +++ b/cmd/cloudflared/cliutil/build_info.go @@ -47,3 +47,7 @@ func (bi *BuildInfo) GetBuildTypeMsg() string { } return fmt.Sprintf(" with %s", bi.BuildType) } + +func (bi *BuildInfo) UserAgent() string { + return fmt.Sprintf("cloudflared/%s", bi.CloudflaredVersion) +} diff --git a/cmd/cloudflared/main.go b/cmd/cloudflared/main.go index f729d55a..3ed25a62 100644 --- a/cmd/cloudflared/main.go +++ b/cmd/cloudflared/main.go @@ -90,7 +90,7 @@ func main() { updater.Init(Version) tracing.Init(Version) token.Init(Version) - tail.Init(Version) + tail.Init(bInfo) runApp(app, graceShutdownC) } diff --git a/cmd/cloudflared/tail/cmd.go b/cmd/cloudflared/tail/cmd.go index 55864e90..24d7fb65 100644 --- a/cmd/cloudflared/tail/cmd.go +++ b/cmd/cloudflared/tail/cmd.go @@ -2,6 +2,7 @@ package tail import ( "encoding/json" + "errors" "fmt" "net/http" "net/url" @@ -10,28 +11,32 @@ import ( "syscall" "time" + "github.com/google/uuid" "github.com/mattn/go-colorable" "github.com/rs/zerolog" "github.com/urfave/cli/v2" "nhooyr.io/websocket" + "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" + "github.com/cloudflare/cloudflared/credentials" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/management" ) var ( - version string + buildInfo *cliutil.BuildInfo ) -func Init(v string) { - version = v +func Init(bi *cliutil.BuildInfo) { + buildInfo = bi } func Command() *cli.Command { return &cli.Command{ - Name: "tail", - Action: Run, - Usage: "Stream logs from a remote cloudflared", + Name: "tail", + Action: Run, + Usage: "Stream logs from a remote cloudflared", + UsageText: "cloudflared tail [tail command options] [TUNNEL-ID]", Flags: []cli.Flag{ &cli.StringFlag{ Name: "connector-id", @@ -75,6 +80,12 @@ func Command() *cli.Command { Usage: "Application logging level {debug, info, warn, error, fatal}", EnvVars: []string{"TUNNEL_LOGLEVEL"}, }, + &cli.StringFlag{ + Name: credentials.OriginCertFlag, + Usage: "Path to the certificate generated for your origin when you run cloudflared login.", + EnvVars: []string{"TUNNEL_ORIGIN_CERT"}, + Value: credentials.FindDefaultOriginCertPath(), + }, }, } } @@ -159,6 +170,59 @@ func parseFilters(c *cli.Context) (*management.StreamingFilters, error) { }, nil } +// getManagementToken will make a call to the Cloudflare API to acquire a management token for the requested tunnel. +func getManagementToken(c *cli.Context, log *zerolog.Logger) (string, error) { + userCreds, err := credentials.Read(c.String(credentials.OriginCertFlag), log) + if err != nil { + return "", err + } + + client, err := userCreds.Client(c.String("api-url"), buildInfo.UserAgent(), log) + if err != nil { + return "", err + } + + tunnelIDString := c.Args().First() + if tunnelIDString == "" { + return "", errors.New("no tunnel ID provided") + } + tunnelID, err := uuid.Parse(tunnelIDString) + if err != nil { + return "", errors.New("unable to parse provided tunnel id as a valid UUID") + } + + token, err := client.GetManagementToken(tunnelID) + if err != nil { + return "", err + } + + return token, nil +} + +// buildURL will build the management url to contain the required query parameters to authenticate the request. +func buildURL(c *cli.Context, log *zerolog.Logger) (url.URL, error) { + var err error + managementHostname := c.String("management-hostname") + token := c.String("token") + if token == "" { + token, err = getManagementToken(c, log) + if err != nil { + return url.URL{}, fmt.Errorf("unable to acquire management token for requested tunnel id: %w", err) + } + } + query := url.Values{} + query.Add("access_token", token) + connector := c.String("connector-id") + if connector != "" { + connectorID, err := uuid.Parse(connector) + if err != nil { + return url.URL{}, fmt.Errorf("unabled to parse 'connector-id' flag into a valid UUID: %w", err) + } + query.Add("connector_id", connectorID.String()) + } + return url.URL{Scheme: "wss", Host: managementHostname, Path: "/logs", RawQuery: query.Encode()}, nil +} + // Run implements a foreground runner func Run(c *cli.Context) error { log := createLogger(c) @@ -173,12 +237,14 @@ func Run(c *cli.Context) error { return nil } - managementHostname := c.String("management-hostname") - token := c.String("token") - u := url.URL{Scheme: "wss", Host: managementHostname, Path: "/logs", RawQuery: "access_token=" + token} + u, err := buildURL(c, log) + if err != nil { + log.Err(err).Msg("unable to construct management request URL") + return nil + } header := make(http.Header) - header.Add("User-Agent", "cloudflared/"+version) + header.Add("User-Agent", buildInfo.UserAgent()) trace := c.String("trace") if trace != "" { header["cf-trace-id"] = []string{trace} @@ -206,6 +272,11 @@ func Run(c *cli.Context) error { log.Error().Err(err).Msg("unable to request logs from management tunnel") return nil } + log.Debug(). + Str("tunnel-id", c.Args().First()). + Str("connector-id", c.String("connector-id")). + Interface("filters", filters). + Msg("connected") readerDone := make(chan struct{}) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 17c177f2..0eb5b328 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -28,6 +28,7 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/updater" "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/credentials" "github.com/cloudflare/cloudflared/features" "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/logger" @@ -751,10 +752,10 @@ func configureCloudflaredFlags(shouldHide bool) []cli.Flag { Hidden: shouldHide, }, altsrc.NewStringFlag(&cli.StringFlag{ - Name: "origincert", + Name: credentials.OriginCertFlag, Usage: "Path to the certificate generated for your origin when you run cloudflared login.", EnvVars: []string{"TUNNEL_ORIGIN_CERT"}, - Value: findDefaultOriginCertPath(), + Value: credentials.FindDefaultOriginCertPath(), Hidden: shouldHide, }), altsrc.NewDurationFlag(&cli.DurationFlag{ diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index adcc82ef..89b76392 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -3,17 +3,14 @@ package tunnel import ( "crypto/tls" "fmt" - "io/ioutil" mathRand "math/rand" "net" "net/netip" "os" - "path/filepath" "strings" "time" "github.com/google/uuid" - homedir "github.com/mitchellh/go-homedir" "github.com/pkg/errors" "github.com/rs/zerolog" "github.com/urfave/cli/v2" @@ -33,7 +30,6 @@ import ( tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) -const LogFieldOriginCertPath = "originCertPath" const secretValue = "*****" var ( @@ -46,18 +42,6 @@ var ( configFlags = []string{"autoupdate-freq", "no-autoupdate", "retries", "protocol", "loglevel", "transport-loglevel", "origincert", "metrics", "metrics-update-freq", "edge-ip-version", "edge-bind-address"} ) -// returns the first path that contains a cert.pem file. If none of the DefaultConfigSearchDirectories -// contains a cert.pem file, return empty string -func findDefaultOriginCertPath() string { - for _, defaultConfigDir := range config.DefaultConfigSearchDirectories() { - originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, config.DefaultCredentialFile)) - if ok, _ := config.FileExists(originCertPath); ok { - return originCertPath - } - } - return "" -} - func generateRandomClientID(log *zerolog.Logger) (string, error) { u, err := uuid.NewRandom() if err != nil { @@ -128,62 +112,6 @@ func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelPrope namedTunnel != nil) // named tunnel } -func findOriginCert(originCertPath string, log *zerolog.Logger) (string, error) { - if originCertPath == "" { - log.Info().Msgf("Cannot determine default origin certificate path. No file %s in %v", config.DefaultCredentialFile, config.DefaultConfigSearchDirectories()) - if isRunningFromTerminal() { - log.Error().Msgf("You need to specify the origin certificate path with --origincert option, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", argumentsUrl) - return "", fmt.Errorf("client didn't specify origincert path when running from terminal") - } else { - log.Error().Msgf("You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", serviceUrl) - return "", fmt.Errorf("client didn't specify origincert path") - } - } - var err error - originCertPath, err = homedir.Expand(originCertPath) - if err != nil { - log.Err(err).Msgf("Cannot resolve origin certificate path") - return "", fmt.Errorf("cannot resolve path %s", originCertPath) - } - // Check that the user has acquired a certificate using the login command - ok, err := config.FileExists(originCertPath) - if err != nil { - log.Error().Err(err).Msgf("Cannot check if origin cert exists at path %s", originCertPath) - return "", fmt.Errorf("cannot check if origin cert exists at path %s", originCertPath) - } - if !ok { - log.Error().Msgf(`Cannot find a valid certificate for your origin at the path: - - %s - -If the path above is wrong, specify the path with the -origincert option. -If you don't have a certificate signed by Cloudflare, run the command: - - %s login -`, originCertPath, os.Args[0]) - return "", fmt.Errorf("cannot find a valid certificate at the path %s", originCertPath) - } - - return originCertPath, nil -} - -func readOriginCert(originCertPath string) ([]byte, error) { - // Easier to send the certificate as []byte via RPC than decoding it at this point - originCert, err := ioutil.ReadFile(originCertPath) - if err != nil { - return nil, fmt.Errorf("cannot read %s to load origin certificate", originCertPath) - } - return originCert, nil -} - -func getOriginCert(originCertPath string, log *zerolog.Logger) ([]byte, error) { - if originCertPath, err := findOriginCert(originCertPath, log); err != nil { - return nil, err - } else { - return readOriginCert(originCertPath) - } -} - func prepareTunnelConfig( c *cli.Context, info *cliutil.BuildInfo, diff --git a/cmd/cloudflared/tunnel/credential_finder.go b/cmd/cloudflared/tunnel/credential_finder.go index a2320af4..92e05495 100644 --- a/cmd/cloudflared/tunnel/credential_finder.go +++ b/cmd/cloudflared/tunnel/credential_finder.go @@ -5,6 +5,7 @@ import ( "path/filepath" "github.com/cloudflare/cloudflared/config" + "github.com/cloudflare/cloudflared/credentials" "github.com/google/uuid" "github.com/rs/zerolog" @@ -56,13 +57,13 @@ func newSearchByID(id uuid.UUID, c *cli.Context, log *zerolog.Logger, fs fileSys } func (s searchByID) Path() (string, error) { - originCertPath := s.c.String("origincert") + originCertPath := s.c.String(credentials.OriginCertFlag) originCertLog := s.log.With(). - Str(LogFieldOriginCertPath, originCertPath). + Str("originCertPath", originCertPath). Logger() // Fallback to look for tunnel credentials in the origin cert directory - if originCertPath, err := findOriginCert(originCertPath, &originCertLog); err == nil { + if originCertPath, err := credentials.FindOriginCert(originCertPath, &originCertLog); err == nil { originCertDir := filepath.Dir(originCertPath) if filePath, err := tunnelFilePath(s.id, originCertDir); err == nil { if s.fs.validFilePath(filePath) { diff --git a/cmd/cloudflared/tunnel/login.go b/cmd/cloudflared/tunnel/login.go index 8b519147..dd0f8fe9 100644 --- a/cmd/cloudflared/tunnel/login.go +++ b/cmd/cloudflared/tunnel/login.go @@ -14,6 +14,7 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/config" + "github.com/cloudflare/cloudflared/credentials" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/token" ) @@ -85,7 +86,7 @@ func checkForExistingCert() (string, bool, error) { if err != nil { return "", false, err } - path := filepath.Join(configPath, config.DefaultCredentialFile) + path := filepath.Join(configPath, credentials.DefaultCredentialFile) fileInfo, err := os.Stat(path) if err == nil && fileInfo.Size() > 0 { return path, true, nil diff --git a/cmd/cloudflared/tunnel/subcommand_context.go b/cmd/cloudflared/tunnel/subcommand_context.go index bc65aced..f49c15eb 100644 --- a/cmd/cloudflared/tunnel/subcommand_context.go +++ b/cmd/cloudflared/tunnel/subcommand_context.go @@ -13,9 +13,9 @@ import ( "github.com/rs/zerolog" "github.com/urfave/cli/v2" - "github.com/cloudflare/cloudflared/certutil" "github.com/cloudflare/cloudflared/cfapi" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/credentials" "github.com/cloudflare/cloudflared/logger" ) @@ -37,7 +37,7 @@ type subcommandContext struct { // These fields should be accessed using their respective Getter tunnelstoreClient cfapi.Client - userCredential *userCredential + userCredential *credentials.User } func newSubcommandContext(c *cli.Context) (*subcommandContext, error) { @@ -56,65 +56,28 @@ func (sc *subcommandContext) credentialFinder(tunnelID uuid.UUID) CredFinder { return newSearchByID(tunnelID, sc.c, sc.log, sc.fs) } -type userCredential struct { - cert *certutil.OriginCert - certPath string -} - func (sc *subcommandContext) client() (cfapi.Client, error) { if sc.tunnelstoreClient != nil { return sc.tunnelstoreClient, nil } - credential, err := sc.credential() + cred, err := sc.credential() if err != nil { return nil, err } - userAgent := fmt.Sprintf("cloudflared/%s", buildInfo.Version()) - client, err := cfapi.NewRESTClient( - sc.c.String("api-url"), - credential.cert.AccountID, - credential.cert.ZoneID, - credential.cert.APIToken, - userAgent, - sc.log, - ) - + sc.tunnelstoreClient, err = cred.Client(sc.c.String("api-url"), buildInfo.UserAgent(), sc.log) if err != nil { return nil, err } - sc.tunnelstoreClient = client - return client, nil + return sc.tunnelstoreClient, nil } -func (sc *subcommandContext) credential() (*userCredential, error) { +func (sc *subcommandContext) credential() (*credentials.User, error) { if sc.userCredential == nil { - originCertPath := sc.c.String("origincert") - originCertLog := sc.log.With(). - Str(LogFieldOriginCertPath, originCertPath). - Logger() - - originCertPath, err := findOriginCert(originCertPath, &originCertLog) + uc, err := credentials.Read(sc.c.String(credentials.OriginCertFlag), sc.log) if err != nil { - return nil, errors.Wrap(err, "Error locating origin cert") - } - blocks, err := readOriginCert(originCertPath) - if err != nil { - return nil, errors.Wrapf(err, "Can't read origin cert from %s", originCertPath) - } - - cert, err := certutil.DecodeOriginCert(blocks) - if err != nil { - return nil, errors.Wrap(err, "Error decoding origin cert") - } - - if cert.AccountID == "" { - return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath) - } - - sc.userCredential = &userCredential{ - cert: cert, - certPath: originCertPath, + return nil, err } + sc.userCredential = uc } return sc.userCredential, nil } @@ -175,13 +138,13 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string, sec return nil, err } tunnelCredentials := connection.Credentials{ - AccountTag: credential.cert.AccountID, + AccountTag: credential.AccountID(), TunnelSecret: tunnelSecret, TunnelID: tunnel.ID, } usedCertPath := false if credentialsFilePath == "" { - originCertDir := filepath.Dir(credential.certPath) + originCertDir := filepath.Dir(credential.CertPath()) credentialsFilePath, err = tunnelFilePath(tunnelCredentials.TunnelID, originCertDir) if err != nil { return nil, err diff --git a/cmd/cloudflared/tunnel/subcommand_context_test.go b/cmd/cloudflared/tunnel/subcommand_context_test.go index 35cc46e7..c2293463 100644 --- a/cmd/cloudflared/tunnel/subcommand_context_test.go +++ b/cmd/cloudflared/tunnel/subcommand_context_test.go @@ -16,6 +16,7 @@ import ( "github.com/cloudflare/cloudflared/cfapi" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/credentials" ) type mockFileSystem struct { @@ -37,7 +38,7 @@ func Test_subcommandContext_findCredentials(t *testing.T) { log *zerolog.Logger fs fileSystem tunnelstoreClient cfapi.Client - userCredential *userCredential + userCredential *credentials.User } type args struct { tunnelID uuid.UUID @@ -249,7 +250,7 @@ func Test_subcommandContext_Delete(t *testing.T) { isUIEnabled bool fs fileSystem tunnelstoreClient *deleteMockTunnelStore - userCredential *userCredential + userCredential *credentials.User } type args struct { tunnelIDs []uuid.UUID diff --git a/config/configuration.go b/config/configuration.go index 70bf163a..73d45fbc 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -39,8 +39,6 @@ var ( ) const ( - DefaultCredentialFile = "cert.pem" - // BastionFlag is to enable bastion, or jump host, operation BastionFlag = "bastion" ) diff --git a/credentials/credentials.go b/credentials/credentials.go new file mode 100644 index 00000000..8d1d8908 --- /dev/null +++ b/credentials/credentials.go @@ -0,0 +1,83 @@ +package credentials + +import ( + "github.com/pkg/errors" + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/cfapi" +) + +const ( + logFieldOriginCertPath = "originCertPath" +) + +type User struct { + cert *OriginCert + certPath string +} + +func (c User) AccountID() string { + return c.cert.AccountID +} + +func (c User) ZoneID() string { + return c.cert.ZoneID +} + +func (c User) APIToken() string { + return c.cert.APIToken +} + +func (c User) CertPath() string { + return c.certPath +} + +// Client uses the user credentials to create a Cloudflare API client +func (c *User) Client(apiURL string, userAgent string, log *zerolog.Logger) (cfapi.Client, error) { + if apiURL == "" { + return nil, errors.New("An api-url was not provided for the Cloudflare API client") + } + client, err := cfapi.NewRESTClient( + apiURL, + c.cert.AccountID, + c.cert.ZoneID, + c.cert.APIToken, + userAgent, + log, + ) + + if err != nil { + return nil, err + } + return client, nil +} + +// Read will load and read the origin cert.pem to load the user credentials +func Read(originCertPath string, log *zerolog.Logger) (*User, error) { + originCertLog := log.With(). + Str(logFieldOriginCertPath, originCertPath). + Logger() + + originCertPath, err := FindOriginCert(originCertPath, &originCertLog) + if err != nil { + return nil, errors.Wrap(err, "Error locating origin cert") + } + blocks, err := readOriginCert(originCertPath) + if err != nil { + return nil, errors.Wrapf(err, "Can't read origin cert from %s", originCertPath) + } + + cert, err := decodeOriginCert(blocks) + if err != nil { + return nil, errors.Wrap(err, "Error decoding origin cert") + } + + if cert.AccountID == "" { + return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath) + } + + return &User{ + cert: cert, + certPath: originCertPath, + }, nil +} diff --git a/credentials/credentials_test.go b/credentials/credentials_test.go new file mode 100644 index 00000000..d9b2d7b7 --- /dev/null +++ b/credentials/credentials_test.go @@ -0,0 +1,38 @@ +package credentials + +import ( + "io/fs" + "os" + "path" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCredentialsRead(t *testing.T) { + file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem") + require.NoError(t, err) + dir := t.TempDir() + certPath := path.Join(dir, originCertFile) + os.WriteFile(certPath, file, fs.ModePerm) + user, err := Read(certPath, &nopLog) + require.NoError(t, err) + require.Equal(t, certPath, user.CertPath()) + require.Equal(t, "test-service-key", user.APIToken()) + require.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", user.ZoneID()) + require.Equal(t, "abcdabcdabcdabcd1234567890abcdef", user.AccountID()) +} + +func TestCredentialsClient(t *testing.T) { + user := User{ + certPath: "/tmp/cert.pem", + cert: &OriginCert{ + ZoneID: "7b0a4d77dfb881c1a3b7d61ea9443e19", + AccountID: "abcdabcdabcdabcd1234567890abcdef", + APIToken: "test-service-key", + }, + } + client, err := user.Client("example.com", "cloudflared/test", &nopLog) + require.NoError(t, err) + require.NotNil(t, client) +} diff --git a/credentials/origin_cert.go b/credentials/origin_cert.go new file mode 100644 index 00000000..73a59fa3 --- /dev/null +++ b/credentials/origin_cert.go @@ -0,0 +1,130 @@ +package credentials + +import ( + "encoding/json" + "encoding/pem" + "fmt" + "os" + "path/filepath" + + "github.com/mitchellh/go-homedir" + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/config" +) + +const ( + DefaultCredentialFile = "cert.pem" + OriginCertFlag = "origincert" +) + +type namedTunnelToken struct { + ZoneID string `json:"zoneID"` + AccountID string `json:"accountID"` + APIToken string `json:"apiToken"` +} + +type OriginCert struct { + ZoneID string + APIToken string + AccountID string +} + +// FindDefaultOriginCertPath returns the first path that contains a cert.pem file. If none of the +// DefaultConfigSearchDirectories contains a cert.pem file, return empty string +func FindDefaultOriginCertPath() string { + for _, defaultConfigDir := range config.DefaultConfigSearchDirectories() { + originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, DefaultCredentialFile)) + if ok := fileExists(originCertPath); ok { + return originCertPath + } + } + return "" +} + +func decodeOriginCert(blocks []byte) (*OriginCert, error) { + if len(blocks) == 0 { + return nil, fmt.Errorf("Cannot decode empty certificate") + } + originCert := OriginCert{} + block, rest := pem.Decode(blocks) + for { + if block == nil { + break + } + switch block.Type { + case "PRIVATE KEY", "CERTIFICATE": + // this is for legacy purposes. + break + case "ARGO TUNNEL TOKEN": + if originCert.ZoneID != "" || originCert.APIToken != "" { + return nil, fmt.Errorf("Found multiple tokens in the certificate") + } + // The token is a string, + // Try the newer JSON format + ntt := namedTunnelToken{} + if err := json.Unmarshal(block.Bytes, &ntt); err == nil { + originCert.ZoneID = ntt.ZoneID + originCert.APIToken = ntt.APIToken + originCert.AccountID = ntt.AccountID + } + default: + return nil, fmt.Errorf("Unknown block %s in the certificate", block.Type) + } + block, rest = pem.Decode(rest) + } + + if originCert.ZoneID == "" || originCert.APIToken == "" { + return nil, fmt.Errorf("Missing token in the certificate") + } + + return &originCert, nil +} + +func readOriginCert(originCertPath string) ([]byte, error) { + originCert, err := os.ReadFile(originCertPath) + if err != nil { + return nil, fmt.Errorf("cannot read %s to load origin certificate", originCertPath) + } + + return originCert, nil +} + +// FindOriginCert will check to make sure that the certificate exists at the specified file path. +func FindOriginCert(originCertPath string, log *zerolog.Logger) (string, error) { + if originCertPath == "" { + log.Error().Msgf("Cannot determine default origin certificate path. No file %s in %v. You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable", DefaultCredentialFile, config.DefaultConfigSearchDirectories()) + return "", fmt.Errorf("client didn't specify origincert path") + } + var err error + originCertPath, err = homedir.Expand(originCertPath) + if err != nil { + log.Err(err).Msgf("Cannot resolve origin certificate path") + return "", fmt.Errorf("cannot resolve path %s", originCertPath) + } + // Check that the user has acquired a certificate using the login command + ok := fileExists(originCertPath) + if !ok { + log.Error().Msgf(`Cannot find a valid certificate for your origin at the path: + + %s + +If the path above is wrong, specify the path with the -origincert option. +If you don't have a certificate signed by Cloudflare, run the command: + + cloudflared login +`, originCertPath) + return "", fmt.Errorf("cannot find a valid certificate at the path %s", originCertPath) + } + + return originCertPath, nil +} + +// FileExists checks to see if a file exist at the provided path. +func fileExists(path string) bool { + fileStat, err := os.Stat(path) + if err != nil { + return false + } + return !fileStat.IsDir() +} diff --git a/credentials/origin_cert_test.go b/credentials/origin_cert_test.go new file mode 100644 index 00000000..77a473e4 --- /dev/null +++ b/credentials/origin_cert_test.go @@ -0,0 +1,110 @@ +package credentials + +import ( + "fmt" + "io/fs" + "os" + "path" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + originCertFile = "cert.pem" +) + +var ( + nopLog = zerolog.Nop().With().Logger() +) + +func TestLoadOriginCert(t *testing.T) { + cert, err := decodeOriginCert([]byte{}) + assert.Equal(t, fmt.Errorf("Cannot decode empty certificate"), err) + assert.Nil(t, cert) + + blocks, err := os.ReadFile("test-cert-unknown-block.pem") + assert.NoError(t, err) + cert, err = decodeOriginCert(blocks) + assert.Equal(t, fmt.Errorf("Unknown block RSA PRIVATE KEY in the certificate"), err) + assert.Nil(t, cert) +} + +func TestJSONArgoTunnelTokenEmpty(t *testing.T) { + blocks, err := os.ReadFile("test-cert-no-token.pem") + assert.NoError(t, err) + cert, err := decodeOriginCert(blocks) + assert.Equal(t, fmt.Errorf("Missing token in the certificate"), err) + assert.Nil(t, cert) +} + +func TestJSONArgoTunnelToken(t *testing.T) { + // The given cert's Argo Tunnel Token was generated by base64 encoding this JSON: + // { + // "zoneID": "7b0a4d77dfb881c1a3b7d61ea9443e19", + // "apiToken": "test-service-key", + // "accountID": "abcdabcdabcdabcd1234567890abcdef" + // } + CloudflareTunnelTokenTest(t, "test-cloudflare-tunnel-cert-json.pem") +} + +func CloudflareTunnelTokenTest(t *testing.T, path string) { + blocks, err := os.ReadFile(path) + assert.NoError(t, err) + cert, err := decodeOriginCert(blocks) + assert.NoError(t, err) + assert.NotNil(t, cert) + assert.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", cert.ZoneID) + key := "test-service-key" + assert.Equal(t, key, cert.APIToken) +} + +type mockFile struct { + path string + data []byte + err error +} + +type mockFileSystem struct { + files map[string]mockFile +} + +func newMockFileSystem(files ...mockFile) *mockFileSystem { + fs := mockFileSystem{map[string]mockFile{}} + for _, f := range files { + fs.files[f.path] = f + } + return &fs +} + +func (fs *mockFileSystem) ReadFile(path string) ([]byte, error) { + if f, ok := fs.files[path]; ok { + return f.data, f.err + } + return nil, os.ErrNotExist +} + +func (fs *mockFileSystem) ValidFilePath(path string) bool { + _, exists := fs.files[path] + return exists +} + +func TestFindOriginCert_Valid(t *testing.T) { + file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem") + require.NoError(t, err) + dir := t.TempDir() + certPath := path.Join(dir, originCertFile) + os.WriteFile(certPath, file, fs.ModePerm) + path, err := FindOriginCert(certPath, &nopLog) + require.NoError(t, err) + require.Equal(t, certPath, path) +} + +func TestFindOriginCert_Missing(t *testing.T) { + dir := t.TempDir() + certPath := path.Join(dir, originCertFile) + _, err := FindOriginCert(certPath, &nopLog) + require.Error(t, err) +} diff --git a/certutil/test-cert-no-token.pem b/credentials/test-cert-no-token.pem similarity index 100% rename from certutil/test-cert-no-token.pem rename to credentials/test-cert-no-token.pem diff --git a/certutil/test-cert-unknown-block.pem b/credentials/test-cert-unknown-block.pem similarity index 100% rename from certutil/test-cert-unknown-block.pem rename to credentials/test-cert-unknown-block.pem diff --git a/certutil/test-cloudflare-tunnel-cert-json.pem b/credentials/test-cloudflare-tunnel-cert-json.pem similarity index 100% rename from certutil/test-cloudflare-tunnel-cert-json.pem rename to credentials/test-cloudflare-tunnel-cert-json.pem