This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-go.git
The following commit(s) were added to refs/heads/main by this push:
new 68420559 feat(catalog): refresh oauth tokens (#793)
68420559 is described below
commit 684205596d5c088c2d05823d4da70d49050d3f0e
Author: Tyler Rockwood <[email protected]>
AuthorDate: Wed Mar 18 14:24:43 2026 -0500
feat(catalog): refresh oauth tokens (#793)
The rest catalog was using a fixed token for the lifetime of the
catalog. We need to refresh the token when the oauth server gives us an
expiration. This means the credential fetch needs to move into the
roundtripper. Also, since we use the same http client for refreshing and
making catalog requests, we add a context key to prevent recursion.
Fixes: #794
---
catalog/rest/auth.go | 133 +++++++++----------------------------
catalog/rest/auth_test.go | 70 ++++++++++---------
catalog/rest/rest.go | 85 ++++++++++++++++++------
catalog/rest/rest_internal_test.go | 133 +++++++++++++++++++++++++++++++------
catalog/rest/rest_test.go | 27 ++++----
go.mod | 2 +-
6 files changed, 263 insertions(+), 187 deletions(-)
diff --git a/catalog/rest/auth.go b/catalog/rest/auth.go
index 82dceb39..815ba4d1 100644
--- a/catalog/rest/auth.go
+++ b/catalog/rest/auth.go
@@ -18,12 +18,10 @@
package rest
import (
- "encoding/json"
+ "errors"
"fmt"
- "io"
- "net/http"
- "net/url"
- "strings"
+
+ "golang.org/x/oauth2"
)
// AuthManager is an interface for providing custom authorization headers.
@@ -32,115 +30,50 @@ type AuthManager interface {
AuthHeader() (string, string, error)
}
-type oauthTokenResponse struct {
- AccessToken string `json:"access_token"`
- TokenType string `json:"token_type"`
- ExpiresIn int `json:"expires_in"`
- Scope string `json:"scope"`
- RefreshToken string `json:"refresh_token"`
-}
-
-type oauthErrorResponse struct {
- Err string `json:"error"`
- ErrDesc string `json:"error_description"`
- ErrURI string `json:"error_uri"`
-}
-
-func (o oauthErrorResponse) Unwrap() error { return ErrOAuthError }
-func (o oauthErrorResponse) Error() string {
- msg := o.Err
- if o.ErrDesc != "" {
- msg += ": " + o.ErrDesc
- }
-
- if o.ErrURI != "" {
- msg += " (" + o.ErrURI + ")"
- }
-
- return msg
-}
-
// Oauth2AuthManager is an implementation of the AuthManager interface which
-// simply returns the provided token as a bearer token. If a credential
-// is provided instead of a static token, it will fetch and refresh the
-// token as needed.
+// uses an oauth2.TokenSource to provide bearer tokens. The token source
+// handles caching, thread-safe refresh, and expiry management.
type Oauth2AuthManager struct {
- Token string
- Credential string
-
- AuthURI *url.URL
- Scope string
- Client *http.Client
+ tokenSource oauth2.TokenSource
}
// AuthHeader returns the authorization header with the bearer token.
func (o *Oauth2AuthManager) AuthHeader() (string, string, error) {
- if o.Token == "" && o.Credential != "" {
- if o.Client == nil {
- return "", "", fmt.Errorf("%w: cannot fetch token
without http client", ErrRESTError)
+ tok, err := o.tokenSource.Token()
+ if err != nil {
+ var re *oauth2.RetrieveError
+ if errors.As(err, &re) {
+ return "", "", oauthError{
+ code: re.ErrorCode,
+ desc: re.ErrorDescription,
+ uri: re.ErrorURI,
+ }
}
- tok, err := o.fetchAccessToken()
- if err != nil {
- return "", "", err
- }
- o.Token = tok
+ return "", "", fmt.Errorf("%w: %s", ErrOAuthError, err)
}
- return "Authorization", "Bearer " + o.Token, nil
+ return "Authorization", tok.Type() + " " + tok.AccessToken, nil
}
-func (o *Oauth2AuthManager) fetchAccessToken() (string, error) {
- clientID, clientSecret, hasID := strings.Cut(o.Credential, ":")
- if !hasID {
- clientID, clientSecret = "", o.Credential
- }
-
- scope := "catalog"
- if o.Scope != "" {
- scope = o.Scope
- }
- data := url.Values{
- "grant_type": {"client_credentials"},
- "client_id": {clientID},
- "client_secret": {clientSecret},
- "scope": {scope},
- }
+// oauthError wraps OAuth2 error details and implements the error chain
+// so that errors.Is(err, ErrOAuthError) returns true.
+type oauthError struct {
+ code string
+ desc string
+ uri string
+}
- if o.AuthURI == nil {
- return "", fmt.Errorf("%w: missing auth uri for fetching
token", ErrRESTError)
+func (e oauthError) Error() string {
+ msg := e.code
+ if e.desc != "" {
+ msg += ": " + e.desc
}
-
- rsp, err := o.Client.PostForm(o.AuthURI.String(), data)
- if err != nil {
- return "", err
+ if e.uri != "" {
+ msg += " (" + e.uri + ")"
}
- if rsp.StatusCode == http.StatusOK {
- defer rsp.Body.Close()
- dec := json.NewDecoder(rsp.Body)
- var tok oauthTokenResponse
- if err := dec.Decode(&tok); err != nil {
- return "", fmt.Errorf("failed to decode oauth token
response: %w", err)
- }
-
- return tok.AccessToken, nil
- }
-
- switch rsp.StatusCode {
- case http.StatusUnauthorized, http.StatusBadRequest:
- defer func() {
- _, _ = io.Copy(io.Discard, rsp.Body)
- _ = rsp.Body.Close()
- }()
- dec := json.NewDecoder(rsp.Body)
- var oauthErr oauthErrorResponse
- if err := dec.Decode(&oauthErr); err != nil {
- return "", fmt.Errorf("failed to decode oauth error:
%w", err)
- }
-
- return "", oauthErr
- default:
- return "", handleNon200(rsp, nil)
- }
+ return msg
}
+
+func (e oauthError) Unwrap() error { return ErrOAuthError }
diff --git a/catalog/rest/auth_test.go b/catalog/rest/auth_test.go
index 84dac1ed..f7bb0534 100644
--- a/catalog/rest/auth_test.go
+++ b/catalog/rest/auth_test.go
@@ -18,19 +18,25 @@
package rest
import (
+ "context"
"encoding/json"
+ "errors"
"net/http"
"net/http/httptest"
- "net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "golang.org/x/oauth2"
+ "golang.org/x/oauth2/clientcredentials"
)
func TestOauth2AuthManager_AuthHeader_StaticToken(t *testing.T) {
manager := &Oauth2AuthManager{
- Token: "static_token",
+ tokenSource: oauth2.StaticTokenSource(&oauth2.Token{
+ AccessToken: "static_token",
+ TokenType: "Bearer",
+ }),
}
key, value, err := manager.AuthHeader()
@@ -39,16 +45,6 @@ func TestOauth2AuthManager_AuthHeader_StaticToken(t
*testing.T) {
assert.Equal(t, "Bearer static_token", value)
}
-func TestOauth2AuthManager_AuthHeader_MissingClient(t *testing.T) {
- manager := &Oauth2AuthManager{
- Credential: "client:secret",
- }
-
- _, _, err := manager.AuthHeader()
- require.Error(t, err)
- assert.Contains(t, err.Error(), "cannot fetch token without http
client")
-}
-
func TestOauth2AuthManager_AuthHeader_FetchToken_Success(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
@@ -61,28 +57,32 @@ func TestOauth2AuthManager_AuthHeader_FetchToken_Success(t
*testing.T) {
assert.Equal(t, "secret", r.FormValue("client_secret"))
assert.Equal(t, "catalog", r.FormValue("scope"))
+ w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
- json.NewEncoder(w).Encode(oauthTokenResponse{
- AccessToken: "fetched_token",
- TokenType: "Bearer",
- ExpiresIn: 3600,
+ json.NewEncoder(w).Encode(map[string]any{
+ "access_token": "fetched_token",
+ "token_type": "Bearer",
+ "expires_in": 3600,
})
})
- authURL, err := url.Parse(server.URL + "/oauth/token")
- require.NoError(t, err)
+ cfg := &clientcredentials.Config{
+ ClientID: "client",
+ ClientSecret: "secret",
+ TokenURL: server.URL + "/oauth/token",
+ Scopes: []string{"catalog"},
+ AuthStyle: oauth2.AuthStyleInParams,
+ }
+ ctx := context.WithValue(context.Background(), oauth2.HTTPClient,
server.Client())
manager := &Oauth2AuthManager{
- Credential: "client:secret",
- AuthURI: authURL,
- Client: server.Client(),
+ tokenSource: cfg.TokenSource(ctx),
}
key, value, err := manager.AuthHeader()
require.NoError(t, err)
assert.Equal(t, "Authorization", key)
assert.Equal(t, "Bearer fetched_token", value)
- assert.Equal(t, "fetched_token", manager.Token)
}
func TestOauth2AuthManager_AuthHeader_FetchToken_ErrorResponse(t *testing.T) {
@@ -91,23 +91,29 @@ func
TestOauth2AuthManager_AuthHeader_FetchToken_ErrorResponse(t *testing.T) {
defer server.Close()
mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r
*http.Request) {
+ w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
- json.NewEncoder(w).Encode(oauthErrorResponse{
- Err: "invalid_client",
- ErrDesc: "Invalid client credentials",
+ json.NewEncoder(w).Encode(map[string]any{
+ "error": "invalid_client",
+ "error_description": "Invalid client credentials",
})
})
- authURL, err := url.Parse(server.URL + "/oauth/token")
- require.NoError(t, err)
+ cfg := &clientcredentials.Config{
+ ClientID: "client",
+ ClientSecret: "secret",
+ TokenURL: server.URL + "/oauth/token",
+ AuthStyle: oauth2.AuthStyleInParams,
+ }
+ ctx := context.WithValue(context.Background(), oauth2.HTTPClient,
server.Client())
manager := &Oauth2AuthManager{
- Credential: "client:secret",
- AuthURI: authURL,
- Client: server.Client(),
+ tokenSource: cfg.TokenSource(ctx),
}
- _, _, err = manager.AuthHeader()
+ _, _, err := manager.AuthHeader()
require.Error(t, err)
- assert.Contains(t, err.Error(), "invalid_client: Invalid client
credentials")
+ assert.True(t, errors.Is(err, ErrOAuthError), "error should wrap
ErrOAuthError")
+ assert.Contains(t, err.Error(), "invalid_client")
+ assert.Contains(t, err.Error(), "Invalid client credentials")
}
diff --git a/catalog/rest/rest.go b/catalog/rest/rest.go
index 943b98da..41ead570 100644
--- a/catalog/rest/rest.go
+++ b/catalog/rest/rest.go
@@ -44,12 +44,15 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/config"
+ "golang.org/x/oauth2"
+ "golang.org/x/oauth2/clientcredentials"
)
var _ catalog.Catalog = (*Catalog)(nil)
const (
pageSizeKey contextKey = "page_size"
+ skipOAuth contextKey = "skip_oauth"
defaultPageSize = 20
@@ -169,6 +172,7 @@ type configResponse struct {
type sessionTransport struct {
http.RoundTripper
+ authManager AuthManager
defaultHeaders http.Header
signer v4.HTTPSigner
cfg aws.Config
@@ -191,6 +195,14 @@ func (s *sessionTransport) RoundTrip(r *http.Request)
(*http.Response, error) {
}
}
+ if s.authManager != nil && r.Context().Value(skipOAuth) == nil {
+ k, v, err := s.authManager.AuthHeader()
+ if err != nil {
+ return nil, err
+ }
+ r.Header.Set(k, v)
+ }
+
if s.signer != nil {
var payloadHash string
if r.Body == nil {
@@ -490,19 +502,53 @@ func NewCatalog(ctx context.Context, name, uri string,
opts ...Option) (*Catalog
}
// setupOAuthManager creates an Oauth2AuthManager based on the provided
options.
-// The allows users to set their token, credential, or just get the defaults
if no auth manager is set.
-func setupOAuthManager(r *Catalog, cl *http.Client, opts *options)
*Oauth2AuthManager {
- authURI := opts.authUri
- if authURI == nil {
- authURI = r.baseURI.JoinPath("oauth/tokens")
+// It uses golang.org/x/oauth2 for token management, caching, and thread-safe
refresh.
+func setupOAuthManager(r *Catalog, cl *http.Client, opts *options) AuthManager
{
+ // If a static token is provided, use it directly.
+ if opts.oauthToken != "" {
+ return &Oauth2AuthManager{
+ tokenSource: oauth2.StaticTokenSource(&oauth2.Token{
+ AccessToken: opts.oauthToken,
+ TokenType: "Bearer",
+ }),
+ }
+ }
+
+ // If no credential, no auth needed.
+ if opts.credential == "" {
+ return nil
+ }
+
+ authURL := opts.authUri
+ if authURL == nil {
+ authURL = r.baseURI.JoinPath("oauth/tokens")
+ }
+
+ clientID, clientSecret, found := strings.Cut(opts.credential, ":")
+ if !found {
+ clientID = ""
+ clientSecret = opts.credential
+ }
+
+ cfg := &clientcredentials.Config{
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ TokenURL: authURL.String(),
+ AuthStyle: oauth2.AuthStyleInParams,
}
+ scope := "catalog"
+ if opts.scope != "" {
+ scope = opts.scope
+ }
+ cfg.Scopes = []string{scope}
+
+ // Add skip oauth so we don't get in cycles trying to refresh the token
+ ctx := context.WithValue(context.Background(), skipOAuth, true)
+ ctx = context.WithValue(ctx, oauth2.HTTPClient, cl)
+
return &Oauth2AuthManager{
- Token: opts.oauthToken,
- Credential: opts.credential,
- AuthURI: authURI,
- Scope: opts.scope,
- Client: cl,
+ tokenSource: cfg.TokenSource(ctx),
}
}
@@ -526,13 +572,16 @@ func (r *Catalog) init(ctx context.Context, ops *options,
uri string) error {
}
func (r *Catalog) createSession(ctx context.Context, opts *options)
(*http.Client, error) {
- session := &sessionTransport{
- defaultHeaders: http.Header{},
- }
+ var baseTransport http.RoundTripper
if opts.transport != nil {
- session.RoundTripper = opts.transport
+ baseTransport = opts.transport
} else {
- session.RoundTripper = &http.Transport{Proxy:
http.ProxyFromEnvironment, TLSClientConfig: opts.tlsConfig}
+ baseTransport = &http.Transport{Proxy:
http.ProxyFromEnvironment, TLSClientConfig: opts.tlsConfig}
+ }
+
+ session := &sessionTransport{
+ RoundTripper: baseTransport,
+ defaultHeaders: http.Header{},
}
cl := &http.Client{Transport: session}
@@ -551,11 +600,7 @@ func (r *Catalog) createSession(ctx context.Context, opts
*options) (*http.Clien
}
if opts.authManager != nil {
- k, v, err := opts.authManager.AuthHeader()
- if err != nil {
- return nil, err
- }
- session.defaultHeaders.Set(k, v)
+ session.authManager = opts.authManager
}
if opts.enableSigv4 {
diff --git a/catalog/rest/rest_internal_test.go
b/catalog/rest/rest_internal_test.go
index 6141ce53..bad6060b 100644
--- a/catalog/rest/rest_internal_test.go
+++ b/catalog/rest/rest_internal_test.go
@@ -26,6 +26,7 @@ import (
"crypto/x509"
"encoding/hex"
"encoding/json"
+ "fmt"
"io"
"net/http"
"net/http/httptest"
@@ -61,7 +62,7 @@ func TestTokenAuthenticationPriority(t *testing.T) {
mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, r
*http.Request) {
oauthCalled = true
- w.WriteHeader(http.StatusOK)
+ w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": "oauth_token_response",
"token_type": "Bearer",
@@ -136,12 +137,11 @@ func TestScope(t *testing.T) {
require.NoError(t, req.ParseForm())
values := req.PostForm
- assert.Equal(t, values.Get("grant_type"), "client_credentials")
- assert.Equal(t, values.Get("client_secret"), "secret")
- assert.Equal(t, values.Get("scope"), "my_scope")
-
- w.WriteHeader(http.StatusOK)
+ assert.Equal(t, "client_credentials", values.Get("grant_type"))
+ assert.Equal(t, "secret", values.Get("client_secret"))
+ assert.Equal(t, "my_scope", values.Get("scope"))
+ w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": "some_jwt_token",
"token_type": "Bearer",
@@ -179,13 +179,12 @@ func TestAuthHeader(t *testing.T) {
require.NoError(t, req.ParseForm())
values := req.PostForm
- assert.Equal(t, values.Get("grant_type"), "client_credentials")
- assert.Equal(t, values.Get("client_id"), "client")
- assert.Equal(t, values.Get("client_secret"), "secret")
- assert.Equal(t, values.Get("scope"), "catalog")
-
- w.WriteHeader(http.StatusOK)
+ assert.Equal(t, "client_credentials", values.Get("grant_type"))
+ assert.Equal(t, "client", values.Get("client_id"))
+ assert.Equal(t, "secret", values.Get("client_secret"))
+ assert.Equal(t, "catalog", values.Get("scope"))
+ w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": "some_jwt_token",
"token_type": "Bearer",
@@ -194,19 +193,30 @@ func TestAuthHeader(t *testing.T) {
})
})
+ var capturedAuthHeader string
+ mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, r
*http.Request) {
+ capturedAuthHeader = r.Header.Get("Authorization")
+ json.NewEncoder(w).Encode(map[string]any{"namespaces":
[][]string{}})
+ })
+
cat, err := NewCatalog(context.Background(), "rest", srv.URL,
WithCredential("client:secret"))
require.NoError(t, err)
assert.NotNil(t, cat)
+ // Verify default headers (excluding Authorization, which is now set
per-request).
require.IsType(t, (*sessionTransport)(nil), cat.cl.Transport)
assert.Equal(t, http.Header{
- "Authorization": {"Bearer some_jwt_token"},
"Content-Type": {"application/json"},
"User-Agent": {"GoIceberg/(unknown version)"},
"X-Client-Version": {icebergRestSpecVersion},
"X-Iceberg-Access-Delegation": {"vended-credentials"},
}, cat.cl.Transport.(*sessionTransport).defaultHeaders)
+
+ // Verify Authorization is set on actual requests.
+ _, err = cat.ListNamespaces(context.Background(), nil)
+ require.NoError(t, err)
+ assert.Equal(t, "Bearer some_jwt_token", capturedAuthHeader)
}
func TestAuthUriHeader(t *testing.T) {
@@ -227,13 +237,12 @@ func TestAuthUriHeader(t *testing.T) {
require.NoError(t, req.ParseForm())
values := req.PostForm
- assert.Equal(t, values.Get("grant_type"), "client_credentials")
- assert.Equal(t, values.Get("client_id"), "client")
- assert.Equal(t, values.Get("client_secret"), "secret")
- assert.Equal(t, values.Get("scope"), "catalog")
-
- w.WriteHeader(http.StatusOK)
+ assert.Equal(t, "client_credentials", values.Get("grant_type"))
+ assert.Equal(t, "client", values.Get("client_id"))
+ assert.Equal(t, "secret", values.Get("client_secret"))
+ assert.Equal(t, "catalog", values.Get("scope"))
+ w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": "some_jwt_token",
"token_type": "Bearer",
@@ -242,6 +251,12 @@ func TestAuthUriHeader(t *testing.T) {
})
})
+ var capturedAuthHeader string
+ mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, r
*http.Request) {
+ capturedAuthHeader = r.Header.Get("Authorization")
+ json.NewEncoder(w).Encode(map[string]any{"namespaces":
[][]string{}})
+ })
+
authUri, err := url.Parse(srv.URL)
require.NoError(t, err)
cat, err := NewCatalog(context.Background(), "rest", srv.URL,
@@ -251,12 +266,15 @@ func TestAuthUriHeader(t *testing.T) {
require.IsType(t, (*sessionTransport)(nil), cat.cl.Transport)
assert.Equal(t, http.Header{
- "Authorization": {"Bearer some_jwt_token"},
"Content-Type": {"application/json"},
"User-Agent": {"GoIceberg/(unknown version)"},
"X-Client-Version": {icebergRestSpecVersion},
"X-Iceberg-Access-Delegation": {"vended-credentials"},
}, cat.cl.Transport.(*sessionTransport).defaultHeaders)
+
+ _, err = cat.ListNamespaces(context.Background(), nil)
+ require.NoError(t, err)
+ assert.Equal(t, "Bearer some_jwt_token", capturedAuthHeader)
}
func TestSigv4EmptyStringHash(t *testing.T) {
@@ -466,6 +484,81 @@ func TestSigv4ConcurrentSigners(t *testing.T) {
t.Logf("issued %d requests", count.Load())
}
+func TestCredentialRefreshOnExpiry(t *testing.T) {
+ t.Parallel()
+
+ var tokenVersion atomic.Int64
+ var oauthCallCount atomic.Int64
+
+ mux := http.NewServeMux()
+ srv := httptest.NewServer(mux)
+ defer srv.Close()
+
+ mux.HandleFunc("/v1/config", func(w http.ResponseWriter, r
*http.Request) {
+ json.NewEncoder(w).Encode(map[string]any{
+ "defaults": map[string]any{}, "overrides":
map[string]any{},
+ })
+ })
+
+ mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, r
*http.Request) {
+ n := oauthCallCount.Add(1)
+ tokenVersion.Store(n)
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]any{
+ "access_token": fmt.Sprintf("token_v%d", n),
+ "token_type": "Bearer",
+ "expires_in": 1, // expires in 1 second
+ })
+ })
+
+ mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, r
*http.Request) {
+ auth := r.Header.Get("Authorization")
+ currentVersion := tokenVersion.Load()
+ expectedToken := fmt.Sprintf("Bearer token_v%d", currentVersion)
+
+ if auth != expectedToken {
+ // Simulate server rejecting an expired/stale token.
+ w.WriteHeader(http.StatusUnauthorized)
+ json.NewEncoder(w).Encode(map[string]any{
+ "error": map[string]any{
+ "message": "Token expired",
+ "type": "NotAuthorizedException",
+ "code": 401,
+ },
+ })
+
+ return
+ }
+ json.NewEncoder(w).Encode(map[string]any{
+ "namespaces": [][]string{{"ns1"}},
+ })
+ })
+
+ cat, err := NewCatalog(context.Background(), "rest", srv.URL,
+ WithCredential("client:secret"))
+ require.NoError(t, err)
+
+ // First call should succeed - the token was just fetched during
session creation.
+ namespaces, err := cat.ListNamespaces(context.Background(), nil)
+ require.NoError(t, err)
+ assert.Len(t, namespaces, 1)
+
+ // Wait for the token to "expire" and bump the server's expected version
+ // so the old token is rejected.
+ time.Sleep(2 * time.Second)
+ tokenVersion.Add(1)
+
+ // The catalog should automatically refresh its credential and retry,
+ // so this call should succeed transparently.
+ namespaces, err = cat.ListNamespaces(context.Background(), nil)
+ require.NoError(t, err, "catalog should refresh expired credentials
automatically")
+ assert.Len(t, namespaces, 1)
+
+ // The OAuth endpoint should have been called a second time to get a
fresh token.
+ assert.GreaterOrEqual(t, oauthCallCount.Load(), int64(2),
+ "OAuth endpoint should be called again to refresh the expired
token")
+}
+
// trackingReadCloser wraps an io.ReadCloser to track if Close() was called
type trackingReadCloser struct {
io.ReadCloser
diff --git a/catalog/rest/rest_test.go b/catalog/rest/rest_test.go
index 046c415a..ae8e16aa 100644
--- a/catalog/rest/rest_test.go
+++ b/catalog/rest/rest_test.go
@@ -101,8 +101,7 @@ func (r *RestCatalogSuite) TestToken200() {
r.Equal(values.Get("client_secret"), "secret")
r.Equal(values.Get("scope"), scope)
- w.WriteHeader(http.StatusOK)
-
+ w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": TestToken,
"token_type": "Bearer",
@@ -135,8 +134,7 @@ func (r *RestCatalogSuite) TestLoadRegisteredCatalog() {
r.Equal(values.Get("client_secret"), "secret")
r.Equal(values.Get("scope"), "catalog")
- w.WriteHeader(http.StatusOK)
-
+ w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": TestToken,
"token_type": "Bearer",
@@ -162,6 +160,7 @@ func (r *RestCatalogSuite) TestToken400() {
r.Equal(req.Header.Get("Content-Type"),
"application/x-www-form-urlencoded")
+ w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]any{
@@ -191,8 +190,7 @@ func (r *RestCatalogSuite) TestToken200AuthUrl() {
r.Equal(values.Get("client_secret"), "secret")
r.Equal(values.Get("scope"), "catalog")
- w.WriteHeader(http.StatusOK)
-
+ w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": TestToken,
"token_type": "Bearer",
@@ -219,6 +217,7 @@ func (r *RestCatalogSuite) TestToken401() {
r.Equal(req.Header.Get("Content-Type"),
"application/x-www-form-urlencoded")
+ w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]any{
@@ -242,8 +241,7 @@ func (r *RestCatalogSuite) TestTokenContentTypeDuplicated()
{
values := req.Header.Values("Content-Type")
r.Equal([]string{"application/x-www-form-urlencoded"}, values)
- w.WriteHeader(http.StatusOK)
-
+ w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": TestToken,
"token_type": "Bearer",
@@ -307,8 +305,7 @@ func (r *RestCatalogSuite) TestWithHeadersOnOAuthRoute() {
r.Equal(v, req.Header.Get(k))
}
- w.WriteHeader(http.StatusOK)
-
+ w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": TestToken,
"token_type": "Bearer",
@@ -339,8 +336,7 @@ func (r *RestCatalogSuite) TestWithHeadersOnAuthURLRoute() {
r.Equal(v, req.Header.Get(k))
}
- w.WriteHeader(http.StatusOK)
-
+ w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": TestToken,
"token_type": "Bearer",
@@ -418,8 +414,7 @@ func (r *RestCatalogSuite) TestListTablesPrefixed200() {
r.Equal(values.Get("client_secret"), "secret")
r.Equal(values.Get("scope"), "catalog")
- w.WriteHeader(http.StatusOK)
-
+ w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": TestToken,
"token_type": "Bearer",
@@ -2752,6 +2747,7 @@ func (r *RestCatalogSuite) TestCreateTableStaged() {
var lastCommitBody map[string]any
r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req
*http.Request) {
+ w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": TestToken, "token_type": "Bearer",
"expires_in": 3600,
})
@@ -2876,6 +2872,7 @@ func (r *RestCatalogSuite) TestCreateTableNotStaged() {
var commitCalled bool
r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req
*http.Request) {
+ w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": TestToken, "token_type": "Bearer",
"expires_in": 3600,
})
@@ -2918,6 +2915,7 @@ func (r *RestCatalogSuite)
TestCommitTableErrCommitStateUnknown() {
var statusCode int
r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req
*http.Request) {
+ w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": TestToken, "token_type": "Bearer",
"expires_in": 3600,
})
@@ -2962,6 +2960,7 @@ func (r *RestCatalogSuite)
TestUpdateTableErrCommitStateUnknown() {
var statusCode int
r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req
*http.Request) {
+ w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": TestToken, "token_type": "Bearer",
"expires_in": 3600,
})
diff --git a/go.mod b/go.mod
index 490f9258..cfc2fa4b 100644
--- a/go.mod
+++ b/go.mod
@@ -52,6 +52,7 @@ require (
github.com/uptrace/bun/driver/sqliteshim v1.2.18
github.com/uptrace/bun/extra/bundebug v1.2.18
gocloud.dev v0.45.0
+ golang.org/x/oauth2 v0.36.0
golang.org/x/sync v0.20.0
google.golang.org/api v0.271.0
gopkg.in/yaml.v3 v3.0.1
@@ -264,7 +265,6 @@ require (
golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa // indirect
golang.org/x/mod v0.33.0 // indirect
golang.org/x/net v0.51.0 // indirect
- golang.org/x/oauth2 v0.36.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4 // indirect
golang.org/x/term v0.40.0 // indirect