This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new cba396f4ab [python] support rest token refresh  (#7007)
cba396f4ab is described below

commit cba396f4ab4640a1249719eda3c6c7e9741511b2
Author: XiaoHongbo <[email protected]>
AuthorDate: Sat Jan 17 21:49:51 2026 +0800

    [python] support rest token refresh  (#7007)
---
 paimon-python/pypaimon/api/rest_util.py            |   6 +-
 .../pypaimon/catalog/rest/rest_token_file_io.py    | 128 +++++++++++++++------
 paimon-python/pypaimon/read/reader/lance_utils.py  |  55 +++++----
 paimon-python/pypaimon/tests/rest/rest_server.py   |  48 +++++++-
 4 files changed, 179 insertions(+), 58 deletions(-)

diff --git a/paimon-python/pypaimon/api/rest_util.py 
b/paimon-python/pypaimon/api/rest_util.py
index 97a709ecc3..fd4d1da040 100644
--- a/paimon-python/pypaimon/api/rest_util.py
+++ b/paimon-python/pypaimon/api/rest_util.py
@@ -15,7 +15,7 @@
 #  specific language governing permissions and limitations
 #  under the License.
 
-from typing import Dict
+from typing import Dict, Optional
 from urllib.parse import unquote
 
 from pypaimon.common.options import Options
@@ -46,8 +46,8 @@ class RESTUtil:
 
     @staticmethod
     def merge(
-            base_properties: Dict[str, str],
-            override_properties: Dict[str, str]) -> Dict[str, str]:
+            base_properties: Optional[Dict[str, str]],
+            override_properties: Optional[Dict[str, str]]) -> Dict[str, str]:
         if override_properties is None:
             override_properties = {}
         if base_properties is None:
diff --git a/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py 
b/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py
index f686dc66ea..7769ba639b 100644
--- a/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py
+++ b/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py
@@ -18,7 +18,7 @@ limitations under the License.
 import logging
 import threading
 import time
-from typing import Optional
+from typing import Optional, Union
 
 from cachetools import TTLCache
 
@@ -41,40 +41,55 @@ class RESTTokenFileIO(FileIO):
     _FILE_IO_CACHE_MAXSIZE = 1000
     _FILE_IO_CACHE_TTL = 36000  # 10 hours in seconds
     
+    _FILE_IO_CACHE: TTLCache = None
+    _FILE_IO_CACHE_LOCK = threading.Lock()
+    
+    _TOKEN_CACHE: dict = {}
+    _TOKEN_LOCKS: dict = {}
+    _TOKEN_LOCKS_LOCK = threading.Lock()
+    
+    @classmethod
+    def _get_file_io_cache(cls) -> TTLCache:
+        if cls._FILE_IO_CACHE is None:
+            with cls._FILE_IO_CACHE_LOCK:
+                if cls._FILE_IO_CACHE is None:
+                    cls._FILE_IO_CACHE = TTLCache(
+                        maxsize=cls._FILE_IO_CACHE_MAXSIZE,
+                        ttl=cls._FILE_IO_CACHE_TTL
+                    )
+        return cls._FILE_IO_CACHE
+    
     def __init__(self, identifier: Identifier, path: str,
-                 catalog_options: Optional[Options] = None):
+                 catalog_options: Optional[Union[dict, Options]] = None):
         self.identifier = identifier
         self.path = path
-        self.catalog_options = catalog_options
-        self.properties = catalog_options or Options({})  # For compatibility 
with refresh_token()
+        if catalog_options is None:
+            self.catalog_options = None
+        elif isinstance(catalog_options, dict):
+            self.catalog_options = Options(catalog_options)
+        else:
+            # Assume it's already an Options object
+            self.catalog_options = catalog_options
+        self.properties = self.catalog_options or Options({})  # For 
compatibility with refresh_token()
         self.token: Optional[RESTToken] = None
         self.api_instance: Optional[RESTApi] = None
         self.lock = threading.Lock()
         self.log = logging.getLogger(__name__)
         self._uri_reader_factory_cache: Optional[UriReaderFactory] = None
