This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 76f62230f2 [python] Support pfvs in python (#5930)
76f62230f2 is described below
commit 76f62230f27da79e1a98a95c4414667c6493121b
Author: jerry <[email protected]>
AuthorDate: Thu Jul 24 18:01:55 2025 +0800
[python] Support pfvs in python (#5930)
---
.github/workflows/paimon-python-checks.yml | 2 +-
paimon-python/pypaimon/api/__init__.py | 92 ++-
paimon-python/pypaimon/api/api_response.py | 64 +-
paimon-python/pypaimon/api/api_resquest.py | 20 +
paimon-python/pypaimon/api/auth.py | 57 +-
paimon-python/pypaimon/api/client.py | 53 +-
paimon-python/pypaimon/api/config.py | 43 ++
paimon-python/pypaimon/api/rest_json.py | 12 +-
paimon-python/pypaimon/api/token_loader.py | 4 +-
paimon-python/pypaimon/api/typedef.py | 22 +-
paimon-python/pypaimon/pvfs/__init__.py | 815 +++++++++++++++++++++
paimon-python/pypaimon/tests/api_test.py | 739 +------------------
paimon-python/pypaimon/tests/pvfs_test.py | 206 ++++++
.../pypaimon/tests/{api_test.py => rest_server.py} | 401 +++-------
paimon-python/setup.py | 7 +-
15 files changed, 1378 insertions(+), 1159 deletions(-)
diff --git a/.github/workflows/paimon-python-checks.yml
b/.github/workflows/paimon-python-checks.yml
index 7426bb4dcd..4609228488 100644
--- a/.github/workflows/paimon-python-checks.yml
+++ b/.github/workflows/paimon-python-checks.yml
@@ -44,7 +44,7 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install dependencies
run: |
- python -m pip install -q flake8==4.0.1 pytest~=7.0 requests 2>&1
>/dev/null
+ python -m pip install -q readerwriterlock==1.0.9 fsspec==2024.3.1
cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0 pyarrow==15.0.2 numpy==1.24.3
pandas==2.0.3 flake8==4.0.1 pytest~=7.0 requests 2>&1 >/dev/null
- name: Run lint-python.sh
run: |
chmod +x paimon-python/dev/lint-python.sh
diff --git a/paimon-python/pypaimon/api/__init__.py
b/paimon-python/pypaimon/api/__init__.py
index 6bafc45d98..f235c2815e 100644
--- a/paimon-python/pypaimon/api/__init__.py
+++ b/paimon-python/pypaimon/api/__init__.py
@@ -18,15 +18,23 @@
import logging
from typing import Dict, List, Optional, Callable
from urllib.parse import unquote
-
-from .api_response import PagedList, GetTableResponse, ListDatabasesResponse,
ListTablesResponse, \
- GetDatabaseResponse, ConfigResponse, PagedResponse
-from .api_resquest import CreateDatabaseRequest, AlterDatabaseRequest
-from .typedef import Identifier, RESTCatalogOptions
-from .client import HttpClient
-from .auth import DLFAuthProvider, RESTAuthFunction
-from .token_loader import DLFToken, DLFTokenLoaderFactory
-from .typedef import T
+from pypaimon.api.auth import RESTAuthFunction, AuthProviderFactory
+from pypaimon.api.api_response import (
+ PagedList,
+ GetTableResponse,
+ ListDatabasesResponse,
+ ListTablesResponse,
+ GetDatabaseResponse,
+ ConfigResponse,
+ PagedResponse,
+ GetTableTokenResponse, Schema,
+)
+from pypaimon.api.api_resquest import CreateDatabaseRequest,
AlterDatabaseRequest, RenameTableRequest, \
+ CreateTableRequest
+from pypaimon.api.typedef import Identifier
+from pypaimon.api.config import RESTCatalogOptions
+from pypaimon.api.client import HttpClient
+from pypaimon.api.typedef import T
class RESTException(Exception):
@@ -75,35 +83,44 @@ class ResourcePaths:
TABLES = "tables"
TABLE_DETAILS = "table-details"
- def __init__(self, base_path: str = ""):
- self.base_path = base_path.rstrip("/")
+ def __init__(self, prefix: str):
+ self.base_path = f"/{self.V1}/{prefix}".rstrip("/")
@classmethod
def for_catalog_properties(
cls, options: dict[str, str]) -> "ResourcePaths":
prefix = options.get(RESTCatalogOptions.PREFIX, "")
- return cls(f"/{cls.V1}/{prefix}" if prefix else f"/{cls.V1}")
+ return cls(prefix)
- def config(self) -> str:
- return f"/{self.V1}/config"
+ @staticmethod
+ def config() -> str:
+ return f"/{ResourcePaths.V1}/config"
def databases(self) -> str:
return f"{self.base_path}/{self.DATABASES}"
def database(self, name: str) -> str:
- return f"{self.base_path}/{self.DATABASES}/{name}"
+ return
f"{self.base_path}/{self.DATABASES}/{RESTUtil.encode_string(name)}"
def tables(self, database_name: Optional[str] = None) -> str:
if database_name:
- return
f"{self.base_path}/{self.DATABASES}/{database_name}/{self.TABLES}"
+ return
f"{self.base_path}/{self.DATABASES}/{RESTUtil.encode_string(database_name)}/{self.TABLES}"
return f"{self.base_path}/{self.TABLES}"
def table(self, database_name: str, table_name: str) -> str:
- return
f"{self.base_path}/{self.DATABASES}/{database_name}/{self.TABLES}/{table_name}"
+ return
(f"{self.base_path}/{self.DATABASES}/{RESTUtil.encode_string(database_name)}"
+ f"/{self.TABLES}/{RESTUtil.encode_string(table_name)}")
def table_details(self, database_name: str) -> str:
return
f"{self.base_path}/{self.DATABASES}/{database_name}/{self.TABLE_DETAILS}"
+ def table_token(self, database_name: str, table_name: str) -> str:
+ return
(f"{self.base_path}/{self.DATABASES}/{RESTUtil.encode_string(database_name)}"
+ f"/{self.TABLES}/{RESTUtil.encode_string(table_name)}/token")
+
+ def rename_table(self) -> str:
+ return f"{self.base_path}/{self.TABLES}/rename"
+
class RESTApi:
HEADER_PREFIX = "header."
@@ -115,11 +132,7 @@ class RESTApi:
def __init__(self, options: Dict[str, str], config_required: bool = True):
self.logger = logging.getLogger(self.__class__.__name__)
self.client = HttpClient(options.get(RESTCatalogOptions.URI))
- auth_provider = DLFAuthProvider(
- options.get(RESTCatalogOptions.DLF_REGION),
- DLFToken.from_options(options),
- DLFTokenLoaderFactory.create_token_loader(options)
- )
+ auth_provider = AuthProviderFactory.create_auth_provider(options)
base_headers = RESTUtil.extract_prefix_map(options, self.HEADER_PREFIX)
if config_required:
@@ -130,10 +143,10 @@ class RESTApi:
warehouse)
config_response = self.client.get_with_params(
- ResourcePaths().config(),
+ ResourcePaths.config(),
query_params,
ConfigResponse,
- RESTAuthFunction({}, auth_provider),
+ RESTAuthFunction(base_headers, auth_provider),
)
options = config_response.merge(options)
base_headers.update(
@@ -285,6 +298,13 @@ class RESTApi:
tables = response.data() or []
return PagedList(tables, response.get_next_page_token())
+ def create_table(self, identifier: Identifier, schema: Schema) -> None:
+ request = CreateTableRequest(identifier, schema)
+ return self.client.post(
+ self.resource_paths.tables(identifier.database_name),
+ request,
+ self.rest_auth_function)
+
def get_table(self, identifier: Identifier) -> GetTableResponse:
return self.client.get(
self.resource_paths.table(
@@ -293,3 +313,27 @@ class RESTApi:
GetTableResponse,
self.rest_auth_function,
)
+
+ def drop_table(self, identifier: Identifier) -> GetTableResponse:
+ return self.client.delete(
+ self.resource_paths.table(
+ identifier.database_name,
+ identifier.object_name),
+ self.rest_auth_function,
+ )
+
+ def rename_table(self, source_identifier: Identifier, target_identifier:
Identifier) -> None:
+ request = RenameTableRequest(source_identifier, target_identifier)
+ return self.client.post(
+ self.resource_paths.rename_table(),
+ request,
+ self.rest_auth_function)
+
+ def load_table_token(self, identifier: Identifier) ->
GetTableTokenResponse:
+ return self.client.get(
+ self.resource_paths.table_token(
+ identifier.database_name,
+ identifier.object_name),
+ GetTableTokenResponse,
+ self.rest_auth_function,
+ )
diff --git a/paimon-python/pypaimon/api/api_response.py
b/paimon-python/pypaimon/api/api_response.py
index 61169f7cd3..7e0b6a564a 100644
--- a/paimon-python/pypaimon/api/api_response.py
+++ b/paimon-python/pypaimon/api/api_response.py
@@ -37,18 +37,17 @@ class RESTResponse(ABC):
@dataclass
class ErrorResponse(RESTResponse):
-
resource_type: Optional[str] = json_field("resourceType", default=None)
resource_name: Optional[str] = json_field("resourceName", default=None)
message: Optional[str] = json_field("message", default=None)
code: Optional[int] = json_field("code", default=None)
def __init__(
- self,
- resource_type: Optional[str] = None,
- resource_name: Optional[str] = None,
- message: Optional[str] = None,
- code: Optional[int] = None,
+ self,
+ resource_type: Optional[str] = None,
+ resource_name: Optional[str] = None,
+ message: Optional[str] = None,
+ code: Optional[int] = None,
):
self.resource_type = resource_type
self.resource_name = resource_name
@@ -203,18 +202,18 @@ class GetTableResponse(AuditRESTResponse):
schema: Optional[Schema] = json_field(FIELD_SCHEMA, default=None)
def __init__(
- self,
- id: str,
- name: str,
- path: str,
- is_external: bool,
- schema_id: int,
- schema: Schema,
- owner: Optional[str] = None,
- created_at: Optional[int] = None,
- created_by: Optional[str] = None,
- updated_at: Optional[int] = None,
- updated_by: Optional[str] = None,
+ self,
+ id: str,
+ name: str,
+ path: str,
+ is_external: bool,
+ schema_id: int,
+ schema: Schema,
+ owner: Optional[str] = None,
+ created_at: Optional[int] = None,
+ created_by: Optional[str] = None,
+ updated_at: Optional[int] = None,
+ updated_by: Optional[str] = None,
):
super().__init__(owner, created_at, created_by, updated_at, updated_by)
self.id = id
@@ -239,16 +238,16 @@ class GetDatabaseResponse(AuditRESTResponse):
FIELD_OPTIONS, default_factory=dict)
def __init__(
- self,
- id: Optional[str] = None,
- name: Optional[str] = None,
- location: Optional[str] = None,
- options: Optional[Dict[str, str]] = None,
- owner: Optional[str] = None,
- created_at: Optional[int] = None,
- created_by: Optional[str] = None,
- updated_at: Optional[int] = None,
- updated_by: Optional[str] = None,
+ self,
+ id: Optional[str] = None,
+ name: Optional[str] = None,
+ location: Optional[str] = None,
+ options: Optional[Dict[str, str]] = None,
+ owner: Optional[str] = None,
+ created_at: Optional[int] = None,
+ created_by: Optional[str] = None,
+ updated_at: Optional[int] = None,
+ updated_by: Optional[str] = None,
):
super().__init__(owner, created_at, created_by, updated_at, updated_by)
self.id = id
@@ -279,3 +278,12 @@ class ConfigResponse(RESTResponse):
merged = options.copy()
merged.update(self.defaults)
return merged
+
+
+@dataclass
+class GetTableTokenResponse(RESTResponse):
+ FIELD_TOKEN = "token"
+ FIELD_EXPIRES_AT_MILLIS = "expiresAtMillis"
+
+ token: Dict[str, str] = json_field(FIELD_TOKEN, default=None)
+ expires_at_millis: Optional[int] = json_field(FIELD_EXPIRES_AT_MILLIS,
default=None)
diff --git a/paimon-python/pypaimon/api/api_resquest.py
b/paimon-python/pypaimon/api/api_resquest.py
index 8e7d14e418..7bc4f6b2a1 100644
--- a/paimon-python/pypaimon/api/api_resquest.py
+++ b/paimon-python/pypaimon/api/api_resquest.py
@@ -20,6 +20,8 @@ from abc import ABC
from dataclasses import dataclass
from typing import Dict, List
+from .api_response import Schema
+from .typedef import Identifier
from .rest_json import json_field
@@ -43,3 +45,21 @@ class AlterDatabaseRequest(RESTRequest):
removals: List[str] = json_field(FIELD_REMOVALS)
updates: Dict[str, str] = json_field(FIELD_UPDATES)
+
+
+@dataclass
+class RenameTableRequest(RESTRequest):
+ FIELD_SOURCE = "source"
+ FIELD_DESTINATION = "destination"
+
+ source: Identifier = json_field(FIELD_SOURCE)
+ destination: Identifier = json_field(FIELD_DESTINATION)
+
+
+@dataclass
+class CreateTableRequest(RESTRequest):
+ FIELD_IDENTIFIER = "identifier"
+ FIELD_SCHEMA = "schema"
+
+ identifier: Identifier = json_field(FIELD_IDENTIFIER)
+ schema: Schema = json_field(FIELD_SCHEMA)
diff --git a/paimon-python/pypaimon/api/auth.py
b/paimon-python/pypaimon/api/auth.py
index b30c1ec025..dc51df0740 100644
--- a/paimon-python/pypaimon/api/auth.py
+++ b/paimon-python/pypaimon/api/auth.py
@@ -27,13 +27,14 @@ from typing import Optional, Dict
from .token_loader import DLFTokenLoader, DLFToken
from .typedef import RESTAuthParameter
+from .config import RESTCatalogOptions
class AuthProvider(ABC):
@abstractmethod
def merge_auth_header(
- self, base_header: Dict[str, str], parammeter: RESTAuthParameter
+ self, base_header: Dict[str, str], parammeter: RESTAuthParameter
) -> Dict[str, str]:
"""Merge authorization header into header."""
@@ -41,8 +42,7 @@ class AuthProvider(ABC):
class RESTAuthFunction:
def __init__(self,
- init_header: Dict[str,
- str],
+ init_header: Dict[str, str],
auth_provider: AuthProvider):
self.init_header = init_header.copy() if init_header else {}
self.auth_provider = auth_provider
@@ -57,6 +57,35 @@ class RESTAuthFunction:
return self.__call__(rest_auth_parameter)
+class AuthProviderFactory:
+
+ @staticmethod
+ def create_auth_provider(options: Dict[str, str]) -> AuthProvider:
+ provider = options.get(RESTCatalogOptions.TOKEN_PROVIDER)
+ if provider == 'bear':
+ token = options.get(RESTCatalogOptions.TOKEN)
+ return BearTokenAuthProvider(token)
+ elif provider == 'dlf':
+ return DLFAuthProvider(
+ options.get(RESTCatalogOptions.DLF_REGION),
+ DLFToken.from_options(options)
+ )
+ raise ValueError('Unknown auth provider')
+
+
+class BearTokenAuthProvider(AuthProvider):
+
+ def __init__(self, token: str):
+ self.token = token
+
+ def merge_auth_header(
+ self, base_header: Dict[str, str], rest_auth_parameter:
RESTAuthParameter
+ ) -> Dict[str, str]:
+ headers_with_auth = base_header.copy()
+ headers_with_auth['Authorization'] = f'Bearer {self.token}'
+ return headers_with_auth
+
+
class DLFAuthProvider(AuthProvider):
DLF_AUTHORIZATION_HEADER_KEY = "Authorization"
DLF_CONTENT_MD5_HEADER_KEY = "Content-MD5"
@@ -99,7 +128,7 @@ class DLFAuthProvider(AuthProvider):
return self.token
def merge_auth_header(
- self, base_header: Dict[str, str], rest_auth_parameter:
RESTAuthParameter
+ self, base_header: Dict[str, str], rest_auth_parameter:
RESTAuthParameter
) -> Dict[str, str]:
try:
date_time = base_header.get(
@@ -134,7 +163,7 @@ class DLFAuthProvider(AuthProvider):
@classmethod
def generate_sign_headers(
- cls, data: Optional[str], date_time: str, security_token: Optional[str]
+ cls, data: Optional[str], date_time: str, security_token:
Optional[str]
) -> Dict[str, str]:
sign_headers = {}
@@ -172,13 +201,13 @@ class DLFAuthSignature:
@classmethod
def get_authorization(
- cls,
- rest_auth_parameter: RESTAuthParameter,
- dlf_token: DLFToken,
- region: str,
- headers: Dict[str, str],
- date_time: str,
- date: str,
+ cls,
+ rest_auth_parameter: RESTAuthParameter,
+ dlf_token: DLFToken,
+ region: str,
+ headers: Dict[str, str],
+ date_time: str,
+ date: str,
) -> str:
try:
canonical_request = cls.get_canonical_request(
@@ -231,7 +260,7 @@ class DLFAuthSignature:
@classmethod
def get_canonical_request(
- cls, rest_auth_parameter: RESTAuthParameter, headers: Dict[str, str]
+ cls, rest_auth_parameter: RESTAuthParameter, headers: Dict[str,
str]
) -> str:
canonical_request = cls.NEW_LINE.join(
[rest_auth_parameter.method, rest_auth_parameter.path]
@@ -278,7 +307,7 @@ class DLFAuthSignature:
@classmethod
def _build_sorted_signed_headers_map(
- cls, headers: Optional[Dict[str, str]]
+ cls, headers: Optional[Dict[str, str]]
) -> OrderedDict:
sorted_headers = OrderedDict()
diff --git a/paimon-python/pypaimon/api/client.py
b/paimon-python/pypaimon/api/client.py
index d83f79a202..7f7203a1c5 100644
--- a/paimon-python/pypaimon/api/client.py
+++ b/paimon-python/pypaimon/api/client.py
@@ -18,6 +18,7 @@ limitations under the License.
import json
import logging
+import time
import traceback
import urllib.parse
from abc import ABC, abstractmethod
@@ -175,7 +176,7 @@ class DefaultErrorHandler(ErrorHandler):
code = error.code
# Format message with request ID if not default
- if LoggingInterceptor.DEFAULT_REQUEST_ID == request_id:
+ if HttpClient.DEFAULT_REQUEST_ID == request_id:
message = error.message
else:
# If we have a requestId, append it to the message
@@ -256,22 +257,6 @@ class ExponentialRetry:
return Retry(**retry_kwargs)
-class LoggingInterceptor:
- REQUEST_ID_KEY = "x-request-id"
- DEFAULT_REQUEST_ID = "unknown"
-
- def __init__(self):
- self.logger = logging.getLogger(self.__class__.__name__)
-
- def log_request(self, method: str, url: str, headers: Dict[str, str]) ->
None:
- request_id = headers.get(self.REQUEST_ID_KEY, self.DEFAULT_REQUEST_ID)
- self.logger.debug(f"Request [{request_id}]: {method} {url}")
-
- def log_response(self, status_code: int, headers: Dict[str, str]) -> None:
- request_id = headers.get(self.REQUEST_ID_KEY, self.DEFAULT_REQUEST_ID)
- self.logger.debug(f"Response [{request_id}]: {status_code}")
-
-
class RESTClient(ABC):
@abstractmethod
@@ -360,19 +345,19 @@ def _get_headers(path: str, method: str, query_params:
Dict[str, str], data: str
class HttpClient(RESTClient):
+ REQUEST_ID_KEY = "x-request-id"
+ DEFAULT_REQUEST_ID = "unknown"
+
def __init__(self, uri: str):
self.logger = logging.getLogger(self.__class__.__name__)
self.uri = _normalize_uri(uri)
self.error_handler = DefaultErrorHandler.get_instance()
- self.logging_interceptor = LoggingInterceptor()
-
self.session = requests.Session()
retry_interceptor = ExponentialRetry(max_retries=3)
- adapter = HTTPAdapter(max_retries=retry_interceptor.adapter)
- self.session.mount("http://", adapter)
- self.session.mount("https://", adapter)
+ self.session.mount("http://", retry_interceptor.adapter)
+ self.session.mount("https://", retry_interceptor.adapter)
self.session.timeout = (180, 180)
@@ -455,27 +440,29 @@ class HttpClient(RESTClient):
headers: Optional[Dict[str, str]] = None,
response_type: Optional[Type[T]] = None) -> T:
try:
- if headers:
- self.logging_interceptor.log_request(method, url, headers)
-
+ start_time = time.time_ns()
response = self.session.request(
method=method,
url=url,
data=data.encode('utf-8') if data else None,
headers=headers
)
-
- response_headers = dict(response.headers)
- self.logging_interceptor.log_response(response.status_code,
response_headers)
-
+ duration_ms = (time.time_ns() - start_time) // 1_000_000
+ response_request_id = response.headers.get(self.REQUEST_ID_KEY,
self.DEFAULT_REQUEST_ID)
+
+ self.logger.info(
+ "[rest] requestId:%s method:%s url:%s status:%d duration:%dms",
+ response_request_id,
+ response.request.method,
+ response.url,
+ response.status_code,
+ duration_ms
+ )
response_body_str = response.text if response.text else None
if not response.ok:
error = _parse_error_response(response_body_str,
response.status_code)
- request_id = response.headers.get(
- LoggingInterceptor.REQUEST_ID_KEY,
- LoggingInterceptor.DEFAULT_REQUEST_ID
- )
+ request_id = response.headers.get(self.REQUEST_ID_KEY,
self.DEFAULT_REQUEST_ID)
self.error_handler.accept(error, request_id)
if response_type is not None and response_body_str is not None:
diff --git a/paimon-python/pypaimon/api/config.py
b/paimon-python/pypaimon/api/config.py
new file mode 100644
index 0000000000..dcdbb95c50
--- /dev/null
+++ b/paimon-python/pypaimon/api/config.py
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+class OssOptions:
+ OSS_ACCESS_KEY_ID = "fs.oss.accessKeyId"
+ OSS_ACCESS_KEY_SECRET = "fs.oss.accessKeySecret"
+ OSS_SECURITY_TOKEN = "fs.oss.securityToken"
+ OSS_ENDPOINT = "fs.oss.endpoint"
+
+
+class RESTCatalogOptions:
+ URI = "uri"
+ WAREHOUSE = "warehouse"
+ TOKEN_PROVIDER = "token.provider"
+ TOKEN = "token"
+ DLF_REGION = "dlf.region"
+ DLF_ACCESS_KEY_ID = "dlf.access-key-id"
+ DLF_ACCESS_KEY_SECRET = "dlf.access-key-secret"
+ DLF_ACCESS_SECURITY_TOKEN = "dlf.security-token"
+ DLF_TOKEN_LOADER = "dlf.token-loader"
+ DLF_TOKEN_ECS_ROLE_NAME = "dlf.token-ecs-role-name"
+ DLF_TOKEN_ECS_METADATA_URL = "dlf.token-ecs-metadata-url"
+ PREFIX = 'prefix'
+ HTTP_USER_AGENT_HEADER = 'header.HTTP_USER_AGENT'
+
+
+class PVFSOptions:
+ DEFAULT_CACHE_SIZE = 20
+ CACHE_SIZE = "cache_size"
diff --git a/paimon-python/pypaimon/api/rest_json.py
b/paimon-python/pypaimon/api/rest_json.py
index 95291ecb03..124dd69254 100644
--- a/paimon-python/pypaimon/api/rest_json.py
+++ b/paimon-python/pypaimon/api/rest_json.py
@@ -17,9 +17,9 @@
import json
from dataclasses import field, fields, is_dataclass
-from typing import Any, Type, Dict
+from typing import Any, Type, Dict, TypeVar
-from .typedef import T
+T = TypeVar("T")
def json_field(json_name: str, **kwargs):
@@ -70,15 +70,21 @@ class JSON:
"""Create instance from dictionary"""
# Create field name mapping (json_name -> field_name)
field_mapping = {}
+ type_mapping = {}
for field_info in fields(target_class):
json_name = field_info.metadata.get("json_name", field_info.name)
field_mapping[json_name] = field_info.name
+ if is_dataclass(field_info.type):
+ type_mapping[json_name] = field_info.type
# Map JSON data to field names
kwargs = {}
for json_name, value in data.items():
if json_name in field_mapping:
field_name = field_mapping[json_name]
- kwargs[field_name] = value
+ if field_name in type_mapping:
+ kwargs[field_name] = JSON.__from_dict(value,
type_mapping[json_name])
+ else:
+ kwargs[field_name] = value
return target_class(**kwargs)
diff --git a/paimon-python/pypaimon/api/token_loader.py
b/paimon-python/pypaimon/api/token_loader.py
index 46810a1121..223490646c 100644
--- a/paimon-python/pypaimon/api/token_loader.py
+++ b/paimon-python/pypaimon/api/token_loader.py
@@ -25,7 +25,7 @@ from requests.adapters import HTTPAdapter
from requests.exceptions import RequestException
from .rest_json import json_field, JSON
-from .typedef import RESTCatalogOptions
+from .config import RESTCatalogOptions
from .client import ExponentialRetry
@@ -59,7 +59,7 @@ class DLFToken:
@classmethod
def from_options(cls, options: Dict[str, str]) -> Optional['DLFToken']:
- from .typedef import RESTCatalogOptions
+ from .config import RESTCatalogOptions
if (options.get(RESTCatalogOptions.DLF_ACCESS_KEY_ID) is None
or options.get(RESTCatalogOptions.DLF_ACCESS_KEY_SECRET) is
None):
return None
diff --git a/paimon-python/pypaimon/api/typedef.py
b/paimon-python/pypaimon/api/typedef.py
index 4cf65738f9..0501d5e1e1 100644
--- a/paimon-python/pypaimon/api/typedef.py
+++ b/paimon-python/pypaimon/api/typedef.py
@@ -18,6 +18,8 @@
from dataclasses import dataclass
from typing import Optional, TypeVar, Dict
+from pypaimon.api.rest_json import json_field
+
T = TypeVar("T")
@@ -25,9 +27,9 @@ T = TypeVar("T")
class Identifier:
"""Table/View/Function identifier"""
- database_name: str
- object_name: str
- branch_name: Optional[str] = None
+ database_name: str = json_field("database", default=None)
+ object_name: str = json_field("object", default=None)
+ branch_name: Optional[str] = json_field("branch", default=None)
@classmethod
def create(cls, database_name: str, object_name: str) -> "Identifier":
@@ -70,17 +72,3 @@ class RESTAuthParameter:
path: str
data: str
parameters: Dict[str, str]
-
-
-class RESTCatalogOptions:
- URI = "uri"
- WAREHOUSE = "warehouse"
- TOKEN_PROVIDER = "token.provider"
- DLF_REGION = "dlf.region"
- DLF_ACCESS_KEY_ID = "dlf.access-key-id"
- DLF_ACCESS_KEY_SECRET = "dlf.access-key-secret"
- DLF_ACCESS_SECURITY_TOKEN = "dlf.security-token"
- DLF_TOKEN_LOADER = "dlf.token-loader"
- DLF_TOKEN_ECS_ROLE_NAME = "dlf.token-ecs-role-name"
- DLF_TOKEN_ECS_METADATA_URL = "dlf.token-ecs-metadata-url"
- PREFIX = 'prefix'
diff --git a/paimon-python/pypaimon/pvfs/__init__.py
b/paimon-python/pypaimon/pvfs/__init__.py
new file mode 100644
index 0000000000..36ba2f2909
--- /dev/null
+++ b/paimon-python/pypaimon/pvfs/__init__.py
@@ -0,0 +1,815 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import importlib
+import re
+import time
+import datetime
+from abc import ABC
+from dataclasses import dataclass
+from enum import Enum
+from typing import Dict, Any, Optional, Tuple
+
+from cachetools import LRUCache
+from readerwriterlock import rwlock
+
+import fsspec
+from fsspec import AbstractFileSystem
+from fsspec.implementations.local import LocalFileSystem
+
+from pypaimon.api import RESTApi, GetTableTokenResponse, Schema,
GetTableResponse
+from pypaimon.api.client import NoSuchResourceException, AlreadyExistsException
+from pypaimon.api.typedef import Identifier
+from pypaimon.api.config import RESTCatalogOptions, OssOptions, PVFSOptions
+
+PROTOCOL_NAME = "pvfs"
+
+
+class StorageType(Enum):
+ LOCAL = "file"
+ OSS = "oss"
+
+
+class PVFSIdentifier(ABC):
+ catalog: str
+
+
+@dataclass
+class PVFSCatalogIdentifier(PVFSIdentifier):
+ catalog: str
+
+
+@dataclass
+class PVFSDatabaseIdentifier(PVFSIdentifier):
+ database: str
+ catalog: str
+
+
+@dataclass
+class PVFSTableIdentifier(PVFSIdentifier):
+ catalog: str
+ database: str
+ table: str
+ sub_path: str = None
+
+ def __hash__(self) -> int:
+ return hash((self.catalog, self.database, self.table))
+
+ def __eq__(self, __value: Any) -> bool:
+ if isinstance(__value, PVFSTableIdentifier):
+ return self.catalog == __value.catalog and self.database ==
__value.database and self.table == __value.table
+ return False
+
+ def get_actual_path(self, storage_location: str):
+ if self.sub_path:
+ return
f'{storage_location.rstrip("/")}/{self.sub_path.lstrip("/")}'
+ return storage_location
+
+ def get_virtual_location(self):
+ return (f'{PROTOCOL_NAME}://{self.catalog}'
+ f'/{self.database}/{self.table}')
+
+ def get_identifier(self):
+ return Identifier.create(self.database, self.table)
+
+
+@dataclass
+class PaimonRealStorage:
+ TOKEN_EXPIRATION_SAFE_TIME_MILLIS = 3_600_000
+
+ token: Dict[str, str]
+ expires_at_millis: Optional[int]
+ file_system: AbstractFileSystem
+
+ def need_refresh(self) -> bool:
+ if self.expires_at_millis is not None:
+ return self.expires_at_millis - int(time.time() * 1000) <
self.TOKEN_EXPIRATION_SAFE_TIME_MILLIS
+ return False
+
+
+class PaimonVirtualFileSystem(fsspec.AbstractFileSystem):
+ options: Dict[str, Any]
+
+ protocol = PROTOCOL_NAME
+ _identifier_pattern =
re.compile("^pvfs://([^/]+)/([^/]+)/([^/]+)(?:/[^/]+)*/?$")
+
+ def __init__(self, options: Dict = None, **kwargs):
+ options.update({RESTCatalogOptions.HTTP_USER_AGENT_HEADER:
'PythonPVFS'})
+ self.options = options
+ self.warehouse = options.get(RESTCatalogOptions.WAREHOUSE)
+ cache_size = (
+ PVFSOptions.DEFAULT_CACHE_SIZE
+ if options is None
+ else options.get(PVFSOptions.CACHE_SIZE,
PVFSOptions.DEFAULT_CACHE_SIZE)
+ )
+ self._rest_client_cache = LRUCache(cache_size)
+ self._cache = LRUCache(maxsize=cache_size)
+ self._cache_lock = rwlock.RWLockFair()
+ super().__init__(**kwargs)
+
+ def __rest_api(self, catalog: str):
+ rest_api = self._rest_client_cache.get(catalog)
+ if rest_api is None:
+ options = self.options.copy()
+ options.update({RESTCatalogOptions.WAREHOUSE: catalog})
+ rest_api = RESTApi(options)
+ self._rest_client_cache[catalog] = rest_api
+ return rest_api
+
+ @property
+ def fsid(self):
+ return PROTOCOL_NAME
+
+ def sign(self, path, expiration=None, **kwargs):
+ """We do not support to create a signed URL representing the given
path in gvfs."""
+ raise Exception(
+ "Sign is not implemented for Paimon Virtual FileSystem."
+ )
+
+ def ls(self, path, detail=True, **kwargs):
+ pvfs_identifier = self._extract_pvfs_identifier(path)
+ rest_api = self.__rest_api(pvfs_identifier.catalog)
+ if isinstance(pvfs_identifier, PVFSCatalogIdentifier):
+ databases = rest_api.list_databases()
+ if detail:
+ return [
+ self._create_dir_detail(
+
self._convert_database_virtual_path(pvfs_identifier.catalog, database)
+ )
+ for database in databases
+ ]
+ return [
+ self._convert_database_virtual_path(pvfs_identifier.catalog,
database)
+ for database in databases
+ ]
+ elif isinstance(pvfs_identifier, PVFSDatabaseIdentifier):
+ tables = rest_api.list_tables(pvfs_identifier.database)
+ if detail:
+ return [
+ self._create_dir_detail(
+
self._convert_table_virtual_path(pvfs_identifier.catalog,
pvfs_identifier.database, table)
+ )
+ for table in tables
+ ]
+ return [
+ self._convert_table_virtual_path(pvfs_identifier.catalog,
pvfs_identifier.database, table)
+ for table in tables
+ ]
+ elif isinstance(pvfs_identifier, PVFSTableIdentifier):
+ table =
rest_api.get_table(Identifier.create(pvfs_identifier.database,
pvfs_identifier.table))
+ storage_type = self._get_storage_type(table.path)
+ storage_location = table.path
+ actual_path = pvfs_identifier.get_actual_path(storage_location)
+ virtual_location = pvfs_identifier.get_virtual_location()
+ fs = self._get_filesystem(pvfs_identifier, storage_type)
+ entries = fs.ls(actual_path, detail=detail, **kwargs)
+ if detail:
+ virtual_entities = [
+ self._convert_actual_info(entry, storage_type,
storage_location, virtual_location)
+ for entry in entries
+ ]
+ return virtual_entities
+ else:
+ virtual_entry_paths = [
+ self._convert_actual_path(
+ storage_type, entry_path, storage_location,
virtual_location
+ )
+ for entry_path in entries
+ ]
+ return virtual_entry_paths
+
+ def info(self, path, **kwargs):
+ pvfs_identifier = self._extract_pvfs_identifier(path)
+ if isinstance(pvfs_identifier, PVFSCatalogIdentifier):
+ return
self._create_dir_detail(f'{PROTOCOL_NAME}://{pvfs_identifier.catalog}')
+ elif isinstance(pvfs_identifier, PVFSDatabaseIdentifier):
+ return self._create_dir_detail(
+ self._convert_database_virtual_path(pvfs_identifier.catalog,
pvfs_identifier.database)
+ )
+ elif isinstance(pvfs_identifier, PVFSTableIdentifier):
+ rest_api = self.__rest_api(pvfs_identifier.catalog)
+ table =
rest_api.get_table(Identifier.create(pvfs_identifier.database,
pvfs_identifier.table))
+ storage_type = self._get_storage_type(table.path)
+ storage_location = table.path
+ actual_path = pvfs_identifier.get_actual_path(storage_location)
+ virtual_location = pvfs_identifier.get_virtual_location()
+ fs = self._get_filesystem(pvfs_identifier, storage_type)
+ entry = fs.info(actual_path)
+ return self._convert_actual_info(entry, storage_type,
storage_location, virtual_location)
+
+ def exists(self, path, **kwargs):
+ pvfs_identifier = self._extract_pvfs_identifier(path)
+ if isinstance(pvfs_identifier, PVFSCatalogIdentifier):
+ return True
+ elif isinstance(pvfs_identifier, PVFSDatabaseIdentifier):
+ rest_api = self.__rest_api(pvfs_identifier.catalog)
+ try:
+ rest_api.get_database(pvfs_identifier.database)
+ return True
+ except NoSuchResourceException:
+ return False
+ elif isinstance(pvfs_identifier, PVFSTableIdentifier):
+ rest_api = self.__rest_api(pvfs_identifier.catalog)
+ try:
+ table =
rest_api.get_table(Identifier.create(pvfs_identifier.database,
pvfs_identifier.table))
+ if pvfs_identifier.sub_path is None:
+ return True
+ storage_type = self._get_storage_type(table.path)
+ storage_location = table.path
+ actual_path = pvfs_identifier.get_actual_path(storage_location)
+ fs = self._get_filesystem(pvfs_identifier, storage_type)
+ return fs.exists(actual_path)
+ except NoSuchResourceException:
+ return False
+
+ def cp_file(self, path1, path2, **kwargs):
+ source = self._extract_pvfs_identifier(path1)
+ target = self._extract_pvfs_identifier(path2)
+ if ((isinstance(source, PVFSTableIdentifier)
+ and isinstance(target, PVFSTableIdentifier))
+ and target.sub_path is not None
+ and source.sub_path is not None
+ and source == target):
+ rest_api = self.__rest_api(source.catalog)
+ table_identifier = source.get_identifier()
+ table = rest_api.get_table(table_identifier)
+ storage_type = self._get_storage_type(table.path)
+ storage_location = table.path
+ source_actual_path = source.get_actual_path(storage_location)
+ target_actual_path = target.get_actual_path(storage_location)
+ fs = self._get_filesystem(source, storage_type)
+ fs.cp_file(
+ self._strip_storage_protocol(storage_type, source_actual_path),
+ self._strip_storage_protocol(storage_type, target_actual_path),
+ )
+ return None
+ raise Exception(
+ f"cp is not supported for path: {path1} to path: {path2}"
+ )
+
+ def mv(self, path1, path2, recursive=False, maxdepth=None, **kwargs):
+ source = self._extract_pvfs_identifier(path1)
+ target = self._extract_pvfs_identifier(path2)
+ if (isinstance(source, PVFSTableIdentifier) and
+ isinstance(target, PVFSTableIdentifier) and
+ target.catalog == source.catalog):
+ rest_api = self.__rest_api(source.catalog)
+ if target.sub_path is None and source.sub_path is None:
+ source_identifier = Identifier.create(source.database,
source.table)
+ target_identifier = Identifier.create(target.database,
target.table)
+ rest_api.rename_table(source_identifier, target_identifier)
+ return None
+ elif target.sub_path is not None and source.sub_path is not None
and target == source:
+ table_identifier = source.get_identifier()
+ table = rest_api.get_table(table_identifier)
+ storage_type = self._get_storage_type(table.path)
+ storage_location = table.path
+ source_actual_path = source.get_actual_path(storage_location)
+ target_actual_path = target.get_actual_path(storage_location)
+ fs = self._get_filesystem(source, storage_type)
+ if storage_type == StorageType.LOCAL:
+ fs.mv(
+ self._strip_storage_protocol(storage_type,
source_actual_path),
+ self._strip_storage_protocol(storage_type,
target_actual_path),
+ recursive=recursive,
+ maxdepth=maxdepth
+ )
+ else:
+ fs.mv(
+ self._strip_storage_protocol(storage_type,
source_actual_path),
+ self._strip_storage_protocol(storage_type,
target_actual_path),
+ )
+ return None
+ raise Exception(
+ f"Mv is not supported for path: {path1} to path: {path2}"
+ )
+
+ def rm(self, path, recursive=False, maxdepth=None):
+ pvfs_identifier = self._extract_pvfs_identifier(path)
+ rest_api = self.__rest_api(pvfs_identifier.catalog)
+ if isinstance(pvfs_identifier, PVFSDatabaseIdentifier):
+ database_name = pvfs_identifier.database
+ if not recursive and len(rest_api.list_tables(database_name)) > 0:
+ raise Exception('Recursive is False but database is not empty')
+ rest_api.drop_database(database_name)
+ return True
+ elif isinstance(pvfs_identifier, PVFSTableIdentifier):
+ table_identifier = pvfs_identifier.get_identifier()
+ table = rest_api.get_table(table_identifier)
+ if pvfs_identifier.sub_path is None:
+ rest_api.drop_table(table_identifier)
+ return True
+ storage_type = self._get_storage_type(table.path)
+ storage_location = table.path
+ actual_path = pvfs_identifier.get_actual_path(storage_location)
+ fs = self._get_filesystem(pvfs_identifier, storage_type)
+ return fs.rm(
+ self._strip_storage_protocol(storage_type, actual_path),
+ recursive,
+ maxdepth,
+ )
+ raise Exception(
+ f"Rm is not supported for path: {path}."
+ )
+
+ def rm_file(self, path):
+ pvfs_identifier = self._extract_pvfs_identifier(path)
+ if isinstance(pvfs_identifier, PVFSTableIdentifier):
+ table_identifier = pvfs_identifier.get_identifier()
+ rest_api = self.__rest_api(pvfs_identifier.catalog)
+ table = rest_api.get_table(table_identifier)
+ if pvfs_identifier.sub_path is not None:
+ storage_type = self._get_storage_type(table.path)
+ storage_location = table.path
+ actual_path = pvfs_identifier.get_actual_path(storage_location)
+ fs = self._get_filesystem(pvfs_identifier, storage_type)
+ return fs.rm_file(
+ self._strip_storage_protocol(storage_type, actual_path),
+ )
+ raise Exception(
+ f"Rm file is not supported for path: {path}."
+ )
+
+ def rmdir(self, path):
+ files = self.ls(path)
+ if len(files) == 0:
+ pvfs_identifier = self._extract_pvfs_identifier(path)
+ rest_api = self.__rest_api(pvfs_identifier.catalog)
+ if isinstance(pvfs_identifier, PVFSDatabaseIdentifier):
+ database_name = pvfs_identifier.database
+ rest_api.drop_database(database_name)
+ return True
+ elif isinstance(pvfs_identifier, PVFSTableIdentifier):
+ table_identifier = pvfs_identifier.get_identifier()
+ table = rest_api.get_table(table_identifier)
+ if pvfs_identifier.sub_path is None:
+ rest_api.drop_table(table_identifier)
+ self._cache.pop(pvfs_identifier)
+ return True
+ storage_type = self._get_storage_type(table.path)
+ storage_location = table.path
+ actual_path = pvfs_identifier.get_actual_path(storage_location)
+ fs = self._get_filesystem(pvfs_identifier, storage_type)
+ return fs.rmdir(
+ self._strip_storage_protocol(storage_type, actual_path)
+ )
+ raise Exception(
+ f"Rm dir is not supported for path: {path}."
+ )
+ else:
+ raise Exception(
+ f"Rm dir is not supported for path: {path} as it is not empty."
+ )
+
+ def open(
+ self,
+ path,
+ mode="rb",
+ block_size=None,
+ cache_options=None,
+ compression=None,
+ **kwargs
+ ):
+ pvfs_identifier = self._extract_pvfs_identifier(path)
+ if isinstance(pvfs_identifier, PVFSCatalogIdentifier):
+ raise Exception(
+ f"open is not supported for path: {path}"
+ )
+ elif isinstance(pvfs_identifier, PVFSDatabaseIdentifier):
+ raise Exception(
+ f"open is not supported for path: {path}"
+ )
+ elif isinstance(pvfs_identifier, PVFSTableIdentifier):
+ rest_api = self.__rest_api(pvfs_identifier.catalog)
+ table_identifier = pvfs_identifier.get_identifier()
+ table = rest_api.get_table(table_identifier)
+ if pvfs_identifier.sub_path is None:
+ raise Exception(
+ f"open is not supported for path: {path}"
+ )
+ else:
+ storage_type = self._get_storage_type(table.path)
+ storage_location = table.path
+ actual_path = pvfs_identifier.get_actual_path(storage_location)
+ fs = self._get_filesystem(pvfs_identifier, storage_type)
+ return fs.open(
+ self._strip_storage_protocol(storage_type, actual_path),
+ mode,
+ block_size,
+ cache_options,
+ compression,
+ **kwargs
+ )
+
+ def mkdir(self, path, create_parents=True, **kwargs):
+ pvfs_identifier = self._extract_pvfs_identifier(path)
+ rest_api = self.__rest_api(pvfs_identifier.catalog)
+ if isinstance(pvfs_identifier, PVFSCatalogIdentifier):
+ raise Exception(
+ f"mkdir is not supported for path: {path}"
+ )
+ elif isinstance(pvfs_identifier, PVFSDatabaseIdentifier):
+ rest_api.create_database(pvfs_identifier.database, {})
+ elif isinstance(pvfs_identifier, PVFSTableIdentifier):
+ table_identifier = pvfs_identifier.get_identifier()
+ if pvfs_identifier.sub_path is None:
+ if create_parents:
+ try:
+ rest_api.create_database(pvfs_identifier.database, {})
+ except AlreadyExistsException:
+ pass
+ self._create_object_table(pvfs_identifier)
+ else:
+ table: GetTableResponse
+ if create_parents:
+ try:
+ rest_api.create_database(pvfs_identifier.database, {})
+ except AlreadyExistsException:
+ pass
+ try:
+ table = rest_api.get_table(table_identifier)
+ except NoSuchResourceException:
+ try:
+ self._create_object_table(pvfs_identifier)
+ except AlreadyExistsException:
+ pass
+ finally:
+ table = rest_api.get_table(table_identifier)
+ storage_type = self._get_storage_type(table.path)
+ storage_location = table.path
+ actual_path = pvfs_identifier.get_actual_path(storage_location)
+ fs = self._get_filesystem(pvfs_identifier, storage_type)
+ return fs.mkdir(
+ self._strip_storage_protocol(storage_type, actual_path),
+ create_parents,
+ **kwargs
+ )
+
+ def makedirs(self, path, exist_ok=True):
+ pvfs_identifier = self._extract_pvfs_identifier(path)
+ rest_api = self.__rest_api(pvfs_identifier.catalog)
+ if isinstance(pvfs_identifier, PVFSCatalogIdentifier):
+ raise Exception(
+ f"makedirs is not supported for path: {path}"
+ )
+ elif isinstance(pvfs_identifier, PVFSDatabaseIdentifier):
+ try:
+ rest_api.create_database(pvfs_identifier.catalog, {})
+ except AlreadyExistsException as e:
+ if exist_ok:
+ pass
+ raise e
+ elif isinstance(pvfs_identifier, PVFSTableIdentifier):
+ table_identifier = pvfs_identifier.get_identifier()
+ if pvfs_identifier.sub_path is None:
+ try:
+ self._create_object_table(pvfs_identifier)
+ except AlreadyExistsException as e:
+ if exist_ok:
+ pass
+ raise e
+ else:
+ try:
+ self._create_object_table(pvfs_identifier)
+ except AlreadyExistsException as e:
+ if exist_ok:
+ pass
+ else:
+ raise e
+ table = rest_api.get_table(table_identifier)
+ storage_type = self._get_storage_type(table.path)
+ storage_location = table.path
+ actual_path = pvfs_identifier.get_actual_path(storage_location)
+ fs = self._get_filesystem(pvfs_identifier, storage_type)
+ return fs.makedirs(
+ self._strip_storage_protocol(storage_type, actual_path),
+ exist_ok
+ )
+
+ def created(self, path):
+ pvfs_identifier = self._extract_pvfs_identifier(path)
+ rest_api = self.__rest_api(pvfs_identifier.catalog)
+ if isinstance(pvfs_identifier, PVFSCatalogIdentifier):
+ raise Exception(
+ f"created is not supported for path: {path}"
+ )
+ elif isinstance(pvfs_identifier, PVFSDatabaseIdentifier):
+ return
self.__converse_ts_to_datatime(rest_api.get_database(pvfs_identifier.database).created_at)
+ elif isinstance(pvfs_identifier, PVFSTableIdentifier):
+ table_identifier = pvfs_identifier.get_identifier()
+ if pvfs_identifier.sub_path is None:
+ return
self.__converse_ts_to_datatime(rest_api.get_table(table_identifier).created_at)
+ else:
+ table = rest_api.get_table(table_identifier)
+ storage_type = self._get_storage_type(table.path)
+ storage_location = table.path
+ actual_path = pvfs_identifier.get_actual_path(storage_location)
+ fs = self._get_filesystem(pvfs_identifier, storage_type)
+ return fs.created(
+ self._strip_storage_protocol(storage_type, actual_path)
+ )
+
+ def modified(self, path):
+ pvfs_identifier = self._extract_pvfs_identifier(path)
+ rest_api = self.__rest_api(pvfs_identifier.catalog)
+ if isinstance(pvfs_identifier, PVFSCatalogIdentifier):
+ raise Exception(
+ f"modified is not supported for path: {path}"
+ )
+ elif isinstance(pvfs_identifier, PVFSDatabaseIdentifier):
+ return
self.__converse_ts_to_datatime(rest_api.get_database(pvfs_identifier.database).updated_at)
+ elif isinstance(pvfs_identifier, PVFSTableIdentifier):
+ table_identifier = pvfs_identifier.get_identifier()
+ if pvfs_identifier.sub_path is None:
+ return
self.__converse_ts_to_datatime(rest_api.get_table(table_identifier).updated_at)
+ else:
+ table = rest_api.get_table(table_identifier)
+ storage_type = self._get_storage_type(table.path)
+ storage_location = table.path
+ actual_path = pvfs_identifier.get_actual_path(storage_location)
+ fs = self._get_filesystem(pvfs_identifier, storage_type)
+ return fs.modified(
+ self._strip_storage_protocol(storage_type, actual_path)
+ )
+
+ def cat_file(self, path, start=None, end=None, **kwargs):
+ pvfs_identifier = self._extract_pvfs_identifier(path)
+ if isinstance(pvfs_identifier, PVFSCatalogIdentifier):
+ raise Exception(
+ f"cat file is not supported for path: {path}"
+ )
+ elif isinstance(pvfs_identifier, PVFSDatabaseIdentifier):
+ raise Exception(
+ f"cat file is not supported for path: {path}"
+ )
+ elif isinstance(pvfs_identifier, PVFSTableIdentifier):
+ table_identifier = pvfs_identifier.get_identifier()
+ if pvfs_identifier.sub_path is None:
+ raise Exception(
+ f"cat file is not supported for path: {path}"
+ )
+ else:
+ rest_api = self.__rest_api(pvfs_identifier.catalog)
+ table = rest_api.get_table(table_identifier)
+ storage_type = self._get_storage_type(table.path)
+ storage_location = table.path
+ actual_path = pvfs_identifier.get_actual_path(storage_location)
+ fs = self._get_filesystem(pvfs_identifier, storage_type)
+ return fs.cat_file(
+ self._strip_storage_protocol(storage_type, actual_path),
+ start,
+ end,
+ **kwargs,
+ )
+
+ def get_file(self, rpath, lpath, callback=None, outfile=None, **kwargs):
+ pvfs_identifier = self._extract_pvfs_identifier(rpath)
+ if isinstance(pvfs_identifier, PVFSCatalogIdentifier):
+ raise Exception(
+ f"get file is not supported for path: {rpath}"
+ )
+ elif isinstance(pvfs_identifier, PVFSDatabaseIdentifier):
+ raise Exception(
+ f"get file is not supported for path: {rpath}"
+ )
+ elif isinstance(pvfs_identifier, PVFSTableIdentifier):
+ rest_api = self.__rest_api(pvfs_identifier.catalog)
+ table_identifier = pvfs_identifier.get_identifier()
+ if pvfs_identifier.sub_path is None:
+ raise Exception(
+ f"get file is not supported for path: {rpath}"
+ )
+ else:
+ table = rest_api.get_table(table_identifier)
+ storage_type = self._get_storage_type(table.path)
+ storage_location = table.path
+ actual_path = pvfs_identifier.get_actual_path(storage_location)
+ fs = self._get_filesystem(pvfs_identifier, storage_type)
+ return fs.get_file(
+ self._strip_storage_protocol(storage_type, actual_path),
+ lpath,
+ **kwargs
+ )
+
+ def _rm(self, path):
+ raise Exception(
+ "_rm is not implemented for Paimon Virtual FileSystem."
+ )
+
+ def _create_object_table(self, pvfs_identifier: PVFSTableIdentifier):
+ rest_api = self.__rest_api(pvfs_identifier.catalog)
+ schema = Schema(options={'type': 'object-table'})
+ table_identifier = pvfs_identifier.get_identifier()
+ rest_api.create_table(table_identifier, schema)
+
+ @staticmethod
+ def __converse_ts_to_datatime(ts: int):
+ return datetime.datetime.fromtimestamp(ts / 1000,
tz=datetime.timezone.utc)
+
+ @staticmethod
+ def _strip_storage_protocol(storage_type: StorageType, path: str):
+ if storage_type == StorageType.LOCAL:
+ return path[len(f"{StorageType.LOCAL.value}:"):]
+
+ # OSS has different behavior than S3 and GCS, if we do not remove the
+ # protocol, it will always return an empty array.
+ if storage_type == StorageType.OSS:
+ if path.startswith(f"{StorageType.OSS.value}://"):
+ return path[len(f"{StorageType.OSS.value}://"):]
+ return path
+
+ raise Exception(
+ f"Storage type:{storage_type} doesn't support now."
+ )
+
+ @staticmethod
+ def _convert_actual_info(
+ entry: Dict,
+ storage_type: StorageType,
+ storage_location: str,
+ virtual_location: str,
+ ):
+ path = PaimonVirtualFileSystem._convert_actual_path(storage_type,
entry["name"], storage_location,
+ virtual_location)
+
+ if "mtime" in entry:
+ # HDFS and GCS
+ return PaimonVirtualFileSystem._create_file_detail(path,
entry["size"], entry["type"], entry["mtime"])
+ elif "LastModified" in entry:
+ # S3 and OSS
+ return PaimonVirtualFileSystem._create_file_detail(path,
entry["size"], entry["type"],
+
entry["LastModified"])
+ # Unknown
+ return PaimonVirtualFileSystem._create_file_detail(path,
entry["size"], entry["type"])
+
+ @staticmethod
+ def _create_dir_detail(path: str) -> Dict[str, Any]:
+ return PaimonVirtualFileSystem._create_file_detail(path, 0,
'directory', None)
+
+ @staticmethod
+ def _create_file_detail(name: str, size: int, type: str, mtime: int =
None) -> Dict[str, Any]:
+ return {
+ "name": name,
+ "size": size,
+ "type": type,
+ "mtime": mtime,
+ }
+
+ @staticmethod
+ def _convert_database_virtual_path(
+ catalog_name: str,
+ database_name: str
+ ):
+ return f'{PROTOCOL_NAME}://{catalog_name}/{database_name}'
+
+ @staticmethod
+ def _convert_table_virtual_path(
+ catalog_name: str,
+ database_name: str,
+ table_name: str
+ ):
+ return f'{PROTOCOL_NAME}://{catalog_name}/{database_name}/{table_name}'
+
+ @staticmethod
+ def _convert_actual_path(
+ storage_type: StorageType,
+ actual_path: str,
+ storage_location: str,
+ virtual_location: str,
+ ):
+ actual_path =
PaimonVirtualFileSystem._get_path_without_schema(storage_type, actual_path)
+ storage_location =
PaimonVirtualFileSystem._get_path_without_schema(storage_type, storage_location)
+ normalized_pvfs = virtual_location.rstrip('/')
+ sub_location = actual_path[len(storage_location):].lstrip("/")
+ if len(sub_location) == 0:
+ return normalized_pvfs
+ else:
+ return f'{normalized_pvfs}/{sub_location}'
+
+ @staticmethod
+ def _get_path_without_schema(storage_type: StorageType, path: str) -> str:
+ if storage_type == StorageType.LOCAL and
path.startswith(StorageType.LOCAL.value):
+ return path[len(f"{StorageType.LOCAL.value}://"):]
+ elif storage_type == StorageType.OSS and
path.startswith(StorageType.OSS.value):
+ return path[len(f"{StorageType.OSS.value}://"):]
+ return path
+
+ @staticmethod
+ def _extract_pvfs_identifier(path: str) -> Optional['PVFSIdentifier']:
+ if not isinstance(path, str):
+ raise Exception("path is not a string")
+ path_without_protocol = path
+ if path.startswith(f'{PROTOCOL_NAME}://'):
+ path_without_protocol = path[7:]
+
+ if not path_without_protocol:
+ return None
+
+ components = [component for component in
path_without_protocol.rstrip('/').split('/') if component]
+
+ if len(components) == 0:
+ return None
+ elif len(components) == 1:
+ return PVFSCatalogIdentifier(components[0])
+ elif len(components) == 2:
+ return PVFSDatabaseIdentifier(catalog=components[0],
database=components[1])
+ elif len(components) == 3:
+ return PVFSTableIdentifier(catalog=components[0],
database=components[1], table=components[2])
+ elif len(components) > 3:
+ sub_path = '/'.join(components[3:])
+ return PVFSTableIdentifier(
+ catalog=components[0], database=components[1],
+ table=components[2], sub_path=sub_path
+ )
+ return None
+
+ def _get_filesystem(self, pvfs_table_identifier: PVFSTableIdentifier,
storage_type: StorageType) -> 'FileSystem':
+ read_lock = self._cache_lock.gen_rlock()
+ try:
+ read_lock.acquire()
+ cache_value: Tuple[StorageType, AbstractFileSystem] =
self._cache.get(
+ storage_type
+ )
+ if cache_value is not None:
+ return cache_value
+ finally:
+ read_lock.release()
+
+ write_lock = self._cache_lock.gen_wlock()
+ try:
+ write_lock.acquire()
+ cache_value: PaimonRealStorage =
self._cache.get(pvfs_table_identifier)
+ if cache_value is not None and cache_value.need_refresh() is False:
+ return cache_value.file_system
+ if storage_type == StorageType.LOCAL:
+ fs = LocalFileSystem()
+ elif storage_type == StorageType.OSS:
+ rest_api = self.__rest_api(pvfs_table_identifier.catalog)
+ load_token_response: GetTableTokenResponse =
rest_api.load_table_token(
+ Identifier.create(pvfs_table_identifier.database,
pvfs_table_identifier.table))
+ fs = self._get_oss_filesystem(load_token_response.token)
+ paimon_real_storage = PaimonRealStorage(
+ token=load_token_response.token,
+ expires_at_millis=load_token_response.expires_at_millis,
+ file_system=fs
+ )
+ self._cache[pvfs_table_identifier] = paimon_real_storage
+ else:
+ raise Exception(
+ f"Storage type: `{storage_type}` doesn't support now."
+ )
+ return fs
+ finally:
+ write_lock.release()
+
+ @staticmethod
+ def _get_storage_type(path: str):
+ if path.startswith(f"{StorageType.LOCAL.value}:/"):
+ return StorageType.LOCAL
+ elif path.startswith(f"{StorageType.OSS.value}://"):
+ return StorageType.OSS
+ raise Exception(
+ f"Storage type doesn't support now. Path:{path}"
+ )
+
+ @staticmethod
+ def _get_oss_filesystem(options: Dict[str, str]) -> AbstractFileSystem:
+ access_key_id = options.get(OssOptions.OSS_ACCESS_KEY_ID)
+ if access_key_id is None:
+ raise Exception(
+ "OSS access key id is not found in the options."
+ )
+
+ access_key_secret = options.get(
+ OssOptions.OSS_ACCESS_KEY_SECRET
+ )
+ if access_key_secret is None:
+ raise Exception(
+ "OSS access key secret is not found in the options."
+ )
+ oss_endpoint_url = options.get(OssOptions.OSS_ENDPOINT)
+ if oss_endpoint_url is None:
+ raise Exception(
+ "OSS endpoint url is not found in the options."
+ )
+ token = options.get(OssOptions.OSS_SECURITY_TOKEN)
+ return importlib.import_module("ossfs").OSSFileSystem(
+ key=access_key_id,
+ secret=access_key_secret,
+ token=token,
+ endpoint=oss_endpoint_url
+ )
diff --git a/paimon-python/pypaimon/tests/api_test.py
b/paimon-python/pypaimon/tests/api_test.py
index d8b872c33f..9559296652 100644
--- a/paimon-python/pypaimon/tests/api_test.py
+++ b/paimon-python/pypaimon/tests/api_test.py
@@ -16,739 +16,20 @@
# limitations under the License.
import logging
-import re
import uuid
-from typing import Dict, List, Optional, Any, Union, Tuple
-from dataclasses import dataclass
-import threading
-from http.server import HTTPServer, BaseHTTPRequestHandler
-from urllib.parse import urlparse
import unittest
import pypaimon.api as api
-from ..api.api_response import (ConfigResponse, ListDatabasesResponse,
GetDatabaseResponse,
- TableMetadata, Schema, GetTableResponse,
ListTablesResponse,
- TableSchema, RESTResponse, PagedList,
DataField)
+from .rest_server import RESTCatalogServer
+from ..api.api_response import (ConfigResponse, TableMetadata, TableSchema,
DataField)
from ..api import RESTApi
+from ..api.auth import BearTokenAuthProvider
from ..api.rest_json import JSON
from ..api.token_loader import DLFTokenLoaderFactory, DLFToken
from ..api.typedef import Identifier
from ..api.data_types import AtomicInteger, DataTypeParser, AtomicType,
ArrayType, MapType, RowType
-@dataclass
-class ErrorResponse(RESTResponse):
- """Error response"""
- RESOURCE_TYPE_DATABASE = "database"
- RESOURCE_TYPE_TABLE = "table"
- RESOURCE_TYPE_VIEW = "view"
- RESOURCE_TYPE_FUNCTION = "function"
- RESOURCE_TYPE_COLUMN = "column"
- RESOURCE_TYPE_SNAPSHOT = "snapshot"
- RESOURCE_TYPE_TAG = "tag"
- RESOURCE_TYPE_BRANCH = "branch"
- RESOURCE_TYPE_DEFINITION = "definition"
- RESOURCE_TYPE_DIALECT = "dialect"
-
- resource_type: Optional[str]
- resource_name: Optional[str]
- message: str
- code: int
-
-
-class ResourcePaths:
- """Resource path constants"""
-
- TABLES = "tables"
- VIEWS = "views"
- FUNCTIONS = "functions"
- SNAPSHOTS = "snapshots"
- ROLLBACK = "rollback"
-
- def __init__(self, prefix: str = ""):
- self.prefix = prefix.rstrip('/')
-
- def config(self) -> str:
- return "/v1/config"
-
- def databases(self) -> str:
- return f"/v1/{self.prefix}/databases"
-
- def tables(self) -> str:
- return f"{self.prefix}/tables"
-
-
-# Exception classes
-class CatalogException(Exception):
- """Base catalog exception"""
-
-
-class DatabaseNotExistException(CatalogException):
- """Database not exist exception"""
-
- def __init__(self, database: str):
- self.database = database
- super().__init__(f"Database {database} does not exist")
-
-
-class DatabaseAlreadyExistException(CatalogException):
- """Database already exist exception"""
-
- def __init__(self, database: str):
- self.database = database
- super().__init__(f"Database {database} already exists")
-
-
-class DatabaseNoPermissionException(CatalogException):
- """Database no permission exception"""
-
- def __init__(self, database: str):
- self.database = database
- super().__init__(f"No permission to access database {database}")
-
-
-class TableNotExistException(CatalogException):
- """Table not exist exception"""
-
- def __init__(self, identifier: Identifier):
- self.identifier = identifier
- super().__init__(f"Table {identifier.get_full_name()} does not exist")
-
-
-class TableAlreadyExistException(CatalogException):
- """Table already exist exception"""
-
- def __init__(self, identifier: Identifier):
- self.identifier = identifier
- super().__init__(f"Table {identifier.get_full_name()} already exists")
-
-
-class TableNoPermissionException(CatalogException):
- """Table no permission exception"""
-
- def __init__(self, identifier: Identifier):
- self.identifier = identifier
- super().__init__(f"No permission to access table
{identifier.get_full_name()}")
-
-
-class ViewNotExistException(CatalogException):
- """View not exist exception"""
-
- def __init__(self, identifier: Identifier):
- self.identifier = identifier
- super().__init__(f"View {identifier.get_full_name()} does not exist")
-
-
-class ViewAlreadyExistException(CatalogException):
- """View already exist exception"""
-
- def __init__(self, identifier: Identifier):
- self.identifier = identifier
- super().__init__(f"View {identifier.get_full_name()} already exists")
-
-
-class FunctionNotExistException(CatalogException):
- """Function not exist exception"""
-
- def __init__(self, identifier: Identifier):
- self.identifier = identifier
- super().__init__(f"Function {identifier.get_full_name()} does not
exist")
-
-
-class FunctionAlreadyExistException(CatalogException):
- """Function already exist exception"""
-
- def __init__(self, identifier: Identifier):
- self.identifier = identifier
- super().__init__(f"Function {identifier.get_full_name()} already
exists")
-
-
-class ColumnNotExistException(CatalogException):
- """Column not exist exception"""
-
- def __init__(self, column: str):
- self.column = column
- super().__init__(f"Column {column} does not exist")
-
-
-class ColumnAlreadyExistException(CatalogException):
- """Column already exist exception"""
-
- def __init__(self, column: str):
- self.column = column
- super().__init__(f"Column {column} already exists")
-
-
-class DefinitionNotExistException(CatalogException):
- """Definition not exist exception"""
-
- def __init__(self, identifier: Identifier, name: str):
- self.identifier = identifier
- self.name = name
- super().__init__(f"Definition {name} does not exist in
{identifier.get_full_name()}")
-
-
-class DefinitionAlreadyExistException(CatalogException):
- """Definition already exist exception"""
-
- def __init__(self, identifier: Identifier, name: str):
- self.identifier = identifier
- self.name = name
- super().__init__(f"Definition {name} already exists in
{identifier.get_full_name()}")
-
-
-class DialectNotExistException(CatalogException):
- """Dialect not exist exception"""
-
- def __init__(self, identifier: Identifier, dialect: str):
- self.identifier = identifier
- self.dialect = dialect
- super().__init__(f"Dialect {dialect} does not exist in
{identifier.get_full_name()}")
-
-
-class DialectAlreadyExistException(CatalogException):
- """Dialect already exist exception"""
-
- def __init__(self, identifier: Identifier, dialect: str):
- self.identifier = identifier
- self.dialect = dialect
- super().__init__(f"Dialect {dialect} already exists in
{identifier.get_full_name()}")
-
-
-# Constants
-DEFAULT_MAX_RESULTS = 100
-AUTHORIZATION_HEADER_KEY = "Authorization"
-
-# REST API parameter constants
-DATABASE_NAME_PATTERN = "databaseNamePattern"
-TABLE_NAME_PATTERN = "tableNamePattern"
-VIEW_NAME_PATTERN = "viewNamePattern"
-FUNCTION_NAME_PATTERN = "functionNamePattern"
-PARTITION_NAME_PATTERN = "partitionNamePattern"
-MAX_RESULTS = "maxResults"
-PAGE_TOKEN = "pageToken"
-
-# Core options
-PATH = "path"
-TYPE = "type"
-WAREHOUSE = "warehouse"
-SNAPSHOT_CLEAN_EMPTY_DIRECTORIES = "snapshot.clean-empty-directories"
-
-# Table types
-FORMAT_TABLE = "FORMAT_TABLE"
-OBJECT_TABLE = "OBJECT_TABLE"
-
-
-class RESTCatalogServer:
- """Mock REST server for testing"""
-
- def __init__(self, data_path: str, auth_provider, config: ConfigResponse,
warehouse: str,
- role_name: str = None, token_json: str = None):
- self.logger = logging.getLogger(__name__)
- self.warehouse = warehouse
- self.config_response = config
-
- # Initialize resource paths
- prefix = config.defaults.get("prefix")
- self.resource_paths = ResourcePaths(prefix)
- self.database_uri = self.resource_paths.databases()
-
- # Initialize storage
- self.database_store: Dict[str, GetDatabaseResponse] = {}
- self.table_metadata_store: Dict[str, TableMetadata] = {}
- self.no_permission_databases: List[str] = []
- self.no_permission_tables: List[str] = []
-
- # Initialize mock catalog (simplified)
- self.data_path = data_path
- self.auth_provider = auth_provider
- self.role_name = role_name
- self.token_json = token_json
-
- # HTTP server setup
- self.server = None
- self.server_thread = None
- self.port = 0
-
- def start(self) -> None:
- """Start the mock server"""
- handler = self._create_request_handler()
- self.server = HTTPServer(('localhost', 0), handler)
- self.port = self.server.server_port
-
- self.server_thread = threading.Thread(target=self.server.serve_forever)
- self.server_thread.daemon = True
- self.server_thread.start()
-
- self.logger.info(f"Mock REST server started on port {self.port}")
-
- def get_url(self) -> str:
- """Get server URL"""
- return f"http://localhost:{self.port}"
-
- def shutdown(self) -> None:
- """Shutdown the server"""
- if self.server:
- self.server.shutdown()
- self.server.server_close()
- if self.server_thread:
- self.server_thread.join()
-
- def _create_request_handler(self):
- """Create HTTP request handler"""
- server_instance = self
-
- class RequestHandler(BaseHTTPRequestHandler):
- def do_GET(self):
- self._handle_request('GET')
-
- def do_POST(self):
- self._handle_request('POST')
-
- def do_DELETE(self):
- self._handle_request('DELETE')
-
- def _handle_request(self, method: str):
- try:
- # Parse request
- parsed_url = urlparse(self.path)
- resource_path = parsed_url.path
- parameters = self._parse_query_params(parsed_url.query)
-
- # Get request body
- content_length = int(self.headers.get('Content-Length', 0))
- data = self.rfile.read(content_length).decode('utf-8') if
content_length > 0 else ""
-
- # Get headers
- headers = dict(self.headers)
-
- # Handle authentication
- auth_token = headers.get(AUTHORIZATION_HEADER_KEY.lower())
- if not self._authenticate(auth_token, resource_path,
parameters, method, data):
- self._send_response(401, "Unauthorized")
- return
-
- # Route request
- response, status_code = server_instance._route_request(
- method, resource_path, parameters, data, headers
- )
-
- self._send_response(status_code, response)
-
- except Exception as e:
- server_instance.logger.error(f"Request handling error:
{e}")
- self._send_response(500, str(e))
-
- def _parse_query_params(self, query: str) -> Dict[str, str]:
- """Parse query parameters"""
- if not query:
- return {}
-
- params = {}
- for pair in query.split('&'):
- if '=' in pair:
- key, value = pair.split('=', 1)
- params[key.strip()] =
api.RESTUtil.decode_string(value.strip())
- return params
-
- def _authenticate(self, token: str, path: str, params: Dict[str,
str],
- method: str, data: str) -> bool:
- """Authenticate request"""
- # Simplified authentication - always return True for mock
- return True
-
- def _send_response(self, status_code: int, body: str):
- """Send HTTP response"""
- self.send_response(status_code)
- self.send_header('Content-Type', 'application/json')
- self.end_headers()
- self.wfile.write(body.encode('utf-8'))
-
- def log_message(self, format, *args):
- """Override to use our logger"""
- server_instance.logger.debug(format % args)
-
- return RequestHandler
-
- def _route_request(self, method: str, resource_path: str, parameters:
Dict[str, str],
- data: str, headers: Dict[str, str]) -> Tuple[str, int]:
- """Route HTTP request to appropriate handler"""
- try:
- # Config endpoint
- if resource_path.startswith(self.resource_paths.config()):
- warehouse_param = parameters.get(WAREHOUSE)
- if warehouse_param == self.warehouse:
- return self._mock_response(self.config_response, 200)
-
- # ecs role
- if resource_path == '/ram/security-credential/':
- return self._mock_response(self.role_name, 200)
-
- if resource_path == f'/ram/security-credential/{self.role_name}':
- return self._mock_response(self.token_json, 200)
-
- # Databases endpoint
- if resource_path == self.database_uri or
resource_path.startswith(self.database_uri + "?"):
- return self._databases_api_handler(method, data, parameters)
-
- # Global tables endpoint
- if resource_path.startswith(self.resource_paths.tables()):
- return self._tables_handle(parameters)
-
- # Database-specific endpoints
- if resource_path.startswith(self.database_uri + "/"):
- return self._handle_database_resource(method, resource_path,
parameters, data)
-
- return self._mock_response(ErrorResponse(None, None, "Not Found",
404), 404)
-
- except DatabaseNotExistException as e:
- response = ErrorResponse(
- ErrorResponse.RESOURCE_TYPE_DATABASE, e.database, str(e), 404
- )
- return self._mock_response(response, 404)
- except TableNotExistException as e:
- response = ErrorResponse(
- ErrorResponse.RESOURCE_TYPE_TABLE,
e.identifier.get_table_name(), str(e), 404
- )
- return self._mock_response(response, 404)
- except DatabaseNoPermissionException as e:
- response = ErrorResponse(
- ErrorResponse.RESOURCE_TYPE_DATABASE, e.database, str(e), 403
- )
- return self._mock_response(response, 403)
- except TableNoPermissionException as e:
- response = ErrorResponse(
- ErrorResponse.RESOURCE_TYPE_TABLE,
e.identifier.get_table_name(), str(e), 403
- )
- return self._mock_response(response, 403)
- except Exception as e:
- self.logger.error(f"Unexpected error: {e}")
- response = ErrorResponse(None, None, str(e), 500)
- return self._mock_response(response, 500)
-
- def _handle_database_resource(self, method: str, resource_path: str,
- parameters: Dict[str, str], data: str) ->
Tuple[str, int]:
- """Handle database-specific resource requests"""
- # Extract database name and resource path
- path_parts = resource_path[len(self.database_uri) + 1:].split('/')
- database_name = api.RESTUtil.decode_string(path_parts[0])
-
- # Check database permissions
- if database_name in self.no_permission_databases:
- raise DatabaseNoPermissionException(database_name)
-
- if database_name not in self.database_store:
- raise DatabaseNotExistException(database_name)
-
- # Handle different resource types
- if len(path_parts) == 1:
- # Database operations
- return self._database_handle(method, data, database_name)
-
- elif len(path_parts) == 2:
- # Collection operations (tables, views, functions)
- resource_type = path_parts[1]
-
- if resource_type.startswith(ResourcePaths.TABLES):
- return self._tables_handle(method, data, database_name,
parameters)
-
- elif len(path_parts) >= 3:
- # Individual resource operations
- resource_type = path_parts[1]
- resource_name = api.RESTUtil.decode_string(path_parts[2])
- identifier = Identifier.create(database_name, resource_name)
-
- if resource_type == ResourcePaths.TABLES:
- return self._handle_table_resource(method, path_parts,
identifier, data, parameters)
-
- return self._mock_response(ErrorResponse(None, None, "Not Found",
404), 404)
-
- def _handle_table_resource(self, method: str, path_parts: List[str],
- identifier: Identifier, data: str,
- parameters: Dict[str, str]) -> Tuple[str, int]:
- """Handle table-specific resource requests"""
- # Check table permissions
- if identifier.get_full_name() in self.no_permission_tables:
- raise TableNoPermissionException(identifier)
-
- if len(path_parts) == 3:
- # Basic table operations
- return self._table_handle(method, data, identifier)
- return self._mock_response(ErrorResponse(None, None, "Not Found",
404), 404)
-
- def _databases_api_handler(self, method: str, data: str,
- parameters: Dict[str, str]) -> Tuple[str, int]:
- """Handle databases API requests"""
- if method == "GET":
- database_name_pattern = parameters.get(DATABASE_NAME_PATTERN)
- databases = [
- db_name for db_name in self.database_store.keys()
- if not database_name_pattern or
self._match_name_pattern(db_name, database_name_pattern)
- ]
- return self._generate_final_list_databases_response(parameters,
databases)
-
- return self._mock_response(ErrorResponse(None, None, "Method Not
Allowed", 405), 405)
-
- def _database_handle(self, method: str, data: str, database_name: str) ->
Tuple[str, int]:
- """Handle individual database operations"""
- if database_name not in self.database_store:
- raise DatabaseNotExistException(database_name)
-
- database = self.database_store[database_name]
-
- if method == "GET":
- response = database
- return self._mock_response(response, 200)
-
- elif method == "DELETE":
- del self.database_store[database_name]
- return self._mock_response("", 200)
- return self._mock_response(ErrorResponse(None, None, "Method Not
Allowed", 405), 405)
-
- def _tables_handle(self, method: str = None, data: str = None,
database_name: str = None,
- parameters: Dict[str, str] = None) -> Tuple[str, int]:
- """Handle tables operations"""
- if parameters is None:
- parameters = {}
-
- if database_name:
- # Database-specific tables
- if method == "GET":
- tables = self._list_tables(database_name, parameters)
- return self._generate_final_list_tables_response(parameters,
tables)
- return self._mock_response(ErrorResponse(None, None, "Method Not
Allowed", 405), 405)
-
- def _table_handle(self, method: str, data: str, identifier: Identifier) ->
Tuple[str, int]:
- """Handle individual table operations"""
- if method == "GET":
- if identifier.is_system_table():
- # Handle system table
- schema = Schema(fields=[], options={PATH:
f"/tmp/{identifier.get_full_name()}"})
- table_metadata = self._create_table_metadata(identifier, 1,
schema, None, False)
- else:
- if identifier.get_full_name() not in self.table_metadata_store:
- raise TableNotExistException(identifier)
- table_metadata =
self.table_metadata_store[identifier.get_full_name()]
-
- schema = table_metadata.schema.to_schema()
- path = schema.options.pop(PATH, None)
-
- response = self.mock_table(identifier, table_metadata, path,
schema)
- return self._mock_response(response, 200)
- #
- # elif method == "POST":
- # # Alter table
- # request_body = JSON.from_json(data, AlterTableRequest)
- # self._alter_table_impl(identifier, request_body.get_changes())
- # return self._mock_response("", 200)
-
- elif method == "DELETE":
- # Drop table
- if identifier.get_full_name() in self.table_metadata_store:
- del self.table_metadata_store[identifier.get_full_name()]
- if identifier.get_full_name() in self.table_latest_snapshot_store:
- del
self.table_latest_snapshot_store[identifier.get_full_name()]
- if identifier.get_full_name() in self.table_partitions_store:
- del self.table_partitions_store[identifier.get_full_name()]
-
- return self._mock_response("", 200)
-
- return self._mock_response(ErrorResponse(None, None, "Method Not
Allowed", 405), 405)
-
- # Utility methods
- def _mock_response(self, response: Union[RESTResponse, str], http_code:
int) -> Tuple[str, int]:
- """Create mock response"""
- if isinstance(response, str):
- return response, http_code
-
- try:
- return JSON.to_json(response), http_code
- except Exception as e:
- self.logger.error(f"Failed to serialize response: {e}")
- return str(e), 500
-
- def _get_max_results(self, parameters: Dict[str, str]) -> int:
- """Get max results from parameters"""
- max_results_str = parameters.get(MAX_RESULTS)
- if max_results_str:
- try:
- max_results = int(max_results_str)
- return min(max_results, DEFAULT_MAX_RESULTS) if max_results >
0 else DEFAULT_MAX_RESULTS
- except ValueError:
- raise ValueError(f"Invalid maxResults value:
{max_results_str}")
- return DEFAULT_MAX_RESULTS
-
- def _build_paged_entities(self, entities: List[Any], max_results: int,
- page_token: Optional[str], desc: bool = False)
-> PagedList:
- """Build paged entities"""
- # Sort entities
- sorted_entities = sorted(entities, key=self._get_paged_key,
reverse=desc)
-
- # Apply pagination
- paged_entities = []
- for entity in sorted_entities:
- if len(paged_entities) < max_results:
- if not page_token or self._get_paged_key(entity) > page_token:
- paged_entities.append(entity)
- else:
- break
-
- # Determine next page token
- next_page_token = None
- if len(paged_entities) == max_results and len(sorted_entities) >
max_results:
- next_page_token = self._get_paged_key(paged_entities[-1])
-
- return PagedList(elements=paged_entities,
next_page_token=next_page_token)
-
- def _get_paged_key(self, entity: Any) -> str:
- """Get paging key for entity"""
- if isinstance(entity, str):
- return entity
- elif hasattr(entity, 'get_name'):
- return entity.get_name()
- elif hasattr(entity, 'get_full_name'):
- return entity.get_full_name()
- elif hasattr(entity, 'name'):
- return entity.name
- else:
- return str(entity)
-
- def _match_name_pattern(self, name: str, pattern: str) -> bool:
- """Match name against SQL pattern"""
- if not pattern:
- raise ValueError("Pattern cannot be empty")
- regex_pattern = self._sql_pattern_to_regex(pattern)
- return re.match(regex_pattern, name) is not None
-
- def _sql_pattern_to_regex(self, pattern: str) -> str:
- """Convert SQL pattern to regex"""
- regex = []
- escaped = False
-
- for char in pattern:
- if escaped:
- regex.append(re.escape(char))
- escaped = False
- elif char == '\\':
- escaped = True
- elif char == '%':
- regex.append('.*')
- elif char == '_':
- regex.append('.')
- else:
- regex.append(re.escape(char))
-
- return '^' + ''.join(regex) + '$'
-
- def _create_table_metadata(self, identifier: Identifier, schema_id: int,
- schema: Schema, uuid_str: str, is_external:
bool) -> TableMetadata:
- """Create table metadata"""
- options = schema.options.copy()
- path = f"/tmp/{identifier.get_full_name()}"
- options[PATH] = path
-
- table_schema = TableSchema(
- id=schema_id,
- fields=schema.fields,
- highest_field_id=len(schema.fields) - 1,
- partition_keys=schema.partition_keys,
- primary_keys=schema.primary_keys,
- options=options,
- comment=schema.comment
- )
-
- return TableMetadata(
- schema=table_schema,
- is_external=is_external,
- uuid=uuid_str or str(uuid.uuid4())
- )
-
- # List methods
- def _list_tables(self, database_name: str, parameters: Dict[str, str]) ->
List[str]:
- """List tables in database"""
- table_name_pattern = parameters.get(TABLE_NAME_PATTERN)
- tables = []
-
- for full_name, metadata in self.table_metadata_store.items():
- identifier = Identifier.from_string(full_name)
- if (identifier.get_database_name() == database_name and
- (not table_name_pattern or
self._match_name_pattern(identifier.get_table_name(),
-
table_name_pattern))):
- tables.append(identifier.get_table_name())
-
- return tables
-
- # Response generation methods
- def _generate_final_list_databases_response(self, parameters: Dict[str,
str],
- databases: List[str]) ->
Tuple[str, int]:
- """Generate final list databases response"""
- if databases:
- max_results = self._get_max_results(parameters)
- page_token = parameters.get(PAGE_TOKEN)
- paged_dbs = self._build_paged_entities(databases, max_results,
page_token)
- response = ListDatabasesResponse(
- databases=paged_dbs.elements,
- next_page_token=paged_dbs.next_page_token
- )
- else:
- response = ListDatabasesResponse(databases=[],
next_page_token=None)
-
- return self._mock_response(response, 200)
-
- def _generate_final_list_tables_response(self, parameters: Dict[str, str],
- tables: List[str]) -> Tuple[str,
int]:
- """Generate final list tables response"""
- if tables:
- max_results = self._get_max_results(parameters)
- page_token = parameters.get(PAGE_TOKEN)
- paged_tables = self._build_paged_entities(tables, max_results,
page_token)
- response = ListTablesResponse(
- tables=paged_tables.elements,
- next_page_token=paged_tables.next_page_token
- )
- else:
- response = ListTablesResponse(tables=[], next_page_token=None)
-
- return self._mock_response(response, 200)
-
- def add_no_permission_database(self, database: str) -> None:
- """Add no permission database"""
- self.no_permission_databases.append(database)
-
- def add_no_permission_table(self, identifier: Identifier) -> None:
- """Add no permission table"""
- self.no_permission_tables.append(identifier.get_full_name())
-
- def mock_database(self, name: str, options: dict[str, str]) ->
GetDatabaseResponse:
- return GetDatabaseResponse(
- id=str(uuid.uuid4()),
- name=name,
- location=f"{self.data_path}/{name}",
- options=options,
- owner="owner",
- created_at=1,
- created_by="created",
- updated_at=1,
- updated_by="updated"
- )
-
- def mock_table(self, identifier: Identifier, table_metadata:
TableMetadata, path: str,
- schema: Schema) -> GetTableResponse:
- return GetTableResponse(
- id=str(table_metadata.uuid),
- name=identifier.get_object_name(),
- path=path,
- is_external=table_metadata.is_external,
- schema_id=table_metadata.schema.id,
- schema=schema,
- owner="owner",
- created_at=1,
- created_by="created",
- updated_at=1,
- updated_by="updated"
- )
-
-
class ApiTestCase(unittest.TestCase):
def test_parse_data(self):
@@ -845,16 +126,11 @@ class ApiTestCase(unittest.TestCase):
# Create config
config = ConfigResponse(defaults={"prefix": "mock-test"})
-
- # Create mock auth provider
- class MockAuthProvider:
- def merge_auth_header(self, headers, auth_param):
- return {AUTHORIZATION_HEADER_KEY: "Bearer test-token"}
-
+ token = str(uuid.uuid4())
# Create server
server = RESTCatalogServer(
data_path="/tmp/test_warehouse",
- auth_provider=MockAuthProvider(),
+ auth_provider=BearTokenAuthProvider(token),
config=config,
warehouse="test_warehouse"
)
@@ -885,9 +161,8 @@ class ApiTestCase(unittest.TestCase):
'uri': f"http://localhost:{server.port}",
'warehouse': 'test_warehouse',
'dlf.region': 'cn-hangzhou',
- "token.provider": "xxxx",
- 'dlf.access-key-id': 'xxxx',
- 'dlf.access-key-secret': 'xxxx'
+ "token.provider": "bear",
+ 'token': token
}
api = RESTApi(options)
self.assertSetEqual(set(api.list_databases()), {*test_databases})
diff --git a/paimon-python/pypaimon/tests/pvfs_test.py
b/paimon-python/pypaimon/tests/pvfs_test.py
new file mode 100644
index 0000000000..c3248dcea7
--- /dev/null
+++ b/paimon-python/pypaimon/tests/pvfs_test.py
@@ -0,0 +1,206 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import shutil
+import tempfile
+import unittest
+import uuid
+import pandas
+from pathlib import Path
+
+from pypaimon.api import ConfigResponse
+from pypaimon.api.api_response import TableSchema, TableMetadata
+from pypaimon.api.auth import BearTokenAuthProvider
+from pypaimon.api.data_types import DataField, AtomicType
+from pypaimon.pvfs import PaimonVirtualFileSystem
+from pypaimon.tests.api_test import RESTCatalogServer
+
+
+class PVFSTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self.temp_dir = tempfile.mkdtemp(prefix="unittest_")
+ self.temp_path = Path(self.temp_dir)
+ # Create config
+ config = ConfigResponse(defaults={"prefix": "mock-test"})
+
+ # Create server
+ self.data_path = self.temp_dir
+ self.catalog = 'test_warehouse'
+ self.token = str(uuid.uuid4())
+ # Create server
+ self.server = RESTCatalogServer(
+ data_path=self.data_path,
+ auth_provider=BearTokenAuthProvider(self.token),
+ config=config,
+ warehouse=self.catalog)
+ self.server.start()
+ print(f"Server started at: {self.server.get_url()}")
+ print(f"create: {self.temp_path}")
+ options = {
+ 'uri': f"http://localhost:{self.server.port}",
+ 'warehouse': 'test_warehouse',
+ 'dlf.region': 'cn-hangzhou',
+ "token.provider": "bear",
+ 'token': self.token
+ }
+ self.pvfs = PaimonVirtualFileSystem(options)
+ self.database = 'test_database'
+ self.table = 'test_table'
+ self.test_databases = {
+ self.database: self.server.mock_database(self.database, {"k1":
"v1", "k2": "v2"}),
+ }
+ data_fields = [
+ DataField(0, "id", AtomicType('INT'), 'id'),
+ DataField(1, "name", AtomicType('STRING'), 'name')
+ ]
+ schema = TableSchema(len(data_fields), data_fields, len(data_fields),
[], [], {}, "")
+ self.server.database_store.update(self.test_databases)
+ self.test_tables = {
+ f"{self.database}.{self.table}":
TableMetadata(uuid=str(uuid.uuid4()), is_external=True, schema=schema),
+ }
+ self.server.table_metadata_store.update(self.test_tables)
+
+ def tearDown(self):
+ if self.temp_path.exists():
+ shutil.rmtree(self.temp_path)
+ print(f"clean: {self.temp_path}")
+ if self.server is not None:
+ self.server.shutdown()
+ print("Server stopped")
+
+ def _create_parquet_file(self, database: str, table: str, data_file_name:
str):
+ fs = self.pvfs
+ path = f'pvfs://{self.catalog}/{database}/{table}/{data_file_name}'
+ fs.mkdir(f'pvfs://{self.catalog}/{database}/{table}')
+ print(fs.ls(f'pvfs://{self.catalog}/{database}/{table}'))
+ fs.touch(path)
+ print(fs.ls(path))
+
self.assertEqual(fs.exists(f'pvfs://{self.catalog}/{database}/{table}'), True)
+ self.assertEqual(fs.exists(path), True)
+ data = {
+ 'id': [1, 2, 3, 4, 5],
+ 'name': ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve'],
+ }
+
+ df = pandas.DataFrame(data)
+
+ df.to_parquet(
+
f'{self.data_path}/{self.catalog}/{database}/{table}/{data_file_name}',
+ engine='pyarrow', index=False
+ )
+
+ def test_arrow(self):
+ import pyarrow.parquet as pq
+ fs = self.pvfs
+ database = 'arrow_db'
+ table = 'test_table'
+ data_file_name = 'a.parquet'
+ self._create_parquet_file(database, table, data_file_name)
+ path = f'pvfs://{self.catalog}/{database}/{table}/{data_file_name}'
+ dataset = pq.ParquetDataset(path, filesystem=fs)
+ table = dataset.read()
+ first_row = table.slice(0, 1).to_pydict()
+ print(f"first_row: {first_row}")
+ df = table.to_pandas()
+ self.assertEqual(len(df), 5)
+
+ def test_ray(self):
+ import ray
+ if not ray.is_initialized():
+ ray.init(ignore_reinit_error=True)
+ fs = self.pvfs
+ database = 'ray_db'
+ table = 'test_table'
+ data_file_name = 'a.parquet'
+ self._create_parquet_file(database, table, data_file_name)
+ path = f'pvfs://{self.catalog}/{database}/{table}/{data_file_name}'
+ ds = ray.data.read_parquet(filesystem=fs, paths=path)
+ print(ds.count())
+ self.assertEqual(ds.count(), 5)
+
+ def test_api(self):
+ nested_dir = self.temp_path / self.database / self.table
+ nested_dir.mkdir(parents=True)
+ data_file_name = 'a.parquet'
+ self._create_parquet_file(self.database, self.table, data_file_name)
+ database_dirs = self.pvfs.ls(f"pvfs://{self.catalog}", detail=False)
+ expect_database_dirs = set(map(
+ lambda x: self.pvfs._convert_database_virtual_path(self.catalog,
x),
+ list(self.test_databases.keys())
+ ))
+ self.assertSetEqual(set(database_dirs), expect_database_dirs)
+ table_dirs = self.pvfs.ls(f"pvfs://{self.catalog}/{self.database}",
detail=False)
+ expect_table_dirs = set(map(
+ lambda x: self.pvfs._convert_table_virtual_path(self.catalog,
self.database, x),
+ [self.table]
+ ))
+ self.assertSetEqual(set(table_dirs), expect_table_dirs)
+ database_virtual_path = f"pvfs://{self.catalog}/{self.database}"
+ self.assertEqual(database_virtual_path,
self.pvfs.info(database_virtual_path).get('name'))
+ self.assertEqual(True, self.pvfs.exists(database_virtual_path))
+ table_virtual_path =
f"pvfs://{self.catalog}/{self.database}/{self.table}"
+ self.assertEqual(table_virtual_path,
self.pvfs.info(table_virtual_path).get('name'))
+ self.assertEqual(True, self.pvfs.exists(database_virtual_path))
+ user_dirs =
self.pvfs.ls(f"pvfs://{self.catalog}/{self.database}/{self.table}",
detail=False)
+ self.assertSetEqual(set(user_dirs),
{f'pvfs://{self.catalog}/{self.database}/{self.table}/{data_file_name}'})
+
+ data_file_name = 'data.txt'
+ data_file_path =
f'pvfs://{self.catalog}/{self.database}/{self.table}/{data_file_name}'
+ self.pvfs.touch(data_file_path)
+ content = 'Hello World'
+ date_file_virtual_path =
f'pvfs://{self.catalog}/{self.database}/{self.table}/{data_file_name}'
+ data_file_name = 'data_2.txt'
+ date_file_new_virtual_path =
f'pvfs://{self.catalog}/{self.database}/{self.table}/{data_file_name}'
+ self.pvfs.cp(date_file_virtual_path, date_file_new_virtual_path)
+ self.assertEqual(True, self.pvfs.exists(date_file_virtual_path))
+ self.assertEqual(True, self.pvfs.exists(date_file_new_virtual_path))
+
+ data_file_mv_virtual_path =
f'pvfs://{self.catalog}/{self.database}/{self.table}/mv.txt'
+ self.pvfs.mv(date_file_virtual_path, data_file_mv_virtual_path)
+ self.assertEqual(False, self.pvfs.exists(date_file_virtual_path))
+ self.assertEqual(True, self.pvfs.exists(data_file_mv_virtual_path))
+
+ mv_source_table_path =
f'pvfs://{self.catalog}/{self.database}/mv_table1'
+ mv_des_table_path = f'pvfs://{self.catalog}/{self.database}/des_table1'
+ self.pvfs.mkdir(mv_source_table_path)
+ self.assertTrue(self.pvfs.exists(mv_source_table_path))
+ self.assertFalse(self.pvfs.exists(mv_des_table_path))
+ self.pvfs.mv(mv_source_table_path, mv_des_table_path)
+ self.assertTrue(self.pvfs.exists(mv_des_table_path))
+
+ with self.pvfs.open(date_file_new_virtual_path, 'w') as w:
+ w.write(content)
+
+ with self.pvfs.open(date_file_new_virtual_path, 'r', encoding='utf-8')
as file:
+ lines = file.readlines()
+ self.assertListEqual([content], lines)
+
+ database_new_virtual_path = f"pvfs://{self.catalog}/new_db"
+ self.assertEqual(False, self.pvfs.exists(database_new_virtual_path))
+ self.pvfs.mkdir(database_new_virtual_path)
+ self.assertEqual(True, self.pvfs.exists(database_new_virtual_path))
+
+ table_data_new_virtual_path =
f"pvfs://{self.catalog}/{self.database}/new_table/data.txt"
+ self.assertEqual(False, self.pvfs.exists(table_data_new_virtual_path))
+ self.pvfs.mkdir(table_data_new_virtual_path)
+ self.assertEqual(True, self.pvfs.exists(table_data_new_virtual_path))
+ self.pvfs.makedirs(table_data_new_virtual_path)
+ self.assertEqual(True, self.pvfs.exists(table_data_new_virtual_path))
+ self.assertTrue(self.pvfs.created(table_data_new_virtual_path) is not
None)
+ self.assertTrue(self.pvfs.modified(table_data_new_virtual_path) is not
None)
+ self.assertEqual('Hello World',
self.pvfs.cat_file(date_file_new_virtual_path).decode('utf-8'))
diff --git a/paimon-python/pypaimon/tests/api_test.py
b/paimon-python/pypaimon/tests/rest_server.py
similarity index 66%
copy from paimon-python/pypaimon/tests/api_test.py
copy to paimon-python/pypaimon/tests/rest_server.py
index d8b872c33f..5439402c6a 100644
--- a/paimon-python/pypaimon/tests/api_test.py
+++ b/paimon-python/pypaimon/tests/rest_server.py
@@ -1,39 +1,39 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
import logging
import re
+import time
import uuid
+from pathlib import Path
from typing import Dict, List, Optional, Any, Union, Tuple
from dataclasses import dataclass
import threading
from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import urlparse
-import unittest
import pypaimon.api as api
+from ..api import RenameTableRequest, CreateTableRequest, CreateDatabaseRequest
from ..api.api_response import (ConfigResponse, ListDatabasesResponse,
GetDatabaseResponse,
TableMetadata, Schema, GetTableResponse,
ListTablesResponse,
- TableSchema, RESTResponse, PagedList,
DataField)
-from ..api import RESTApi
+ TableSchema, RESTResponse, PagedList)
from ..api.rest_json import JSON
-from ..api.token_loader import DLFTokenLoaderFactory, DLFToken
from ..api.typedef import Identifier
-from ..api.data_types import AtomicInteger, DataTypeParser, AtomicType,
ArrayType, MapType, RowType
@dataclass
@@ -56,28 +56,6 @@ class ErrorResponse(RESTResponse):
code: int
-class ResourcePaths:
- """Resource path constants"""
-
- TABLES = "tables"
- VIEWS = "views"
- FUNCTIONS = "functions"
- SNAPSHOTS = "snapshots"
- ROLLBACK = "rollback"
-
- def __init__(self, prefix: str = ""):
- self.prefix = prefix.rstrip('/')
-
- def config(self) -> str:
- return "/v1/config"
-
- def databases(self) -> str:
- return f"/v1/{self.prefix}/databases"
-
- def tables(self) -> str:
- return f"{self.prefix}/tables"
-
-
# Exception classes
class CatalogException(Exception):
"""Base catalog exception"""
@@ -250,7 +228,7 @@ class RESTCatalogServer:
# Initialize resource paths
prefix = config.defaults.get("prefix")
- self.resource_paths = ResourcePaths(prefix)
+ self.resource_paths = api.ResourcePaths(prefix=prefix)
self.database_uri = self.resource_paths.databases()
# Initialize storage
@@ -374,6 +352,7 @@ class RESTCatalogServer:
data: str, headers: Dict[str, str]) -> Tuple[str, int]:
"""Route HTTP request to appropriate handler"""
try:
+ # Config endpoint
# Config endpoint
if resource_path.startswith(self.resource_paths.config()):
warehouse_param = parameters.get(WAREHOUSE)
@@ -388,16 +367,62 @@ class RESTCatalogServer:
return self._mock_response(self.token_json, 200)
# Databases endpoint
- if resource_path == self.database_uri or
resource_path.startswith(self.database_uri + "?"):
+ if resource_path == self.resource_paths.databases() or
resource_path.startswith(self.database_uri + "?"):
return self._databases_api_handler(method, data, parameters)
- # Global tables endpoint
- if resource_path.startswith(self.resource_paths.tables()):
- return self._tables_handle(parameters)
-
+ if resource_path == self.resource_paths.rename_table():
+ rename_request = JSON.from_json(data, RenameTableRequest)
+ source_table = rename_request.source
+ destination_table = rename_request.destination
+ source =
self.table_metadata_store.get(source_table.get_full_name())
+
self.table_metadata_store.update({destination_table.get_full_name(): source})
+ source_table_dir = (Path(self.data_path) / self.warehouse
+ / source_table.database_name /
source_table.object_name)
+ destination_table_dir = (Path(self.data_path) / self.warehouse
+ / destination_table.database_name /
destination_table.object_name)
+ if not source_table_dir.exists():
+ destination_table_dir.mkdir(parents=True)
+ else:
+ source_table_dir.rename(destination_table_dir)
+ return self._mock_response("", 200)
+
+ database = resource_path.split("/")[4]
# Database-specific endpoints
- if resource_path.startswith(self.database_uri + "/"):
- return self._handle_database_resource(method, resource_path,
parameters, data)
+ if
resource_path.startswith(self.resource_paths.database(database)):
+ """Handle database-specific resource requests"""
+ # Extract database name and resource path
+ path_parts = resource_path[len(self.database_uri) +
1:].split('/')
+ database_name = api.RESTUtil.decode_string(path_parts[0])
+
+ # Check database permissions
+ if database_name in self.no_permission_databases:
+ raise DatabaseNoPermissionException(database_name)
+
+ if database_name not in self.database_store:
+ raise DatabaseNotExistException(database_name)
+
+ # Handle different resource types
+ if len(path_parts) == 1:
+ # Database operations
+ return self._database_handle(method, data, database_name)
+
+ elif len(path_parts) == 2:
+ # Collection operations (tables, views, functions)
+ resource_type = path_parts[1]
+
+ if resource_type.startswith(api.ResourcePaths.TABLES):
+ return self._tables_handle(method, data,
database_name, parameters)
+
+ elif len(path_parts) >= 3:
+ # Individual resource operations
+ resource_type = path_parts[1]
+ resource_name = api.RESTUtil.decode_string(path_parts[2])
+ identifier = Identifier.create(database_name,
resource_name)
+
+ if resource_type == api.ResourcePaths.TABLES:
+ return self._handle_table_resource(method, path_parts,
identifier, data, parameters)
+
+ return self._mock_response(ErrorResponse(None, None, "Not
Found", 404), 404)
return self._mock_response(ErrorResponse(None, None, "Not Found",
404), 404)
@@ -426,43 +451,6 @@ class RESTCatalogServer:
response = ErrorResponse(None, None, str(e), 500)
return self._mock_response(response, 500)
- def _handle_database_resource(self, method: str, resource_path: str,
- parameters: Dict[str, str], data: str) ->
Tuple[str, int]:
- """Handle database-specific resource requests"""
- # Extract database name and resource path
- path_parts = resource_path[len(self.database_uri) + 1:].split('/')
- database_name = api.RESTUtil.decode_string(path_parts[0])
-
- # Check database permissions
- if database_name in self.no_permission_databases:
- raise DatabaseNoPermissionException(database_name)
-
- if database_name not in self.database_store:
- raise DatabaseNotExistException(database_name)
-
- # Handle different resource types
- if len(path_parts) == 1:
- # Database operations
- return self._database_handle(method, data, database_name)
-
- elif len(path_parts) == 2:
- # Collection operations (tables, views, functions)
- resource_type = path_parts[1]
-
- if resource_type.startswith(ResourcePaths.TABLES):
- return self._tables_handle(method, data, database_name,
parameters)
-
- elif len(path_parts) >= 3:
- # Individual resource operations
- resource_type = path_parts[1]
- resource_name = api.RESTUtil.decode_string(path_parts[2])
- identifier = Identifier.create(database_name, resource_name)
-
- if resource_type == ResourcePaths.TABLES:
- return self._handle_table_resource(method, path_parts,
identifier, data, parameters)
-
- return self._mock_response(ErrorResponse(None, None, "Not Found",
404), 404)
-
def _handle_table_resource(self, method: str, path_parts: List[str],
identifier: Identifier, data: str,
parameters: Dict[str, str]) -> Tuple[str, int]:
@@ -486,7 +474,12 @@ class RESTCatalogServer:
if not database_name_pattern or
self._match_name_pattern(db_name, database_name_pattern)
]
return self._generate_final_list_databases_response(parameters,
databases)
-
+ if method == "POST":
+ create_database = JSON.from_json(data, CreateDatabaseRequest)
+ self.database_store.update({
+ create_database.name: self.mock_database(create_database.name,
create_database.options)
+ })
+ return self._mock_response("", 200)
return self._mock_response(ErrorResponse(None, None, "Method Not
Allowed", 405), 405)
def _database_handle(self, method: str, data: str, database_name: str) ->
Tuple[str, int]:
@@ -516,24 +509,27 @@ class RESTCatalogServer:
if method == "GET":
tables = self._list_tables(database_name, parameters)
return self._generate_final_list_tables_response(parameters,
tables)
+ elif method == "POST":
+ create_table = JSON.from_json(data, CreateTableRequest)
+ table_metadata = self._create_table_metadata(
+ create_table.identifier, 1, create_table.schema,
str(uuid.uuid4()), False
+ )
+
self.table_metadata_store.update({create_table.identifier.get_full_name():
table_metadata})
+ table_dir = Path(self.data_path) / self.warehouse /
database_name / create_table.identifier.object_name
+ if not table_dir.exists():
+ table_dir.mkdir(parents=True)
+ return self._mock_response("", 200)
return self._mock_response(ErrorResponse(None, None, "Method Not
Allowed", 405), 405)
def _table_handle(self, method: str, data: str, identifier: Identifier) ->
Tuple[str, int]:
"""Handle individual table operations"""
if method == "GET":
- if identifier.is_system_table():
- # Handle system table
- schema = Schema(fields=[], options={PATH:
f"/tmp/{identifier.get_full_name()}"})
- table_metadata = self._create_table_metadata(identifier, 1,
schema, None, False)
- else:
- if identifier.get_full_name() not in self.table_metadata_store:
- raise TableNotExistException(identifier)
- table_metadata =
self.table_metadata_store[identifier.get_full_name()]
-
+ if identifier.get_full_name() not in self.table_metadata_store:
+ raise TableNotExistException(identifier)
+ table_metadata =
self.table_metadata_store[identifier.get_full_name()]
+ table_path =
f'file://{self.data_path}/{self.warehouse}/{identifier.database_name}/{identifier.object_name}'
schema = table_metadata.schema.to_schema()
- path = schema.options.pop(PATH, None)
-
- response = self.mock_table(identifier, table_metadata, path,
schema)
+ response = self.mock_table(identifier, table_metadata, table_path,
schema)
return self._mock_response(response, 200)
#
# elif method == "POST":
@@ -644,9 +640,6 @@ class RESTCatalogServer:
schema: Schema, uuid_str: str, is_external:
bool) -> TableMetadata:
"""Create table metadata"""
options = schema.options.copy()
- path = f"/tmp/{identifier.get_full_name()}"
- options[PATH] = path
-
table_schema = TableSchema(
id=schema_id,
fields=schema.fields,
@@ -726,9 +719,9 @@ class RESTCatalogServer:
location=f"{self.data_path}/{name}",
options=options,
owner="owner",
- created_at=1,
+ created_at=int(time.time()) * 1000,
created_by="created",
- updated_at=1,
+ updated_at=int(time.time()) * 1000,
updated_by="updated"
)
@@ -747,203 +740,3 @@ class RESTCatalogServer:
updated_at=1,
updated_by="updated"
)
-
-
-class ApiTestCase(unittest.TestCase):
-
- def test_parse_data(self):
- simple_type_test_cases = [
- "DECIMAL",
- "DECIMAL(5)",
- "DECIMAL(10, 2)",
- "DECIMAL(38, 18)",
- "VARBINARY",
- "VARBINARY(100)",
- "VARBINARY(1024)",
- "BYTES",
- "VARCHAR(255)",
- "CHAR(10)",
- "INT",
- "BOOLEAN"
- ]
- for type_str in simple_type_test_cases:
- data_type = DataTypeParser.parse_data_type(type_str)
- self.assertEqual(data_type.nullable, True)
- self.assertEqual(data_type.type, type_str)
- field_id = AtomicInteger(0)
- simple_type = DataTypeParser.parse_data_type("VARCHAR(32)")
- self.assertEqual(simple_type.nullable, True)
- self.assertEqual(simple_type.type, 'VARCHAR(32)')
-
- array_json = {
- "type": "ARRAY",
- "element": "INT"
- }
- array_type = DataTypeParser.parse_data_type(array_json, field_id)
- self.assertEqual(array_type.element.type, 'INT')
-
- map_json = {
- "type": "MAP",
- "key": "STRING",
- "value": "INT"
- }
- map_type = DataTypeParser.parse_data_type(map_json, field_id)
- self.assertEqual(map_type.key.type, 'STRING')
- self.assertEqual(map_type.value.type, 'INT')
- row_json = {
- "type": "ROW",
- "fields": [
- {
- "name": "id",
- "type": "BIGINT",
- "description": "Primary key"
- },
- {
- "name": "name",
- "type": "VARCHAR(100)",
- "description": "User name"
- },
- {
- "name": "scores",
- "type": {
- "type": "ARRAY",
- "element": "DOUBLE"
- }
- }
- ]
- }
-
- row_type: RowType = DataTypeParser.parse_data_type(row_json,
AtomicInteger(0))
- self.assertEqual(row_type.fields[0].type.type, 'BIGINT')
- self.assertEqual(row_type.fields[1].type.type, 'VARCHAR(100)')
-
- complex_json = {
- "type": "ARRAY",
- "element": {
- "type": "MAP",
- "key": "STRING",
- "value": {
- "type": "ROW",
- "fields": [
- {"name": "count", "type": "BIGINT"},
- {"name": "percentage", "type": "DOUBLE"}
- ]
- }
- }
- }
-
- complex_type: ArrayType = DataTypeParser.parse_data_type(complex_json,
field_id)
- element_type: MapType = complex_type.element
- value_type: RowType = element_type.value
- self.assertEqual(value_type.fields[0].type.type, 'BIGINT')
- self.assertEqual(value_type.fields[1].type.type, 'DOUBLE')
-
- def test_api(self):
- """Example usage of RESTCatalogServer"""
- # Setup logging
- logging.basicConfig(level=logging.INFO)
-
- # Create config
- config = ConfigResponse(defaults={"prefix": "mock-test"})
-
- # Create mock auth provider
- class MockAuthProvider:
- def merge_auth_header(self, headers, auth_param):
- return {AUTHORIZATION_HEADER_KEY: "Bearer test-token"}
-
- # Create server
- server = RESTCatalogServer(
- data_path="/tmp/test_warehouse",
- auth_provider=MockAuthProvider(),
- config=config,
- warehouse="test_warehouse"
- )
- try:
- # Start server
- server.start()
- print(f"Server started at: {server.get_url()}")
- test_databases = {
- "default": server.mock_database("default", {"env": "test"}),
- "test_db1": server.mock_database("test_db1", {"env": "test"}),
- "test_db2": server.mock_database("test_db2", {"env": "test"}),
- "prod_db": server.mock_database("prod_db", {"env": "prod"})
- }
- data_fields = [
- DataField(0, "name", AtomicType('INT'), 'desc name'),
- DataField(1, "arr11", ArrayType(True, AtomicType('INT')),
'desc arr11'),
- DataField(2, "map11", MapType(False, AtomicType('INT'),
- MapType(False,
AtomicType('INT'), AtomicType('INT'))),
- 'desc arr11'),
- ]
- schema = TableSchema(len(data_fields), data_fields,
len(data_fields), [], [], {}, "")
- test_tables = {
- "default.user": TableMetadata(uuid=str(uuid.uuid4()),
is_external=True, schema=schema),
- }
- server.table_metadata_store.update(test_tables)
- server.database_store.update(test_databases)
- options = {
- 'uri': f"http://localhost:{server.port}",
- 'warehouse': 'test_warehouse',
- 'dlf.region': 'cn-hangzhou',
- "token.provider": "xxxx",
- 'dlf.access-key-id': 'xxxx',
- 'dlf.access-key-secret': 'xxxx'
- }
- api = RESTApi(options)
- self.assertSetEqual(set(api.list_databases()), {*test_databases})
- self.assertEqual(api.get_database('default'),
test_databases.get('default'))
- table = api.get_table(Identifier.from_string('default.user'))
- self.assertEqual(table.id, str(test_tables['default.user'].uuid))
-
- finally:
- # Shutdown server
- server.shutdown()
- print("Server stopped")
-
- def test_ecs_loader_token(self):
- token = DLFToken(
- access_key_id='AccessKeyId',
- access_key_secret='AccessKeySecret',
- security_token='AQoDYXdzEJr...<remainder of security token>',
- expiration="2023-12-01T12:00:00Z"
- )
- token_json = JSON.to_json(token)
- role_name = 'test_role'
- config = ConfigResponse(defaults={"prefix": "mock-test"})
- server = RESTCatalogServer(
- data_path="/tmp/test_warehouse",
- auth_provider=None,
- config=config,
- warehouse="test_warehouse",
- role_name=role_name,
- token_json=token_json
- )
- try:
- # Start server
- server.start()
- ecs_metadata_url =
f"http://localhost:{server.port}/ram/security-credential/"
- options = {
- api.RESTCatalogOptions.DLF_TOKEN_LOADER: 'ecs',
- api.RESTCatalogOptions.DLF_TOKEN_ECS_METADATA_URL:
ecs_metadata_url
- }
- loader = DLFTokenLoaderFactory.create_token_loader(options)
- load_token = loader.load_token()
- self.assertEqual(load_token.access_key_id, token.access_key_id)
- self.assertEqual(load_token.access_key_secret,
token.access_key_secret)
- self.assertEqual(load_token.security_token, token.security_token)
- self.assertEqual(load_token.expiration, token.expiration)
- options_with_role = {
- api.RESTCatalogOptions.DLF_TOKEN_LOADER: 'ecs',
- api.RESTCatalogOptions.DLF_TOKEN_ECS_METADATA_URL:
ecs_metadata_url,
- api.RESTCatalogOptions.DLF_TOKEN_ECS_ROLE_NAME: role_name,
- }
- loader =
DLFTokenLoaderFactory.create_token_loader(options_with_role)
- token = loader.load_token()
- self.assertEqual(load_token.access_key_id, token.access_key_id)
- self.assertEqual(load_token.access_key_secret,
token.access_key_secret)
- self.assertEqual(load_token.security_token, token.security_token)
- self.assertEqual(load_token.expiration, token.expiration)
- finally:
- # Shutdown server
- server.shutdown()
- print("Server stopped")
diff --git a/paimon-python/setup.py b/paimon-python/setup.py
index 3fbc6108dc..0d82a25ffb 100644
--- a/paimon-python/setup.py
+++ b/paimon-python/setup.py
@@ -21,7 +21,12 @@ VERSION = "0.3.dev" # noqa
PACKAGES = find_packages(include=["pypaimon*"])
-install_requires = []
+install_requires = [
+ 'readerwriterlock==1.0.9',
+ 'fsspec==2024.3.1',
+ 'cachetools==5.3.3',
+ 'ossfs==2023.12.0'
+]
long_description = "See Apache Paimon Python API \
[Doc](https://paimon.apache.org/docs/master/program-api/python-api/) for
usage."