This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-go.git


The following commit(s) were added to refs/heads/main by this push:
     new 68420559 feat(catalog): refresh oauth tokens (#793)
68420559 is described below

commit 684205596d5c088c2d05823d4da70d49050d3f0e
Author: Tyler Rockwood <[email protected]>
AuthorDate: Wed Mar 18 14:24:43 2026 -0500

    feat(catalog): refresh oauth tokens (#793)
    
    The rest catalog was using a fixed token for the lifetime of the
    catalog. We need to refresh the token when the oauth server gives us an
    expiration. This means the credential fetch needs to move into the
    roundtripper. Also, since we use the same http client for refreshing and
    making catalog requests, we add a context key to prevent recursion.
    
    Fixes: #794
---
 catalog/rest/auth.go               | 133 +++++++++----------------------------
 catalog/rest/auth_test.go          |  70 ++++++++++---------
 catalog/rest/rest.go               |  85 ++++++++++++++++++------
 catalog/rest/rest_internal_test.go | 133 +++++++++++++++++++++++++++++++------
 catalog/rest/rest_test.go          |  27 ++++----
 go.mod                             |   2 +-
 6 files changed, 263 insertions(+), 187 deletions(-)

diff --git a/catalog/rest/auth.go b/catalog/rest/auth.go
index 82dceb39..815ba4d1 100644
--- a/catalog/rest/auth.go
+++ b/catalog/rest/auth.go
@@ -18,12 +18,10 @@
 package rest
 
 import (
-       "encoding/json"
+       "errors"
        "fmt"
-       "io"
-       "net/http"
-       "net/url"
-       "strings"
+
+       "golang.org/x/oauth2"
 )
 
 // AuthManager is an interface for providing custom authorization headers.
@@ -32,115 +30,50 @@ type AuthManager interface {
        AuthHeader() (string, string, error)
 }
 
-type oauthTokenResponse struct {
-       AccessToken  string `json:"access_token"`
-       TokenType    string `json:"token_type"`
-       ExpiresIn    int    `json:"expires_in"`
-       Scope        string `json:"scope"`
-       RefreshToken string `json:"refresh_token"`
-}
-
-type oauthErrorResponse struct {
-       Err     string `json:"error"`
-       ErrDesc string `json:"error_description"`
-       ErrURI  string `json:"error_uri"`
-}
-
-func (o oauthErrorResponse) Unwrap() error { return ErrOAuthError }
-func (o oauthErrorResponse) Error() string {
-       msg := o.Err
-       if o.ErrDesc != "" {
-               msg += ": " + o.ErrDesc
-       }
-
-       if o.ErrURI != "" {
-               msg += " (" + o.ErrURI + ")"
-       }
-
-       return msg
-}
-
 // Oauth2AuthManager is an implementation of the AuthManager interface which
-// simply returns the provided token as a bearer token. If a credential
-// is provided instead of a static token, it will fetch and refresh the
-// token as needed.
+// uses an oauth2.TokenSource to provide bearer tokens. The token source
+// handles caching, thread-safe refresh, and expiry management.
 type Oauth2AuthManager struct {
-       Token      string
-       Credential string
-
-       AuthURI *url.URL
-       Scope   string
-       Client  *http.Client
+       tokenSource oauth2.TokenSource
 }
 
 // AuthHeader returns the authorization header with the bearer token.
 func (o *Oauth2AuthManager) AuthHeader() (string, string, error) {
-       if o.Token == "" && o.Credential != "" {
-               if o.Client == nil {
-                       return "", "", fmt.Errorf("%w: cannot fetch token 
without http client", ErrRESTError)
+       tok, err := o.tokenSource.Token()
+       if err != nil {
+               var re *oauth2.RetrieveError
+               if errors.As(err, &re) {
+                       return "", "", oauthError{
+                               code: re.ErrorCode,
+                               desc: re.ErrorDescription,
+                               uri:  re.ErrorURI,
+                       }
                }
 
-               tok, err := o.fetchAccessToken()
-               if err != nil {
-                       return "", "", err
-               }
-               o.Token = tok
+               return "", "", fmt.Errorf("%w: %s", ErrOAuthError, err)
        }
 
-       return "Authorization", "Bearer " + o.Token, nil
+       return "Authorization", tok.Type() + " " + tok.AccessToken, nil
 }
 
-func (o *Oauth2AuthManager) fetchAccessToken() (string, error) {
-       clientID, clientSecret, hasID := strings.Cut(o.Credential, ":")
-       if !hasID {
-               clientID, clientSecret = "", o.Credential
-       }
-
-       scope := "catalog"
-       if o.Scope != "" {
-               scope = o.Scope
-       }
-       data := url.Values{
-               "grant_type":    {"client_credentials"},
-               "client_id":     {clientID},
-               "client_secret": {clientSecret},
-               "scope":         {scope},
-       }
+// oauthError wraps OAuth2 error details and implements the error chain
+// so that errors.Is(err, ErrOAuthError) returns true.
+type oauthError struct {
+       code string
+       desc string
+       uri  string
+}
 
-       if o.AuthURI == nil {
-               return "", fmt.Errorf("%w: missing auth uri for fetching 
token", ErrRESTError)
+func (e oauthError) Error() string {
+       msg := e.code
+       if e.desc != "" {
+               msg += ": " + e.desc
        }
-
-       rsp, err := o.Client.PostForm(o.AuthURI.String(), data)
-       if err != nil {
-               return "", err
+       if e.uri != "" {
+               msg += " (" + e.uri + ")"
        }
 
-       if rsp.StatusCode == http.StatusOK {
-               defer rsp.Body.Close()
-               dec := json.NewDecoder(rsp.Body)
-               var tok oauthTokenResponse
-               if err := dec.Decode(&tok); err != nil {
-                       return "", fmt.Errorf("failed to decode oauth token 
response: %w", err)
-               }
-
-               return tok.AccessToken, nil
-       }
-
-       switch rsp.StatusCode {
-       case http.StatusUnauthorized, http.StatusBadRequest:
-               defer func() {
-                       _, _ = io.Copy(io.Discard, rsp.Body)
-                       _ = rsp.Body.Close()
-               }()
-               dec := json.NewDecoder(rsp.Body)
-               var oauthErr oauthErrorResponse
-               if err := dec.Decode(&oauthErr); err != nil {
-                       return "", fmt.Errorf("failed to decode oauth error: 
%w", err)
-               }
-
-               return "", oauthErr
-       default:
-               return "", handleNon200(rsp, nil)
-       }
+       return msg
 }
+
+func (e oauthError) Unwrap() error { return ErrOAuthError }
diff --git a/catalog/rest/auth_test.go b/catalog/rest/auth_test.go
index 84dac1ed..f7bb0534 100644
--- a/catalog/rest/auth_test.go
+++ b/catalog/rest/auth_test.go
@@ -18,19 +18,25 @@
 package rest
 
 import (
+       "context"
        "encoding/json"
+       "errors"
        "net/http"
        "net/http/httptest"
-       "net/url"
        "testing"
 
        "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/require"
+       "golang.org/x/oauth2"
+       "golang.org/x/oauth2/clientcredentials"
 )
 
 func TestOauth2AuthManager_AuthHeader_StaticToken(t *testing.T) {
        manager := &Oauth2AuthManager{
-               Token: "static_token",
+               tokenSource: oauth2.StaticTokenSource(&oauth2.Token{
+                       AccessToken: "static_token",
+                       TokenType:   "Bearer",
+               }),
        }
 
        key, value, err := manager.AuthHeader()
@@ -39,16 +45,6 @@ func TestOauth2AuthManager_AuthHeader_StaticToken(t 
*testing.T) {
        assert.Equal(t, "Bearer static_token", value)
 }
 
-func TestOauth2AuthManager_AuthHeader_MissingClient(t *testing.T) {
-       manager := &Oauth2AuthManager{
-               Credential: "client:secret",
-       }
-
-       _, _, err := manager.AuthHeader()
-       require.Error(t, err)
-       assert.Contains(t, err.Error(), "cannot fetch token without http 
client")
-}
-
 func TestOauth2AuthManager_AuthHeader_FetchToken_Success(t *testing.T) {
        mux := http.NewServeMux()
        server := httptest.NewServer(mux)
@@ -61,28 +57,32 @@ func TestOauth2AuthManager_AuthHeader_FetchToken_Success(t 
*testing.T) {
                assert.Equal(t, "secret", r.FormValue("client_secret"))
                assert.Equal(t, "catalog", r.FormValue("scope"))
 
+               w.Header().Set("Content-Type", "application/json")
                w.WriteHeader(http.StatusOK)
-               json.NewEncoder(w).Encode(oauthTokenResponse{
-                       AccessToken: "fetched_token",
-                       TokenType:   "Bearer",
-                       ExpiresIn:   3600,
+               json.NewEncoder(w).Encode(map[string]any{
+                       "access_token": "fetched_token",
+                       "token_type":   "Bearer",
+                       "expires_in":   3600,
                })
        })
 
-       authURL, err := url.Parse(server.URL + "/oauth/token")
-       require.NoError(t, err)
+       cfg := &clientcredentials.Config{
+               ClientID:     "client",
+               ClientSecret: "secret",
+               TokenURL:     server.URL + "/oauth/token",
+               Scopes:       []string{"catalog"},
+               AuthStyle:    oauth2.AuthStyleInParams,
+       }
 
+       ctx := context.WithValue(context.Background(), oauth2.HTTPClient, 
server.Client())
        manager := &Oauth2AuthManager{
-               Credential: "client:secret",
-               AuthURI:    authURL,
-               Client:     server.Client(),
+               tokenSource: cfg.TokenSource(ctx),
        }
 
        key, value, err := manager.AuthHeader()
        require.NoError(t, err)
        assert.Equal(t, "Authorization", key)
        assert.Equal(t, "Bearer fetched_token", value)
-       assert.Equal(t, "fetched_token", manager.Token)
 }
 
 func TestOauth2AuthManager_AuthHeader_FetchToken_ErrorResponse(t *testing.T) {
@@ -91,23 +91,29 @@ func 
TestOauth2AuthManager_AuthHeader_FetchToken_ErrorResponse(t *testing.T) {
        defer server.Close()
 
        mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r 
*http.Request) {
+               w.Header().Set("Content-Type", "application/json")
                w.WriteHeader(http.StatusBadRequest)
-               json.NewEncoder(w).Encode(oauthErrorResponse{
-                       Err:     "invalid_client",
-                       ErrDesc: "Invalid client credentials",
+               json.NewEncoder(w).Encode(map[string]any{
+                       "error":             "invalid_client",
+                       "error_description": "Invalid client credentials",
                })
        })
 
-       authURL, err := url.Parse(server.URL + "/oauth/token")
-       require.NoError(t, err)
+       cfg := &clientcredentials.Config{
+               ClientID:     "client",
+               ClientSecret: "secret",
+               TokenURL:     server.URL + "/oauth/token",
+               AuthStyle:    oauth2.AuthStyleInParams,
+       }
 
+       ctx := context.WithValue(context.Background(), oauth2.HTTPClient, 
server.Client())
        manager := &Oauth2AuthManager{
-               Credential: "client:secret",
-               AuthURI:    authURL,
-               Client:     server.Client(),
+               tokenSource: cfg.TokenSource(ctx),
        }
 
-       _, _, err = manager.AuthHeader()
+       _, _, err := manager.AuthHeader()
        require.Error(t, err)
-       assert.Contains(t, err.Error(), "invalid_client: Invalid client 
credentials")
+       assert.True(t, errors.Is(err, ErrOAuthError), "error should wrap 
ErrOAuthError")
+       assert.Contains(t, err.Error(), "invalid_client")
+       assert.Contains(t, err.Error(), "Invalid client credentials")
 }
diff --git a/catalog/rest/rest.go b/catalog/rest/rest.go
index 943b98da..41ead570 100644
--- a/catalog/rest/rest.go
+++ b/catalog/rest/rest.go
@@ -44,12 +44,15 @@ import (
        "github.com/aws/aws-sdk-go-v2/aws"
        v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
        "github.com/aws/aws-sdk-go-v2/config"
+       "golang.org/x/oauth2"
+       "golang.org/x/oauth2/clientcredentials"
 )
 
 var _ catalog.Catalog = (*Catalog)(nil)
 
 const (
        pageSizeKey contextKey = "page_size"
+       skipOAuth   contextKey = "skip_oauth"
 
        defaultPageSize = 20
 
@@ -169,6 +172,7 @@ type configResponse struct {
 type sessionTransport struct {
        http.RoundTripper
 
+       authManager    AuthManager
        defaultHeaders http.Header
        signer         v4.HTTPSigner
        cfg            aws.Config
@@ -191,6 +195,14 @@ func (s *sessionTransport) RoundTrip(r *http.Request) 
(*http.Response, error) {
                }
        }
 
+       if s.authManager != nil && r.Context().Value(skipOAuth) == nil {
+               k, v, err := s.authManager.AuthHeader()
+               if err != nil {
+                       return nil, err
+               }
+               r.Header.Set(k, v)
+       }
+
        if s.signer != nil {
                var payloadHash string
                if r.Body == nil {
@@ -490,19 +502,53 @@ func NewCatalog(ctx context.Context, name, uri string, 
opts ...Option) (*Catalog
 }
 
 // setupOAuthManager creates an Oauth2AuthManager based on the provided 
options.
-// The allows users to set their token, credential, or just get the defaults 
if no auth manager is set.
-func setupOAuthManager(r *Catalog, cl *http.Client, opts *options) 
*Oauth2AuthManager {
-       authURI := opts.authUri
-       if authURI == nil {
-               authURI = r.baseURI.JoinPath("oauth/tokens")
+// It uses golang.org/x/oauth2 for token management, caching, and thread-safe 
refresh.
+func setupOAuthManager(r *Catalog, cl *http.Client, opts *options) AuthManager 
{
+       // If a static token is provided, use it directly.
+       if opts.oauthToken != "" {
+               return &Oauth2AuthManager{
+                       tokenSource: oauth2.StaticTokenSource(&oauth2.Token{
+                               AccessToken: opts.oauthToken,
+                               TokenType:   "Bearer",
+                       }),
+               }
+       }
+
+       // If no credential, no auth needed.
+       if opts.credential == "" {
+               return nil
+       }
+
+       authURL := opts.authUri
+       if authURL == nil {
+               authURL = r.baseURI.JoinPath("oauth/tokens")
+       }
+
+       clientID, clientSecret, found := strings.Cut(opts.credential, ":")
+       if !found {
+               clientID = ""
+               clientSecret = opts.credential
+       }
+
+       cfg := &clientcredentials.Config{
+               ClientID:     clientID,
+               ClientSecret: clientSecret,
+               TokenURL:     authURL.String(),
+               AuthStyle:    oauth2.AuthStyleInParams,
        }
 
+       scope := "catalog"
+       if opts.scope != "" {
+               scope = opts.scope
+       }
+       cfg.Scopes = []string{scope}
+
+       // Add skip oauth so we don't get in cycles trying to refresh the token
+       ctx := context.WithValue(context.Background(), skipOAuth, true)
+       ctx = context.WithValue(ctx, oauth2.HTTPClient, cl)
+
        return &Oauth2AuthManager{
-               Token:      opts.oauthToken,
-               Credential: opts.credential,
-               AuthURI:    authURI,
-               Scope:      opts.scope,
-               Client:     cl,
+               tokenSource: cfg.TokenSource(ctx),
        }
 }
 
@@ -526,13 +572,16 @@ func (r *Catalog) init(ctx context.Context, ops *options, 
uri string) error {
 }
 
 func (r *Catalog) createSession(ctx context.Context, opts *options) 
(*http.Client, error) {
-       session := &sessionTransport{
-               defaultHeaders: http.Header{},
-       }
+       var baseTransport http.RoundTripper
        if opts.transport != nil {
-               session.RoundTripper = opts.transport
+               baseTransport = opts.transport
        } else {
-               session.RoundTripper = &http.Transport{Proxy: 
http.ProxyFromEnvironment, TLSClientConfig: opts.tlsConfig}
+               baseTransport = &http.Transport{Proxy: 
http.ProxyFromEnvironment, TLSClientConfig: opts.tlsConfig}
+       }
+
+       session := &sessionTransport{
+               RoundTripper:   baseTransport,
+               defaultHeaders: http.Header{},
        }
        cl := &http.Client{Transport: session}
 
@@ -551,11 +600,7 @@ func (r *Catalog) createSession(ctx context.Context, opts 
*options) (*http.Clien
        }
 
        if opts.authManager != nil {
-               k, v, err := opts.authManager.AuthHeader()
-               if err != nil {
-                       return nil, err
-               }
-               session.defaultHeaders.Set(k, v)
+               session.authManager = opts.authManager
        }
 
        if opts.enableSigv4 {
diff --git a/catalog/rest/rest_internal_test.go 
b/catalog/rest/rest_internal_test.go
index 6141ce53..bad6060b 100644
--- a/catalog/rest/rest_internal_test.go
+++ b/catalog/rest/rest_internal_test.go
@@ -26,6 +26,7 @@ import (
        "crypto/x509"
        "encoding/hex"
        "encoding/json"
+       "fmt"
        "io"
        "net/http"
        "net/http/httptest"
@@ -61,7 +62,7 @@ func TestTokenAuthenticationPriority(t *testing.T) {
 
        mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, r 
*http.Request) {
                oauthCalled = true
-               w.WriteHeader(http.StatusOK)
+               w.Header().Set("Content-Type", "application/json")
                json.NewEncoder(w).Encode(map[string]any{
                        "access_token": "oauth_token_response",
                        "token_type":   "Bearer",
@@ -136,12 +137,11 @@ func TestScope(t *testing.T) {
 
                require.NoError(t, req.ParseForm())
                values := req.PostForm
-               assert.Equal(t, values.Get("grant_type"), "client_credentials")
-               assert.Equal(t, values.Get("client_secret"), "secret")
-               assert.Equal(t, values.Get("scope"), "my_scope")
-
-               w.WriteHeader(http.StatusOK)
+               assert.Equal(t, "client_credentials", values.Get("grant_type"))
+               assert.Equal(t, "secret", values.Get("client_secret"))
+               assert.Equal(t, "my_scope", values.Get("scope"))
 
+               w.Header().Set("Content-Type", "application/json")
                json.NewEncoder(w).Encode(map[string]any{
                        "access_token":      "some_jwt_token",
                        "token_type":        "Bearer",
@@ -179,13 +179,12 @@ func TestAuthHeader(t *testing.T) {
 
                require.NoError(t, req.ParseForm())
                values := req.PostForm
-               assert.Equal(t, values.Get("grant_type"), "client_credentials")
-               assert.Equal(t, values.Get("client_id"), "client")
-               assert.Equal(t, values.Get("client_secret"), "secret")
-               assert.Equal(t, values.Get("scope"), "catalog")
-
-               w.WriteHeader(http.StatusOK)
+               assert.Equal(t, "client_credentials", values.Get("grant_type"))
+               assert.Equal(t, "client", values.Get("client_id"))
+               assert.Equal(t, "secret", values.Get("client_secret"))
+               assert.Equal(t, "catalog", values.Get("scope"))
 
+               w.Header().Set("Content-Type", "application/json")
                json.NewEncoder(w).Encode(map[string]any{
                        "access_token":      "some_jwt_token",
                        "token_type":        "Bearer",
@@ -194,19 +193,30 @@ func TestAuthHeader(t *testing.T) {
                })
        })
 
+       var capturedAuthHeader string
+       mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, r 
*http.Request) {
+               capturedAuthHeader = r.Header.Get("Authorization")
+               json.NewEncoder(w).Encode(map[string]any{"namespaces": 
[][]string{}})
+       })
+
        cat, err := NewCatalog(context.Background(), "rest", srv.URL,
                WithCredential("client:secret"))
        require.NoError(t, err)
        assert.NotNil(t, cat)
 
+       // Verify default headers (excluding Authorization, which is now set 
per-request).
        require.IsType(t, (*sessionTransport)(nil), cat.cl.Transport)
        assert.Equal(t, http.Header{
-               "Authorization":               {"Bearer some_jwt_token"},
                "Content-Type":                {"application/json"},
                "User-Agent":                  {"GoIceberg/(unknown version)"},
                "X-Client-Version":            {icebergRestSpecVersion},
                "X-Iceberg-Access-Delegation": {"vended-credentials"},
        }, cat.cl.Transport.(*sessionTransport).defaultHeaders)
+
+       // Verify Authorization is set on actual requests.
+       _, err = cat.ListNamespaces(context.Background(), nil)
+       require.NoError(t, err)
+       assert.Equal(t, "Bearer some_jwt_token", capturedAuthHeader)
 }
 
 func TestAuthUriHeader(t *testing.T) {
@@ -227,13 +237,12 @@ func TestAuthUriHeader(t *testing.T) {
 
                require.NoError(t, req.ParseForm())
                values := req.PostForm
-               assert.Equal(t, values.Get("grant_type"), "client_credentials")
-               assert.Equal(t, values.Get("client_id"), "client")
-               assert.Equal(t, values.Get("client_secret"), "secret")
-               assert.Equal(t, values.Get("scope"), "catalog")
-
-               w.WriteHeader(http.StatusOK)
+               assert.Equal(t, "client_credentials", values.Get("grant_type"))
+               assert.Equal(t, "client", values.Get("client_id"))
+               assert.Equal(t, "secret", values.Get("client_secret"))
+               assert.Equal(t, "catalog", values.Get("scope"))
 
+               w.Header().Set("Content-Type", "application/json")
                json.NewEncoder(w).Encode(map[string]any{
                        "access_token":      "some_jwt_token",
                        "token_type":        "Bearer",
@@ -242,6 +251,12 @@ func TestAuthUriHeader(t *testing.T) {
                })
        })
 
+       var capturedAuthHeader string
+       mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, r 
*http.Request) {
+               capturedAuthHeader = r.Header.Get("Authorization")
+               json.NewEncoder(w).Encode(map[string]any{"namespaces": 
[][]string{}})
+       })
+
        authUri, err := url.Parse(srv.URL)
        require.NoError(t, err)
        cat, err := NewCatalog(context.Background(), "rest", srv.URL,
@@ -251,12 +266,15 @@ func TestAuthUriHeader(t *testing.T) {
 
        require.IsType(t, (*sessionTransport)(nil), cat.cl.Transport)
        assert.Equal(t, http.Header{
-               "Authorization":               {"Bearer some_jwt_token"},
                "Content-Type":                {"application/json"},
                "User-Agent":                  {"GoIceberg/(unknown version)"},
                "X-Client-Version":            {icebergRestSpecVersion},
                "X-Iceberg-Access-Delegation": {"vended-credentials"},
        }, cat.cl.Transport.(*sessionTransport).defaultHeaders)
+
+       _, err = cat.ListNamespaces(context.Background(), nil)
+       require.NoError(t, err)
+       assert.Equal(t, "Bearer some_jwt_token", capturedAuthHeader)
 }
 
 func TestSigv4EmptyStringHash(t *testing.T) {
@@ -466,6 +484,81 @@ func TestSigv4ConcurrentSigners(t *testing.T) {
        t.Logf("issued %d requests", count.Load())
 }
 
+func TestCredentialRefreshOnExpiry(t *testing.T) {
+       t.Parallel()
+
+       var tokenVersion atomic.Int64
+       var oauthCallCount atomic.Int64
+
+       mux := http.NewServeMux()
+       srv := httptest.NewServer(mux)
+       defer srv.Close()
+
+       mux.HandleFunc("/v1/config", func(w http.ResponseWriter, r 
*http.Request) {
+               json.NewEncoder(w).Encode(map[string]any{
+                       "defaults": map[string]any{}, "overrides": 
map[string]any{},
+               })
+       })
+
+       mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, r 
*http.Request) {
+               n := oauthCallCount.Add(1)
+               tokenVersion.Store(n)
+               w.Header().Set("Content-Type", "application/json")
+               json.NewEncoder(w).Encode(map[string]any{
+                       "access_token": fmt.Sprintf("token_v%d", n),
+                       "token_type":   "Bearer",
+                       "expires_in":   1, // expires in 1 second
+               })
+       })
+
+       mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, r 
*http.Request) {
+               auth := r.Header.Get("Authorization")
+               currentVersion := tokenVersion.Load()
+               expectedToken := fmt.Sprintf("Bearer token_v%d", currentVersion)
+
+               if auth != expectedToken {
+                       // Simulate server rejecting an expired/stale token.
+                       w.WriteHeader(http.StatusUnauthorized)
+                       json.NewEncoder(w).Encode(map[string]any{
+                               "error": map[string]any{
+                                       "message": "Token expired",
+                                       "type":    "NotAuthorizedException",
+                                       "code":    401,
+                               },
+                       })
+
+                       return
+               }
+               json.NewEncoder(w).Encode(map[string]any{
+                       "namespaces": [][]string{{"ns1"}},
+               })
+       })
+
+       cat, err := NewCatalog(context.Background(), "rest", srv.URL,
+               WithCredential("client:secret"))
+       require.NoError(t, err)
+
+       // First call should succeed - the token was just fetched during 
session creation.
+       namespaces, err := cat.ListNamespaces(context.Background(), nil)
+       require.NoError(t, err)
+       assert.Len(t, namespaces, 1)
+
+       // Wait for the token to "expire" and bump the server's expected version
+       // so the old token is rejected.
+       time.Sleep(2 * time.Second)
+       tokenVersion.Add(1)
+
+       // The catalog should automatically refresh its credential and retry,
+       // so this call should succeed transparently.
+       namespaces, err = cat.ListNamespaces(context.Background(), nil)
+       require.NoError(t, err, "catalog should refresh expired credentials 
automatically")
+       assert.Len(t, namespaces, 1)
+
+       // The OAuth endpoint should have been called a second time to get a 
fresh token.
+       assert.GreaterOrEqual(t, oauthCallCount.Load(), int64(2),
+               "OAuth endpoint should be called again to refresh the expired 
token")
+}
+
 // trackingReadCloser wraps an io.ReadCloser to track if Close() was called
 type trackingReadCloser struct {
        io.ReadCloser
diff --git a/catalog/rest/rest_test.go b/catalog/rest/rest_test.go
index 046c415a..ae8e16aa 100644
--- a/catalog/rest/rest_test.go
+++ b/catalog/rest/rest_test.go
@@ -101,8 +101,7 @@ func (r *RestCatalogSuite) TestToken200() {
                r.Equal(values.Get("client_secret"), "secret")
                r.Equal(values.Get("scope"), scope)
 
-               w.WriteHeader(http.StatusOK)
-
+               w.Header().Set("Content-Type", "application/json")
                json.NewEncoder(w).Encode(map[string]any{
                        "access_token":      TestToken,
                        "token_type":        "Bearer",
@@ -135,8 +134,7 @@ func (r *RestCatalogSuite) TestLoadRegisteredCatalog() {
                r.Equal(values.Get("client_secret"), "secret")
                r.Equal(values.Get("scope"), "catalog")
 
-               w.WriteHeader(http.StatusOK)
-
+               w.Header().Set("Content-Type", "application/json")
                json.NewEncoder(w).Encode(map[string]any{
                        "access_token":      TestToken,
                        "token_type":        "Bearer",
@@ -162,6 +160,7 @@ func (r *RestCatalogSuite) TestToken400() {
 
                r.Equal(req.Header.Get("Content-Type"), 
"application/x-www-form-urlencoded")
 
+               w.Header().Set("Content-Type", "application/json")
                w.WriteHeader(http.StatusBadRequest)
 
                json.NewEncoder(w).Encode(map[string]any{
@@ -191,8 +190,7 @@ func (r *RestCatalogSuite) TestToken200AuthUrl() {
                r.Equal(values.Get("client_secret"), "secret")
                r.Equal(values.Get("scope"), "catalog")
 
-               w.WriteHeader(http.StatusOK)
-
+               w.Header().Set("Content-Type", "application/json")
                json.NewEncoder(w).Encode(map[string]any{
                        "access_token":      TestToken,
                        "token_type":        "Bearer",
@@ -219,6 +217,7 @@ func (r *RestCatalogSuite) TestToken401() {
 
                r.Equal(req.Header.Get("Content-Type"), 
"application/x-www-form-urlencoded")
 
+               w.Header().Set("Content-Type", "application/json")
                w.WriteHeader(http.StatusUnauthorized)
 
                json.NewEncoder(w).Encode(map[string]any{
@@ -242,8 +241,7 @@ func (r *RestCatalogSuite) TestTokenContentTypeDuplicated() 
{
                values := req.Header.Values("Content-Type")
                r.Equal([]string{"application/x-www-form-urlencoded"}, values)
 
-               w.WriteHeader(http.StatusOK)
-
+               w.Header().Set("Content-Type", "application/json")
                json.NewEncoder(w).Encode(map[string]any{
                        "access_token":      TestToken,
                        "token_type":        "Bearer",
@@ -307,8 +305,7 @@ func (r *RestCatalogSuite) TestWithHeadersOnOAuthRoute() {
                        r.Equal(v, req.Header.Get(k))
                }
 
-               w.WriteHeader(http.StatusOK)
-
+               w.Header().Set("Content-Type", "application/json")
                json.NewEncoder(w).Encode(map[string]any{
                        "access_token":      TestToken,
                        "token_type":        "Bearer",
@@ -339,8 +336,7 @@ func (r *RestCatalogSuite) TestWithHeadersOnAuthURLRoute() {
                        r.Equal(v, req.Header.Get(k))
                }
 
-               w.WriteHeader(http.StatusOK)
-
+               w.Header().Set("Content-Type", "application/json")
                json.NewEncoder(w).Encode(map[string]any{
                        "access_token":      TestToken,
                        "token_type":        "Bearer",
@@ -418,8 +414,7 @@ func (r *RestCatalogSuite) TestListTablesPrefixed200() {
                r.Equal(values.Get("client_secret"), "secret")
                r.Equal(values.Get("scope"), "catalog")
 
-               w.WriteHeader(http.StatusOK)
-
+               w.Header().Set("Content-Type", "application/json")
                json.NewEncoder(w).Encode(map[string]any{
                        "access_token":      TestToken,
                        "token_type":        "Bearer",
@@ -2752,6 +2747,7 @@ func (r *RestCatalogSuite) TestCreateTableStaged() {
        var lastCommitBody map[string]any
 
        r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req 
*http.Request) {
+               w.Header().Set("Content-Type", "application/json")
                json.NewEncoder(w).Encode(map[string]any{
                        "access_token": TestToken, "token_type": "Bearer", 
"expires_in": 3600,
                })
@@ -2876,6 +2872,7 @@ func (r *RestCatalogSuite) TestCreateTableNotStaged() {
        var commitCalled bool
 
        r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req 
*http.Request) {
+               w.Header().Set("Content-Type", "application/json")
                json.NewEncoder(w).Encode(map[string]any{
                        "access_token": TestToken, "token_type": "Bearer", 
"expires_in": 3600,
                })
@@ -2918,6 +2915,7 @@ func (r *RestCatalogSuite) 
TestCommitTableErrCommitStateUnknown() {
        var statusCode int
 
        r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req 
*http.Request) {
+               w.Header().Set("Content-Type", "application/json")
                json.NewEncoder(w).Encode(map[string]any{
                        "access_token": TestToken, "token_type": "Bearer", 
"expires_in": 3600,
                })
@@ -2962,6 +2960,7 @@ func (r *RestCatalogSuite) 
TestUpdateTableErrCommitStateUnknown() {
        var statusCode int
 
        r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req 
*http.Request) {
+               w.Header().Set("Content-Type", "application/json")
                json.NewEncoder(w).Encode(map[string]any{
                        "access_token": TestToken, "token_type": "Bearer", 
"expires_in": 3600,
                })
diff --git a/go.mod b/go.mod
index 490f9258..cfc2fa4b 100644
--- a/go.mod
+++ b/go.mod
@@ -52,6 +52,7 @@ require (
        github.com/uptrace/bun/driver/sqliteshim v1.2.18
        github.com/uptrace/bun/extra/bundebug v1.2.18
        gocloud.dev v0.45.0
+       golang.org/x/oauth2 v0.36.0
        golang.org/x/sync v0.20.0
        google.golang.org/api v0.271.0
        gopkg.in/yaml.v3 v3.0.1
@@ -264,7 +265,6 @@ require (
        golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa // indirect
        golang.org/x/mod v0.33.0 // indirect
        golang.org/x/net v0.51.0 // indirect
-       golang.org/x/oauth2 v0.36.0 // indirect
        golang.org/x/sys v0.42.0 // indirect
        golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4 // indirect
        golang.org/x/term v0.40.0 // indirect

Reply via email to