-        self._file_io_cache: TTLCache = TTLCache(
-            maxsize=self._FILE_IO_CACHE_MAXSIZE,
-            ttl=self._FILE_IO_CACHE_TTL
-        )
 
     def __getstate__(self):
         state = self.__dict__.copy()
         # Remove non-serializable objects
         state.pop('lock', None)
         state.pop('api_instance', None)
-        state.pop('_file_io_cache', None)
         state.pop('_uri_reader_factory_cache', None)
         # token can be serialized, but we'll refresh it on deserialization
         return state
 
     def __setstate__(self, state):
         self.__dict__.update(state)
-        # Recreate lock and cache after deserialization
+        # Recreate lock after deserialization
         self.lock = threading.Lock()
-        self._file_io_cache = TTLCache(
-            maxsize=self._FILE_IO_CACHE_MAXSIZE,
-            ttl=self._FILE_IO_CACHE_TTL
-        )
         self._uri_reader_factory_cache = None
         # api_instance will be recreated when needed
         self.api_instance = None
@@ -86,25 +101,36 @@ class RESTTokenFileIO(FileIO):
             return FileIO.get(self.path, self.catalog_options or Options({}))
         
         cache_key = self.token
+        cache = self._get_file_io_cache()
         
-        file_io = self._file_io_cache.get(cache_key)
+        file_io = cache.get(cache_key)
         if file_io is not None:
             return file_io
         
-        with self.lock:
-            file_io = self._file_io_cache.get(cache_key)
+        with self._FILE_IO_CACHE_LOCK:
+            self.try_to_refresh_token()
+            
+            if self.token is None:
+                return FileIO.get(self.path, self.catalog_options or 
Options({}))
+            
+            cache_key = self.token
+            cache = self._get_file_io_cache()
+            file_io = cache.get(cache_key)
             if file_io is not None:
                 return file_io
             
-            merged_token = 
self._merge_token_with_catalog_options(self.token.token)
             merged_properties = RESTUtil.merge(
                 self.catalog_options.to_map() if self.catalog_options else {},
-                merged_token
+                self.token.token
             )
+            if self.catalog_options:
+                dlf_oss_endpoint = 
self.catalog_options.get(CatalogOptions.DLF_OSS_ENDPOINT)
+                if dlf_oss_endpoint and dlf_oss_endpoint.strip():
+                    merged_properties[OssOptions.OSS_ENDPOINT.key()] = 
dlf_oss_endpoint
             merged_options = Options(merged_properties)
             
             file_io = PyArrowFileIO(self.path, merged_options)
-            self._file_io_cache[cache_key] = file_io
+            cache[cache_key] = file_io
             return file_io
 
     def _merge_token_with_catalog_options(self, token: dict) -> dict:
@@ -180,16 +206,55 @@ class RESTTokenFileIO(FileIO):
         return self.file_io().filesystem
 
     def try_to_refresh_token(self):
