Skip to content
Open
220 changes: 220 additions & 0 deletions credentials/google/gcp_service_account_identity_credentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
/*
*
* Copyright 2026 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package google

import (
"context"
"errors"
"fmt"
"net/http"
"sync"
"time"

"cloud.google.com/go/auth"
"cloud.google.com/go/auth/credentials/idtoken"
"cloud.google.com/go/compute/metadata"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/status"
)

// earlyExpiry matches the hardcoded 5-minute early expiry used by the
// cloud.google.com/go/auth/credentials/idtoken package.
var earlyExpiry = 5 * time.Minute
Comment thread
Pranjali-2501 marked this conversation as resolved.

type gcpServiceAccountIdentityCallCreds struct {
audience string
creds *auth.Credentials
backoff backoff.Strategy

mu sync.Mutex
token *auth.Token

fetching chan struct{} // used to deduplicate concurrent fetches
nextRetryTime time.Time // When we can try next (backoff)
retryAttempt int // number of consecutive failures used to calculate backoff
lastErr error // error from last attempt
}

// NewGcpServiceAccountIdentity creates a PerRPCCredentials that authenticates
// using a GCP Service Account Identity JWT token for the given audience.
//
// It uses the cloud.google.com/go/auth/credentials/idtoken package to
// automatically fetch ID token from the GCE metadata server. This credential
// is only valid to use in an environment running on GCP.
func NewGcpServiceAccountIdentity(audience string) (credentials.PerRPCCredentials, error) {
if audience == "" {
return nil, fmt.Errorf("audience cannot be empty")
}

creds, err := idtoken.NewCredentials(&idtoken.Options{
Audience: audience,
})
if err != nil {
return nil, fmt.Errorf("failed to create auth.Credentials for idtoken: %v", err)
}

return &gcpServiceAccountIdentityCallCreds{
audience: audience,
creds: creds,
backoff: backoff.DefaultExponential,
}, nil
}

// GetRequestMetadata gets the current request metadata, refreshing tokens if
// required. This implementation follows the PerRPCCredentials interface.
//
// It guarantees that only one underlying token fetch will be executed
// concurrently. If a valid token is cached, it is returned immediately. If
// a fetch recently failed, the cached error is returned until the backoff
// interval expires. Otherwise, it initiates a new token fetch or blocks
// waiting for an already-in-progress fetch to complete.
func (c *gcpServiceAccountIdentityCallCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("cannot send secure credentials on an insecure connection: %v", err)
}

c.mu.Lock()

// If token is valid, return it. If it's also stale, trigger a background
// refresh if not already running.
if c.isTokenValid() {
if c.isTokenStale() && c.fetching == nil {
c.startFetch()
}
defer c.mu.Unlock()
return map[string]string{
"authorization": "Bearer " + c.token.Value,
}, nil
}

if c.lastErr != nil && time.Now().Before(c.nextRetryTime) {
c.mu.Unlock()
return nil, c.lastErr
}

if c.fetching == nil {
c.startFetch()
}
wait := c.fetching
c.mu.Unlock()

select {
case <-wait:
c.mu.Lock()
defer c.mu.Unlock()
if c.isTokenValid() {
return map[string]string{
"authorization": "Bearer " + c.token.Value,
}, nil
}
if c.lastErr != nil {
return nil, c.lastErr
}
return nil, status.Error(codes.Unauthenticated, "fetched token is expired")
case <-ctx.Done():
return nil, ctx.Err()
}
}

// RequireTransportSecurity indicates whether the credentials requires
// transport security.
func (c *gcpServiceAccountIdentityCallCreds) RequireTransportSecurity() bool {
return true
}

// isTokenStale checks if the token doesn't exist or falls within the early
// expiry window.
func (c *gcpServiceAccountIdentityCallCreds) isTokenStale() bool {
if c.token == nil {
return true
}

return c.token.Expiry.Round(0).Add(-earlyExpiry).Before(time.Now())
}

// isTokenValid checks if the token exists and not expired yet.
func (c *gcpServiceAccountIdentityCallCreds) isTokenValid() bool {
if c.token == nil {
return false
}

return c.token.Expiry.After(time.Now())
}

// startFetch initiates an asynchronous token fetch. It creates the 'fetching'
// channel to allow concurrent callers to wait, and starts a goroutine to
// perform the fetch and update the credential state upon completion.
func (c *gcpServiceAccountIdentityCallCreds) startFetch() {
c.fetching = make(chan struct{})
go func() {
token, err := c.creds.TokenProvider.Token(context.Background())

c.mu.Lock()
defer c.mu.Unlock()

close(c.fetching)
c.fetching = nil
err = mapFetchError(err)
c.setBackoff(err)
if err == nil {
c.token = token
}
}()
}

// setBackoff updates the backoff state based on the result of a fetch attempt.
// If err is nil, it resets the backoff; otherwise, it calculates the next
// retry time using the backoff strategy.
func (c *gcpServiceAccountIdentityCallCreds) setBackoff(err error) {
if err == nil {
c.lastErr = nil
c.retryAttempt = 0
c.nextRetryTime = time.Time{}
return
}
c.lastErr = err
backoffDelay := c.backoff.Backoff(c.retryAttempt)
c.retryAttempt++
c.nextRetryTime = time.Now().Add(backoffDelay)
}

func mapFetchError(err error) error {
if err == nil {
return nil
}

var metadataErr *metadata.Error
if errors.As(err, &metadataErr) {
switch metadataErr.Code {
case http.StatusTooManyRequests, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
return status.Error(codes.Unavailable, err.Error())
default:
return status.Error(codes.Unauthenticated, err.Error())
}
}

if _, ok := err.(metadata.NotDefinedError); ok {
return status.Error(codes.Unauthenticated, err.Error())
}

return status.Error(codes.Unavailable, err.Error())
}
Loading
Loading