This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new cba396f4ab [python] support rest token refresh (#7007)
cba396f4ab is described below
commit cba396f4ab4640a1249719eda3c6c7e9741511b2
Author: XiaoHongbo <[email protected]>
AuthorDate: Sat Jan 17 21:49:51 2026 +0800
[python] support rest token refresh (#7007)
---
paimon-python/pypaimon/api/rest_util.py | 6 +-
.../pypaimon/catalog/rest/rest_token_file_io.py | 128 +++++++++++++++------
paimon-python/pypaimon/read/reader/lance_utils.py | 55 +++++----
paimon-python/pypaimon/tests/rest/rest_server.py | 48 +++++++-
4 files changed, 179 insertions(+), 58 deletions(-)
diff --git a/paimon-python/pypaimon/api/rest_util.py
b/paimon-python/pypaimon/api/rest_util.py
index 97a709ecc3..fd4d1da040 100644
--- a/paimon-python/pypaimon/api/rest_util.py
+++ b/paimon-python/pypaimon/api/rest_util.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Dict
+from typing import Dict, Optional
from urllib.parse import unquote
from pypaimon.common.options import Options
@@ -46,8 +46,8 @@ class RESTUtil:
@staticmethod
def merge(
- base_properties: Dict[str, str],
- override_properties: Dict[str, str]) -> Dict[str, str]:
+ base_properties: Optional[Dict[str, str]],
+ override_properties: Optional[Dict[str, str]]) -> Dict[str, str]:
if override_properties is None:
override_properties = {}
if base_properties is None:
diff --git a/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py
b/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py
index f686dc66ea..7769ba639b 100644
--- a/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py
+++ b/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py
@@ -18,7 +18,7 @@ limitations under the License.
import logging
import threading
import time
-from typing import Optional
+from typing import Optional, Union
from cachetools import TTLCache
@@ -41,40 +41,55 @@ class RESTTokenFileIO(FileIO):
_FILE_IO_CACHE_MAXSIZE = 1000
_FILE_IO_CACHE_TTL = 36000 # 10 hours in seconds
+ _FILE_IO_CACHE: TTLCache = None
+ _FILE_IO_CACHE_LOCK = threading.Lock()
+
+ _TOKEN_CACHE: dict = {}
+ _TOKEN_LOCKS: dict = {}
+ _TOKEN_LOCKS_LOCK = threading.Lock()
+
+ @classmethod
+ def _get_file_io_cache(cls) -> TTLCache:
+ if cls._FILE_IO_CACHE is None:
+ with cls._FILE_IO_CACHE_LOCK:
+ if cls._FILE_IO_CACHE is None:
+ cls._FILE_IO_CACHE = TTLCache(
+ maxsize=cls._FILE_IO_CACHE_MAXSIZE,
+ ttl=cls._FILE_IO_CACHE_TTL
+ )
+ return cls._FILE_IO_CACHE
+
def __init__(self, identifier: Identifier, path: str,
- catalog_options: Optional[Options] = None):
+ catalog_options: Optional[Union[dict, Options]] = None):
self.identifier = identifier
self.path = path
- self.catalog_options = catalog_options
- self.properties = catalog_options or Options({}) # For compatibility
with refresh_token()
+ if catalog_options is None:
+ self.catalog_options = None
+ elif isinstance(catalog_options, dict):
+ self.catalog_options = Options(catalog_options)
+ else:
+ # Assume it's already an Options object
+ self.catalog_options = catalog_options
+ self.properties = self.catalog_options or Options({}) # For
compatibility with refresh_token()
self.token: Optional[RESTToken] = None
self.api_instance: Optional[RESTApi] = None
self.lock = threading.Lock()
self.log = logging.getLogger(__name__)
self._uri_reader_factory_cache: Optional[UriReaderFactory] = None
- self._file_io_cache: TTLCache = TTLCache(
- maxsize=self._FILE_IO_CACHE_MAXSIZE,
- ttl=self._FILE_IO_CACHE_TTL
- )
def __getstate__(self):
state = self.__dict__.copy()
# Remove non-serializable objects
state.pop('lock', None)
state.pop('api_instance', None)
- state.pop('_file_io_cache', None)
state.pop('_uri_reader_factory_cache', None)
# token can be serialized, but we'll refresh it on deserialization
return state
def __setstate__(self, state):
self.__dict__.update(state)
- # Recreate lock and cache after deserialization
+ # Recreate lock after deserialization
self.lock = threading.Lock()
- self._file_io_cache = TTLCache(
- maxsize=self._FILE_IO_CACHE_MAXSIZE,
- ttl=self._FILE_IO_CACHE_TTL
- )
self._uri_reader_factory_cache = None
# api_instance will be recreated when needed
self.api_instance = None
@@ -86,25 +101,36 @@ class RESTTokenFileIO(FileIO):
return FileIO.get(self.path, self.catalog_options or Options({}))
cache_key = self.token
+ cache = self._get_file_io_cache()
- file_io = self._file_io_cache.get(cache_key)
+ file_io = cache.get(cache_key)
if file_io is not None:
return file_io
- with self.lock:
- file_io = self._file_io_cache.get(cache_key)
+ with self._FILE_IO_CACHE_LOCK:
+ self.try_to_refresh_token()
+
+ if self.token is None:
+ return FileIO.get(self.path, self.catalog_options or
Options({}))
+
+ cache_key = self.token
+ cache = self._get_file_io_cache()
+ file_io = cache.get(cache_key)
if file_io is not None:
return file_io
- merged_token =
self._merge_token_with_catalog_options(self.token.token)
merged_properties = RESTUtil.merge(
self.catalog_options.to_map() if self.catalog_options else {},
- merged_token
+ self.token.token
)
+ if self.catalog_options:
+ dlf_oss_endpoint =
self.catalog_options.get(CatalogOptions.DLF_OSS_ENDPOINT)
+ if dlf_oss_endpoint and dlf_oss_endpoint.strip():
+ merged_properties[OssOptions.OSS_ENDPOINT.key()] =
dlf_oss_endpoint
merged_options = Options(merged_properties)
file_io = PyArrowFileIO(self.path, merged_options)
- self._file_io_cache[cache_key] = file_io
+ cache[cache_key] = file_io
return file_io
def _merge_token_with_catalog_options(self, token: dict) -> dict:
@@ -180,16 +206,55 @@ class RESTTokenFileIO(FileIO):
return self.file_io().filesystem
def try_to_refresh_token(self):
- if self.should_refresh():
- with self.lock:
- if self.should_refresh():
- self.refresh_token()
+ identifier_str = str(self.identifier)
+
+ if self.token is not None and not self._is_token_expired(self.token):
+ return
+
+ cached_token = self._get_cached_token(identifier_str)
+ if cached_token and not self._is_token_expired(cached_token):
+ self.token = cached_token
+ return
+
+ global_lock = self._get_global_token_lock(identifier_str)
+
+ with global_lock:
+ cached_token = self._get_cached_token(identifier_str)
+ if cached_token and not self._is_token_expired(cached_token):
+ self.token = cached_token
+ return
+
+ token_to_check = cached_token if cached_token else self.token
+ if token_to_check is None or
self._is_token_expired(token_to_check):
+ self.refresh_token()
+ self._set_cached_token(identifier_str, self.token)
+
+ def _get_cached_token(self, identifier_str: str) -> Optional[RESTToken]:
+ with self._TOKEN_LOCKS_LOCK:
+ return self._TOKEN_CACHE.get(identifier_str)
+
+ def _set_cached_token(self, identifier_str: str, token: RESTToken):
+ with self._TOKEN_LOCKS_LOCK:
+ self._TOKEN_CACHE[identifier_str] = token
+
+ def _is_token_expired(self, token: Optional[RESTToken]) -> bool:
+ if token is None:
+ return True
+ current_time = int(time.time() * 1000)
+ return (token.expire_at_millis - current_time) <
RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS
+
+ def _get_global_token_lock(self, identifier_str: str) -> threading.Lock:
+ with self._TOKEN_LOCKS_LOCK:
+ if identifier_str not in self._TOKEN_LOCKS:
+ self._TOKEN_LOCKS[identifier_str] = threading.Lock()
+ return self._TOKEN_LOCKS[identifier_str]
def should_refresh(self):
if self.token is None:
return True
current_time = int(time.time() * 1000)
- return (self.token.expire_at_millis - current_time) <
RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS
+ time_until_expiry = self.token.expire_at_millis - current_time
+ return time_until_expiry < RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS
def refresh_token(self):
self.log.info(f"begin refresh data token for identifier
[{self.identifier}]")
@@ -200,17 +265,14 @@ class RESTTokenFileIO(FileIO):
self.log.info(
f"end refresh data token for identifier [{self.identifier}]
expiresAtMillis [{response.expires_at_millis}]"
)
- self.token = RESTToken(response.token, response.expires_at_millis)
+
+ merged_token_dict =
self._merge_token_with_catalog_options(response.token)
+ new_token = RESTToken(merged_token_dict, response.expires_at_millis)
+ self.token = new_token
def valid_token(self):
self.try_to_refresh_token()
return self.token
def close(self):
- with self.lock:
- for file_io in self._file_io_cache.values():
- try:
- file_io.close()
- except Exception as e:
- self.log.warning(f"Error closing cached FileIO: {e}")
- self._file_io_cache.clear()
+ pass
diff --git a/paimon-python/pypaimon/read/reader/lance_utils.py
b/paimon-python/pypaimon/read/reader/lance_utils.py
index c219dc6704..2e3a331e4b 100644
--- a/paimon-python/pypaimon/read/reader/lance_utils.py
+++ b/paimon-python/pypaimon/read/reader/lance_utils.py
@@ -26,12 +26,24 @@ from pypaimon.common.options.config import OssOptions
def to_lance_specified(file_io: FileIO, file_path: str) -> Tuple[str,
Optional[Dict[str, str]]]:
"""Convert path and extract storage options for Lance format."""
+ # For RESTTokenFileIO, get underlying FileIO which already has latest
token merged
+ # This follows Java implementation: ((RESTTokenFileIO) fileIO).fileIO()
+ # The file_io() method will refresh token and return a FileIO with merged
token
if hasattr(file_io, 'file_io'):
+ # Call file_io() to get underlying FileIO with latest token
+ # This ensures token is refreshed and merged with catalog options
file_io = file_io.file_io()
+ # Now get properties from the underlying FileIO (which has latest token)
+ if hasattr(file_io, 'get_merged_properties'):
+ properties = file_io.get_merged_properties()
+ else:
+ properties = file_io.properties if hasattr(file_io, 'properties') and
file_io.properties else None
+
scheme, _, _ = file_io.parse_location(file_path)
- storage_options = None
file_path_for_lance = file_io.to_filesystem_path(file_path)
+
+ storage_options = None
if scheme in {'file', None} or not scheme:
if not os.path.isabs(file_path_for_lance):
@@ -40,37 +52,40 @@ def to_lance_specified(file_io: FileIO, file_path: str) ->
Tuple[str, Optional[D
file_path_for_lance = file_path
if scheme == 'oss':
- storage_options = {}
- if hasattr(file_io, 'properties'):
- for key, value in file_io.properties.data.items():
+ parsed = urlparse(file_path)
+ bucket = parsed.netloc
+ path = parsed.path.lstrip('/')
+
+ if properties:
+ storage_options = {}
+ for key, value in properties.to_map().items():
if str(key).startswith('fs.'):
storage_options[key] = value
- parsed = urlparse(file_path)
- bucket = parsed.netloc
- path = parsed.path.lstrip('/')
-
- endpoint = file_io.properties.get(OssOptions.OSS_ENDPOINT)
+ endpoint = properties.get(OssOptions.OSS_ENDPOINT)
if endpoint:
endpoint_clean = endpoint.replace('http://',
'').replace('https://', '')
storage_options['endpoint'] =
f"https://{bucket}.{endpoint_clean}"
- if file_io.properties.contains(OssOptions.OSS_ACCESS_KEY_ID):
- storage_options['access_key_id'] =
file_io.properties.get(OssOptions.OSS_ACCESS_KEY_ID)
- storage_options['oss_access_key_id'] =
file_io.properties.get(OssOptions.OSS_ACCESS_KEY_ID)
- if file_io.properties.contains(OssOptions.OSS_ACCESS_KEY_SECRET):
- storage_options['secret_access_key'] =
file_io.properties.get(OssOptions.OSS_ACCESS_KEY_SECRET)
- storage_options['oss_secret_access_key'] =
file_io.properties.get(OssOptions.OSS_ACCESS_KEY_SECRET)
- if file_io.properties.contains(OssOptions.OSS_SECURITY_TOKEN):
- storage_options['session_token'] =
file_io.properties.get(OssOptions.OSS_SECURITY_TOKEN)
- storage_options['oss_session_token'] =
file_io.properties.get(OssOptions.OSS_SECURITY_TOKEN)
- if file_io.properties.contains(OssOptions.OSS_ENDPOINT):
- storage_options['oss_endpoint'] =
file_io.properties.get(OssOptions.OSS_ENDPOINT)
+ if properties.contains(OssOptions.OSS_ACCESS_KEY_ID):
+ storage_options['access_key_id'] =
properties.get(OssOptions.OSS_ACCESS_KEY_ID)
+ storage_options['oss_access_key_id'] =
properties.get(OssOptions.OSS_ACCESS_KEY_ID)
+ if properties.contains(OssOptions.OSS_ACCESS_KEY_SECRET):
+ storage_options['secret_access_key'] =
properties.get(OssOptions.OSS_ACCESS_KEY_SECRET)
+ storage_options['oss_secret_access_key'] =
properties.get(OssOptions.OSS_ACCESS_KEY_SECRET)
+ if properties.contains(OssOptions.OSS_SECURITY_TOKEN):
+ storage_options['session_token'] =
properties.get(OssOptions.OSS_SECURITY_TOKEN)
+ storage_options['oss_session_token'] =
properties.get(OssOptions.OSS_SECURITY_TOKEN)
+ if properties.contains(OssOptions.OSS_ENDPOINT):
+ storage_options['oss_endpoint'] =
properties.get(OssOptions.OSS_ENDPOINT)
+
storage_options['virtual_hosted_style_request'] = 'true'
if bucket and path:
file_path_for_lance = f"oss://{bucket}/{path}"
elif bucket:
file_path_for_lance = f"oss://{bucket}"
+ else:
+ storage_options = None
return file_path_for_lance, storage_options
diff --git a/paimon-python/pypaimon/tests/rest/rest_server.py
b/paimon-python/pypaimon/tests/rest/rest_server.py
index d556f7f55c..cb7b321083 100755
--- a/paimon-python/pypaimon/tests/rest/rest_server.py
+++ b/paimon-python/pypaimon/tests/rest/rest_server.py
@@ -24,9 +24,12 @@ import uuid
from dataclasses import dataclass
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from urllib.parse import urlparse
+if TYPE_CHECKING:
+ from pypaimon.catalog.rest.rest_token import RESTToken
+
from pypaimon.api.api_request import (AlterTableRequest, CreateDatabaseRequest,
CreateTableRequest, RenameTableRequest)
from pypaimon.api.api_response import (ConfigResponse, GetDatabaseResponse,
@@ -213,6 +216,7 @@ class RESTCatalogServer:
self.table_partitions_store: Dict[str, List] = {}
self.no_permission_databases: List[str] = []
self.no_permission_tables: List[str] = []
+ self.table_token_store: Dict[str, "RESTToken"] = {}
# Initialize mock catalog (simplified)
self.data_path = data_path
@@ -469,10 +473,12 @@ class RESTCatalogServer:
# Basic table operations (GET, DELETE, etc.)
return self._table_handle(method, data, lookup_identifier)
elif len(path_parts) == 4:
- # Extended operations (e.g., commit)
+ # Extended operations (e.g., commit, token)
operation = path_parts[3]
if operation == "commit":
return self._table_commit_handle(method, data,
lookup_identifier, branch_part)
+ elif operation == "token":
+ return self._table_token_handle(method, lookup_identifier)
else:
return self._mock_response(ErrorResponse(None, None, "Not
Found", 404), 404)
return self._mock_response(ErrorResponse(None, None, "Not Found",
404), 404)
@@ -574,6 +580,44 @@ class RESTCatalogServer:
return self._mock_response(ErrorResponse(None, None, "Method Not
Allowed", 405), 405)
+ def _table_token_handle(self, method: str, identifier: Identifier) ->
Tuple[str, int]:
+ if method != "GET":
+ return self._mock_response(ErrorResponse(None, None, "Method Not
Allowed", 405), 405)
+
+ if identifier.get_full_name() not in self.table_metadata_store:
+ raise TableNotExistException(identifier)
+
+ from pypaimon.api.api_response import GetTableTokenResponse
+
+ token_key = identifier.get_full_name()
+ if token_key in self.table_token_store:
+ rest_token = self.table_token_store[token_key]
+ response = GetTableTokenResponse(
+ token=rest_token.token,
+ expires_at_millis=rest_token.expire_at_millis
+ )
+ else:
+ default_token = {
+ "akId": "akId" + str(int(time.time() * 1000)),
+ "akSecret": "akSecret" + str(int(time.time() * 1000))
+ }
+ response = GetTableTokenResponse(
+ token=default_token,
+ expires_at_millis=int(time.time() * 1000) + 3600_000 # 1 hour
from now
+ )
+
+ return self._mock_response(response, 200)
+
+ def set_table_token(self, identifier: Identifier, token: "RESTToken") ->
None:
+ self.table_token_store[identifier.get_full_name()] = token
+
+ def get_table_token(self, identifier: Identifier) -> Optional["RESTToken"]:
+ return self.table_token_store.get(identifier.get_full_name())
+
+ def reset_table_token(self, identifier: Identifier) -> None:
+ if identifier.get_full_name() in self.table_token_store:
+ del self.table_token_store[identifier.get_full_name()]
+
def _table_commit_handle(self, method: str, data: str, identifier:
Identifier,
branch: str = None) -> Tuple[str, int]:
"""Handle table commit operations"""