-        if self.should_refresh():
-            with self.lock:
-                if self.should_refresh():
-                    self.refresh_token()
+        identifier_str = str(self.identifier)
+        
+        if self.token is not None and not self._is_token_expired(self.token):
+            return
+        
+        cached_token = self._get_cached_token(identifier_str)
+        if cached_token and not self._is_token_expired(cached_token):
+            self.token = cached_token
+            return
+        
+        global_lock = self._get_global_token_lock(identifier_str)
+        
+        with global_lock:
+            cached_token = self._get_cached_token(identifier_str)
+            if cached_token and not self._is_token_expired(cached_token):
+                self.token = cached_token
+                return
+            
+            token_to_check = cached_token if cached_token else self.token
+            if token_to_check is None or 
self._is_token_expired(token_to_check):
+                self.refresh_token()
+                self._set_cached_token(identifier_str, self.token)
+
+    def _get_cached_token(self, identifier_str: str) -> Optional[RESTToken]:
+        with self._TOKEN_LOCKS_LOCK:
+            return self._TOKEN_CACHE.get(identifier_str)
+    
+    def _set_cached_token(self, identifier_str: str, token: RESTToken):
+        with self._TOKEN_LOCKS_LOCK:
+            self._TOKEN_CACHE[identifier_str] = token
+    
+    def _is_token_expired(self, token: Optional[RESTToken]) -> bool:
+        if token is None:
+            return True
+        current_time = int(time.time() * 1000)
+        return (token.expire_at_millis - current_time) < 
RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS
+    
+    def _get_global_token_lock(self, identifier_str: str) -> threading.Lock:
+        with self._TOKEN_LOCKS_LOCK:
+            if identifier_str not in self._TOKEN_LOCKS:
+                self._TOKEN_LOCKS[identifier_str] = threading.Lock()
+            return self._TOKEN_LOCKS[identifier_str]
 
     def should_refresh(self):
         if self.token is None:
             return True
         current_time = int(time.time() * 1000)
-        return (self.token.expire_at_millis - current_time) < 
RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS
+        time_until_expiry = self.token.expire_at_millis - current_time
+        return time_until_expiry < RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS
 
     def refresh_token(self):
         self.log.info(f"begin refresh data token for identifier 
[{self.identifier}]")
@@ -200,17 +265,14 @@ class RESTTokenFileIO(FileIO):
         self.log.info(
             f"end refresh data token for identifier [{self.identifier}] 
expiresAtMillis [{response.expires_at_millis}]"
         )
-        self.token = RESTToken(response.token, response.expires_at_millis)
+        
+        merged_token_dict = 
self._merge_token_with_catalog_options(response.token)
+        new_token = RESTToken(merged_token_dict, response.expires_at_millis)
+        self.token = new_token
 
     def valid_token(self):
         self.try_to_refresh_token()
         return self.token
 
     def close(self):
-        with self.lock:
-            for file_io in self._file_io_cache.values():
-                try:
-                    file_io.close()
-                except Exception as e:
-                    self.log.warning(f"Error closing cached FileIO: {e}")
-            self._file_io_cache.clear()
+        pass
diff --git a/paimon-python/pypaimon/read/reader/lance_utils.py 
b/paimon-python/pypaimon/read/reader/lance_utils.py
index c219dc6704..2e3a331e4b 100644
--- a/paimon-python/pypaimon/read/reader/lance_utils.py
+++ b/paimon-python/pypaimon/read/reader/lance_utils.py
@@ -26,12 +26,24 @@ from pypaimon.common.options.config import OssOptions
 
 def to_lance_specified(file_io: FileIO, file_path: str) -> Tuple[str, 
Optional[Dict[str, str]]]:
     """Convert path and extract storage options for Lance format."""
+    # For RESTTokenFileIO, get underlying FileIO which already has latest 
token merged
+    # This follows Java implementation: ((RESTTokenFileIO) fileIO).fileIO()
+    # The file_io() method will refresh token and return a FileIO with merged 
token
     if hasattr(file_io, 'file_io'):
+        # Call file_io() to get underlying FileIO with latest token
+        # This ensures token is refreshed and merged with catalog options
         file_io = file_io.file_io()
     
+    # Now get properties from the underlying FileIO (which has latest token)
+    if hasattr(file_io, 'get_merged_properties'):
+        properties = file_io.get_merged_properties()
+    else:
+        properties = file_io.properties if hasattr(file_io, 'properties') and 
file_io.properties else None
+
     scheme, _, _ = file_io.parse_location(file_path)
-    storage_options = None
     file_path_for_lance = file_io.to_filesystem_path(file_path)
+    
+    storage_options = None
 
     if scheme in {'file', None} or not scheme:
         if not os.path.isabs(file_path_for_lance):
@@ -40,37 +52,40 @@ def to_lance_specified(file_io: FileIO, file_path: str) -> 
Tuple[str, Optional[D
         file_path_for_lance = file_path
 
     if scheme == 'oss':
-        storage_options = {}
-        if hasattr(file_io, 'properties'):
-            for key, value in file_io.properties.data.items():
+        parsed = urlparse(file_path)
+        bucket = parsed.netloc
+        path = parsed.path.lstrip('/')
+
+        if properties:
+            storage_options = {}
+            for key, value in properties.to_map().items():
                 if str(key).startswith('fs.'):
                     storage_options[key] = value
 
-            parsed = urlparse(file_path)
-            bucket = parsed.netloc
-            path = parsed.path.lstrip('/')
-
-            endpoint = file_io.properties.get(OssOptions.OSS_ENDPOINT)
+            endpoint = properties.get(OssOptions.OSS_ENDPOINT)
             if endpoint:
                 endpoint_clean = endpoint.replace('http://', 
'').replace('https://', '')
                 storage_options['endpoint'] = 
f"https://{bucket}.{endpoint_clean}";
 
-            if file_io.properties.contains(OssOptions.OSS_ACCESS_KEY_ID):
-                storage_options['access_key_id'] = 
file_io.properties.get(OssOptions.OSS_ACCESS_KEY_ID)
-                storage_options['oss_access_key_id'] = 
file_io.properties.get(OssOptions.OSS_ACCESS_KEY_ID)
-            if file_io.properties.contains(OssOptions.OSS_ACCESS_KEY_SECRET):
-                storage_options['secret_access_key'] = 
file_io.properties.get(OssOptions.OSS_ACCESS_KEY_SECRET)
-                storage_options['oss_secret_access_key'] = 
file_io.properties.get(OssOptions.OSS_ACCESS_KEY_SECRET)
-            if file_io.properties.contains(OssOptions.OSS_SECURITY_TOKEN):
-                storage_options['session_token'] = 
file_io.properties.get(OssOptions.OSS_SECURITY_TOKEN)
-                storage_options['oss_session_token'] = 
file_io.properties.get(OssOptions.OSS_SECURITY_TOKEN)
-            if file_io.properties.contains(OssOptions.OSS_ENDPOINT):
-                storage_options['oss_endpoint'] = 
file_io.properties.get(OssOptions.OSS_ENDPOINT)
+            if properties.contains(OssOptions.OSS_ACCESS_KEY_ID):
+                storage_options['access_key_id'] = 
properties.get(OssOptions.OSS_ACCESS_KEY_ID)
+                storage_options['oss_access_key_id'] = 
properties.get(OssOptions.OSS_ACCESS_KEY_ID)
+            if properties.contains(OssOptions.OSS_ACCESS_KEY_SECRET):
+                storage_options['secret_access_key'] = 
properties.get(OssOptions.OSS_ACCESS_KEY_SECRET)
+                storage_options['oss_secret_access_key'] = 
properties.get(OssOptions.OSS_ACCESS_KEY_SECRET)
+            if properties.contains(OssOptions.OSS_SECURITY_TOKEN):
+                storage_options['session_token'] = 
properties.get(OssOptions.OSS_SECURITY_TOKEN)
+                storage_options['oss_session_token'] = 
properties.get(OssOptions.OSS_SECURITY_TOKEN)
+            if properties.contains(OssOptions.OSS_ENDPOINT):
+                storage_options['oss_endpoint'] = 
properties.get(OssOptions.OSS_ENDPOINT)
+            
             storage_options['virtual_hosted_style_request'] = 'true'
 
             if bucket and path:
                 file_path_for_lance = f"oss://{bucket}/{path}"
             elif bucket:
                 file_path_for_lance = f"oss://{bucket}"
+        else:
+            storage_options = None
 
     return file_path_for_lance, storage_options
diff --git a/paimon-python/pypaimon/tests/rest/rest_server.py 
b/paimon-python/pypaimon/tests/rest/rest_server.py
index d556f7f55c..cb7b321083 100755
--- a/paimon-python/pypaimon/tests/rest/rest_server.py
+++ b/paimon-python/pypaimon/tests/rest/rest_server.py
@@ -24,9 +24,12 @@ import uuid
 from dataclasses import dataclass
 from http.server import BaseHTTPRequestHandler, HTTPServer
 from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
 from urllib.parse import urlparse
 
+if TYPE_CHECKING:
+    from pypaimon.catalog.rest.rest_token import RESTToken
+
 from pypaimon.api.api_request import (AlterTableRequest, CreateDatabaseRequest,
                                       CreateTableRequest, RenameTableRequest)
 from pypaimon.api.api_response import (ConfigResponse, GetDatabaseResponse,
@@ -213,6 +216,7 @@ class RESTCatalogServer:
         self.table_partitions_store: Dict[str, List] = {}
         self.no_permission_databases: List[str] = []
         self.no_permission_tables: List[str] = []
+        self.table_token_store: Dict[str, "RESTToken"] = {}
 
         # Initialize mock catalog (simplified)
         self.data_path = data_path
@@ -469,10 +473,12 @@ class RESTCatalogServer:
             # Basic table operations (GET, DELETE, etc.)
             return self._table_handle(method, data, lookup_identifier)
         elif len(path_parts) == 4:
-            # Extended operations (e.g., commit)
+            # Extended operations (e.g., commit, token)
             operation = path_parts[3]
             if operation == "commit":
                 return self._table_commit_handle(method, data, 
lookup_identifier, branch_part)
+            elif operation == "token":
+                return self._table_token_handle(method, lookup_identifier)
             else:
                 return self._mock_response(ErrorResponse(None, None, "Not 
Found", 404), 404)
         return self._mock_response(ErrorResponse(None, None, "Not Found", 
404), 404)
@@ -574,6 +580,44 @@ class RESTCatalogServer:
 
         return self._mock_response(ErrorResponse(None, None, "Method Not 
Allowed", 405), 405)
 
+    def _table_token_handle(self, method: str, identifier: Identifier) -> 
Tuple[str, int]:
+        if method != "GET":
+            return self._mock_response(ErrorResponse(None, None, "Method Not 
Allowed", 405), 405)
+
+        if identifier.get_full_name() not in self.table_metadata_store:
+            raise TableNotExistException(identifier)
+
+        from pypaimon.api.api_response import GetTableTokenResponse
+
+        token_key = identifier.get_full_name()
+        if token_key in self.table_token_store:
+            rest_token = self.table_token_store[token_key]
+            response = GetTableTokenResponse(
+                token=rest_token.token,
+                expires_at_millis=rest_token.expire_at_millis
+            )
+        else:
+            default_token = {
+                "akId": "akId" + str(int(time.time() * 1000)),
+                "akSecret": "akSecret" + str(int(time.time() * 1000))
+            }
+            response = GetTableTokenResponse(
+                token=default_token,
+                expires_at_millis=int(time.time() * 1000) + 3600_000  # 1 hour 
from now
+            )
+
+        return self._mock_response(response, 200)
+
+    def set_table_token(self, identifier: Identifier, token: "RESTToken") -> 
None:
+        self.table_token_store[identifier.get_full_name()] = token
+
+    def get_table_token(self, identifier: Identifier) -> Optional["RESTToken"]:
+        return self.table_token_store.get(identifier.get_full_name())
+
+    def reset_table_token(self, identifier: Identifier) -> None:
+        if identifier.get_full_name() in self.table_token_store:
+            del self.table_token_store[identifier.get_full_name()]
+
     def _table_commit_handle(self, method: str, data: str, identifier: 
Identifier,
                              branch: str = None) -> Tuple[str, int]:
         """Handle table commit operations"""

Reply via email to