This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new da8e246a19 [python] Support read paimon table as pytorch dataset
(#6987)
da8e246a19 is described below
commit da8e246a197d055c273696bd486e5f167621ee5f
Author: umi <[email protected]>
AuthorDate: Fri Jan 9 19:19:01 2026 +0800
[python] Support read paimon table as pytorch dataset (#6987)
---
.github/workflows/paimon-python-checks.yml | 82 +--
docs/content/program-api/python-api.md | 67 ++-
paimon-python/dev/lint-python.sh | 31 +-
paimon-python/dev/requirements.txt | 20 +-
.../read/{ray_datasource.py => datasource.py} | 138 ++++-
paimon-python/pypaimon/read/table_read.py | 15 +-
paimon-python/pypaimon/tests/blob_table_test.py | 2 +-
.../pypaimon/tests/reader_append_only_test.py | 78 +--
.../pypaimon/tests/reader_primary_key_test.py | 2 +-
paimon-python/pypaimon/tests/torch_read_test.py | 635 +++++++++++++++++++++
10 files changed, 953 insertions(+), 117 deletions(-)
diff --git a/.github/workflows/paimon-python-checks.yml
b/.github/workflows/paimon-python-checks.yml
index 4fb7fe07e4..ff2929c0fa 100755
--- a/.github/workflows/paimon-python-checks.yml
+++ b/.github/workflows/paimon-python-checks.yml
@@ -46,7 +46,7 @@ jobs:
container: "python:${{ matrix.python-version }}-slim"
strategy:
matrix:
- python-version: ['3.6.15', '3.10']
+ python-version: [ '3.6.15', '3.10' ]
steps:
- name: Checkout code
@@ -70,6 +70,7 @@ jobs:
build-essential \
git \
curl \
+ && apt-get clean \
&& rm -rf /var/lib/apt/lists/*
- name: Verify Java and Maven installation
@@ -88,21 +89,24 @@ jobs:
- name: Install Python dependencies
shell: bash
run: |
+ df -h
if [[ "${{ matrix.python-version }}" == "3.6.15" ]]; then
python -m pip install --upgrade pip==21.3.1
python --version
- python -m pip install -q pyroaring readerwriterlock==1.0.9
'fsspec==2021.10.1' 'cachetools==4.2.4' 'ossfs==2021.8.0' pyarrow==6.0.1
pandas==1.1.5 'polars==0.9.12' 'fastavro==1.4.7' zstandard==0.19.0
dataclasses==0.8.0 flake8 pytest py4j==0.10.9.9 requests parameterized==0.8.1
2>&1 >/dev/null
+ python -m pip install --no-cache-dir pyroaring
readerwriterlock==1.0.9 'fsspec==2021.10.1' 'cachetools==4.2.4'
'ossfs==2021.8.0' pyarrow==6.0.1 pandas==1.1.5 'polars==0.9.12'
'fastavro==1.4.7' zstandard==0.19.0 dataclasses==0.8.0 flake8 pytest
py4j==0.10.9.9 requests parameterized==0.8.1 2>&1 >/dev/null
else
python -m pip install --upgrade pip
- python -m pip install -q pyroaring readerwriterlock==1.0.9
fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0
fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2
numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0
py4j==0.10.9.9 requests parameterized==0.9.0 2>&1 >/dev/null
+ pip install torch --index-url https://download.pytorch.org/whl/cpu
+ python -m pip install pyroaring readerwriterlock==1.0.9
fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0
fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2
numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0
py4j==0.10.9.9 requests parameterized==0.9.0
fi
+ df -h
- name: Run lint-python.sh
shell: bash
run: |
chmod +x paimon-python/dev/lint-python.sh
- ./paimon-python/dev/lint-python.sh
+ ./paimon-python/dev/lint-python.sh -e pytest_torch
- requirement_version_compatible_test:
+ torch_test:
runs-on: ubuntu-latest
container: "python:3.10-slim"
@@ -110,17 +114,6 @@ jobs:
- name: Checkout code
uses: actions/checkout@v2
- - name: Set up JDK ${{ env.JDK_VERSION }}
- uses: actions/setup-java@v4
- with:
- java-version: ${{ env.JDK_VERSION }}
- distribution: 'temurin'
-
- - name: Set up Maven
- uses: stCarolas/[email protected]
- with:
- maven-version: 3.8.8
-
- name: Install system dependencies
shell: bash
run: |
@@ -128,26 +121,50 @@ jobs:
build-essential \
git \
curl \
+ && apt-get clean \
&& rm -rf /var/lib/apt/lists/*
- - name: Verify Java and Maven installation
- run: |
- java -version
- mvn -version
-
- name: Verify Python version
run: python --version
- - name: Build Java
+ - name: Install Python dependencies
+ shell: bash
run: |
- echo "Start compiling modules"
- mvn -T 2C -B clean install -DskipTests
+ python -m pip install --upgrade pip
+ pip install torch --index-url https://download.pytorch.org/whl/cpu
+ python -m pip install pyroaring readerwriterlock==1.0.9
fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0
fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2
numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0
py4j==0.10.9.9 requests parameterized==0.9.0
+ - name: Run lint-python.sh
+ shell: bash
+ run: |
+ chmod +x paimon-python/dev/lint-python.sh
+ ./paimon-python/dev/lint-python.sh -i pytest_torch
+
+ requirement_version_compatible_test:
+ runs-on: ubuntu-latest
+ container: "python:3.10-slim"
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v2
+
+ - name: Install system dependencies
+ shell: bash
+ run: |
+ apt-get update && apt-get install -y \
+ build-essential \
+ git \
+ curl \
+ && rm -rf /var/lib/apt/lists/*
+
+ - name: Verify Python version
+ run: python --version
- name: Install base Python dependencies
shell: bash
run: |
python -m pip install --upgrade pip
- python -m pip install -q \
+ pip install torch --index-url https://download.pytorch.org/whl/cpu
+ python -m pip install --no-cache-dir \
pyroaring \
readerwriterlock==1.0.9 \
fsspec==2024.3.1 \
@@ -165,36 +182,37 @@ jobs:
requests \
parameterized==0.9.0 \
packaging
+
- name: Test requirement version compatibility
shell: bash
run: |
cd paimon-python
-
+
# Test Ray version compatibility
echo "=========================================="
echo "Testing Ray version compatibility"
echo "=========================================="
for ray_version in 2.44.0 2.48.0 2.53.0; do
echo "Testing Ray version: $ray_version"
-
+
# Install specific Ray version
- python -m pip install -q ray==$ray_version
-
+ python -m pip install --no-cache-dir -q ray==$ray_version
+
# Verify Ray version
python -c "import ray; print(f'Ray version: {ray.__version__}')"
python -c "from packaging.version import parse; import ray; assert
parse(ray.__version__) == parse('$ray_version'), f'Expected Ray $ray_version,
got {ray.__version__}'"
-
+
# Run tests
python -m pytest pypaimon/tests/ray_data_test.py::RayDataTest -v
--tb=short || {
echo "Tests failed for Ray $ray_version"
exit 1
}
-
+
# Uninstall Ray to avoid conflicts
python -m pip uninstall -y ray
done
-
+
# Add other dependency version tests here in the future
# Example:
# echo "=========================================="
diff --git a/docs/content/program-api/python-api.md
b/docs/content/program-api/python-api.md
index 406c8c1ef6..aa7773e049 100644
--- a/docs/content/program-api/python-api.md
+++ b/docs/content/program-api/python-api.md
@@ -72,6 +72,7 @@ catalog_options = {
}
catalog = CatalogFactory.create(catalog_options)
```
+
{{< /tab >}}
{{< /tabs >}}
@@ -473,6 +474,38 @@ ray_dataset = table_read.to_ray(splits)
See [Ray Data API
Documentation](https://docs.ray.io/en/latest/data/api/doc/ray.data.read_datasource.html)
for more details.
+### Read Pytorch Dataset
+
+This requires `torch` to be installed.
+
+You can read all the data into a `torch.utils.data.Dataset` or
`torch.utils.data.IterableDataset`:
+
+```python
+from torch.utils.data import DataLoader
+
+table_read = read_builder.new_read()
+dataset = table_read.to_torch(splits, streaming=True)
+dataloader = DataLoader(
+ dataset,
+ batch_size=2,
+ num_workers=2, # Concurrency to read data
+ shuffle=False
+)
+
+# Collect all data from dataloader
+for batch_idx, batch_data in enumerate(dataloader):
+ print(batch_data)
+
+# output:
+# {'user_id': tensor([1, 2]), 'behavior': ['a', 'b']}
+# {'user_id': tensor([3, 4]), 'behavior': ['c', 'd']}
+# {'user_id': tensor([5, 6]), 'behavior': ['e', 'f']}
+# {'user_id': tensor([7, 8]), 'behavior': ['g', 'h']}
+```
+
+When the `streaming` parameter is true, it will iteratively read;
+when it is false, it will read the full amount of data into memory.
+
### Incremental Read
This API allows reading data committed between two snapshot timestamps. The
steps are as follows.
@@ -671,22 +704,22 @@ Key points about shard read:
The following shows the supported features of Python Paimon compared to Java
Paimon:
**Catalog Level**
- - FileSystemCatalog
- - RestCatalog
+ - FileSystemCatalog
+ - RestCatalog
**Table Level**
- - Append Tables
- - `bucket = -1` (unaware)
- - `bucket > 0` (fixed)
- - Primary Key Tables
- - only support deduplicate
- - `bucket = -2` (postpone)
- - `bucket > 0` (fixed)
- - read with deletion vectors enabled
- - Read/Write Operations
- - Batch read and write for append tables and primary key tables
- - Predicate filtering
- - Overwrite semantics
- - Incremental reading of Delta data
- - Reading and writing blob data
- - `with_shard` feature
+ - Append Tables
+ - `bucket = -1` (unaware)
+ - `bucket > 0` (fixed)
+ - Primary Key Tables
+ - only support deduplicate
+ - `bucket = -2` (postpone)
+ - `bucket > 0` (fixed)
+ - read with deletion vectors enabled
+ - Read/Write Operations
+ - Batch read and write for append tables and primary key tables
+ - Predicate filtering
+ - Overwrite semantics
+ - Incremental reading of Delta data
+ - Reading and writing blob data
+ - `with_shard` feature
diff --git a/paimon-python/dev/lint-python.sh b/paimon-python/dev/lint-python.sh
index d174b120ad..44be287149 100755
--- a/paimon-python/dev/lint-python.sh
+++ b/paimon-python/dev/lint-python.sh
@@ -107,7 +107,7 @@ function collect_checks() {
function get_all_supported_checks() {
_OLD_IFS=$IFS
IFS=$'\n'
- SUPPORT_CHECKS=("flake8_check" "pytest_check" "mixed_check") # control the
calling sequence
+ SUPPORT_CHECKS=("flake8_check" "pytest_torch_check" "pytest_check"
"mixed_check") # control the calling sequence
for fun in $(declare -F); do
if [[ `regexp_match "$fun" "_check$"` = true ]]; then
check_name="${fun:11}"
@@ -179,7 +179,7 @@ function pytest_check() {
TEST_DIR="pypaimon/tests/py36"
echo "Running tests for Python 3.6: $TEST_DIR"
else
- TEST_DIR="pypaimon/tests --ignore=pypaimon/tests/py36
--ignore=pypaimon/tests/e2e"
+ TEST_DIR="pypaimon/tests --ignore=pypaimon/tests/py36
--ignore=pypaimon/tests/e2e --ignore=pypaimon/tests/torch_read_test.py"
echo "Running tests for Python $PYTHON_VERSION (excluding py36):
pypaimon/tests --ignore=pypaimon/tests/py36"
fi
@@ -197,7 +197,32 @@ function pytest_check() {
print_function "STAGE" "pytest checks... [SUCCESS]"
fi
}
+function pytest_torch_check() {
+ print_function "STAGE" "pytest torch checks"
+ if [ ! -f "$PYTEST_PATH" ]; then
+ echo "For some unknown reasons, the pytest package is not complete."
+ fi
+ # Get Python version
+ PYTHON_VERSION=$(python -c "import sys;
print(f'{sys.version_info.major}.{sys.version_info.minor}')")
+ echo "Detected Python version: $PYTHON_VERSION"
+ TEST_DIR="pypaimon/tests/torch_read_test.py"
+ echo "Running tests for Python $PYTHON_VERSION:
pypaimon/tests/torch_read_test.py"
+
+ # the return value of a pipeline is the status of the last command to exit
+ # with a non-zero status or zero if no command exited with a non-zero
status
+ set -o pipefail
+ ($PYTEST_PATH $TEST_DIR) 2>&1 | tee -a $LOG_FILE
+
+ PYCODESTYLE_STATUS=$?
+ if [ $PYCODESTYLE_STATUS -ne 0 ]; then
+ print_function "STAGE" "pytest checks... [FAILED]"
+ # Stop the running script.
+ exit 1;
+ else
+ print_function "STAGE" "pytest checks... [SUCCESS]"
+ fi
+}
# Mixed tests check - runs Java-Python interoperability tests
function mixed_check() {
# Get Python version
@@ -279,7 +304,7 @@ usage: $0 [options]
-l list all checks supported.
Examples:
./lint-python.sh => exec all checks.
- ./lint-python.sh -e tox,flake8 => exclude checks tox,flake8.
+ ./lint-python.sh -e flake8 => exclude checks flake8.
./lint-python.sh -i flake8 => include checks flake8.
./lint-python.sh -i mixed => include checks mixed.
./lint-python.sh -l => list all checks supported.
diff --git a/paimon-python/dev/requirements.txt
b/paimon-python/dev/requirements.txt
index 703adec8e8..e76827db3e 100644
--- a/paimon-python/dev/requirements.txt
+++ b/paimon-python/dev/requirements.txt
@@ -19,27 +19,23 @@
cachetools>=4.2,<6; python_version=="3.6"
cachetools>=5,<6; python_version>"3.6"
dataclasses>=0.8; python_version < "3.7"
-fastavro>=1.4,<2; python_version<"3.9"
-fastavro>=1.4,<2; python_version>="3.9"
+fastavro>=1.4,<2
fsspec>=2021.10,<2026; python_version<"3.8"
fsspec>=2023,<2026; python_version>="3.8"
ossfs>=2021.8; python_version<"3.8"
ossfs>=2023; python_version>="3.8"
-packaging>=21,<26; python_version<"3.8"
-packaging>=21,<26; python_version>="3.8"
+packaging>=21,<26
pandas>=1.1,<2; python_version < "3.7"
pandas>=1.3,<3; python_version >= "3.7" and python_version < "3.9"
pandas>=1.5,<3; python_version >= "3.9"
polars>=0.9,<1; python_version<"3.8"
-polars>=1,<2; python_version=="3.8"
-polars>=1,<2; python_version>"3.8"
+polars>=1,<2; python_version>="3.8"
pyarrow>=6,<7; python_version < "3.8"
-pyarrow>=16,<20; python_version >= "3.8" and python_version < "3.13"
-pyarrow>=16,<20; python_version >= "3.13"
+pyarrow>=16,<20; python_version >= "3.8"
+pylance>=0.20,<1; python_version>="3.9"
+pylance>=0.10,<1; python_version>="3.8" and python_version<"3.9"
pyroaring
ray>=2.10,<3
readerwriterlock>=1,<2
-zstandard>=0.19,<1; python_version<"3.9"
-zstandard>=0.19,<1; python_version>="3.9"
-pylance>=0.20,<1; python_version>="3.9"
-pylance>=0.10,<1; python_version>="3.8" and python_version<"3.9"
+torch
+zstandard>=0.19,<1
\ No newline at end of file
diff --git a/paimon-python/pypaimon/read/ray_datasource.py
b/paimon-python/pypaimon/read/datasource.py
similarity index 67%
rename from paimon-python/pypaimon/read/ray_datasource.py
rename to paimon-python/pypaimon/read/datasource.py
index 905c8bddef..835effbf0b 100644
--- a/paimon-python/pypaimon/read/ray_datasource.py
+++ b/paimon-python/pypaimon/read/datasource.py
@@ -27,6 +27,7 @@ from typing import List, Optional, Iterable
import pyarrow
from packaging.version import parse
import ray
+import torch
from pypaimon.read.split import Split
from pypaimon.read.table_read import TableRead
@@ -40,8 +41,10 @@ RAY_VERSION_PER_TASK_ROW_LIMIT = "2.52.0" #
per_task_row_limit parameter introd
from ray.data.datasource import Datasource
+from torch.utils.data import Dataset, IterableDataset
-class PaimonDatasource(Datasource):
+
+class RayDatasource(Datasource):
"""
Ray Data Datasource implementation for reading Paimon tables.
@@ -76,7 +79,7 @@ class PaimonDatasource(Datasource):
@staticmethod
def _distribute_splits_into_equal_chunks(
- splits: Iterable[Split], n_chunks: int
+ splits: Iterable[Split], n_chunks: int
) -> List[List[Split]]:
"""
Implement a greedy knapsack algorithm to distribute the splits across
tasks,
@@ -88,7 +91,7 @@ class PaimonDatasource(Datasource):
# From largest to smallest, add the splits to the smallest chunk one
at a time
for split in sorted(
- splits, key=lambda s: s.file_size if hasattr(s, 'file_size') and
s.file_size > 0 else 0, reverse=True
+ splits, key=lambda s: s.file_size if hasattr(s, 'file_size')
and s.file_size > 0 else 0, reverse=True
):
smallest_chunk = heapq.heappop(chunk_sizes)
chunks[smallest_chunk[1]].append(split)
@@ -132,11 +135,11 @@ class PaimonDatasource(Datasource):
# Create a partial function to avoid capturing self in closure
# This reduces serialization overhead (see
https://github.com/ray-project/ray/issues/49107)
def _get_read_task(
- splits: List[Split],
- table=table,
- predicate=predicate,
- read_type=read_type,
- schema=schema,
+ splits: List[Split],
+ table=table,
+ predicate=predicate,
+ read_type=read_type,
+ schema=schema,
) -> Iterable[pyarrow.Table]:
"""Read function that will be executed by Ray workers."""
from pypaimon.read.table_read import TableRead
@@ -216,13 +219,128 @@ class PaimonDatasource(Datasource):
'read_fn': read_fn,
'metadata': metadata,
}
-
+
if parse(ray.__version__) >=
parse(RAY_VERSION_SCHEMA_IN_READ_TASK):
read_task_kwargs['schema'] = schema
-
+
if parse(ray.__version__) >= parse(RAY_VERSION_PER_TASK_ROW_LIMIT)
and per_task_row_limit is not None:
read_task_kwargs['per_task_row_limit'] = per_task_row_limit
read_tasks.append(ReadTask(**read_task_kwargs))
return read_tasks
+
+
+class TorchDataset(Dataset):
+ """
+ PyTorch Dataset implementation for reading Paimon table data.
+
+ This class enables Paimon table data to be used directly with PyTorch's
+ training pipeline, allowing for efficient data loading and batching.
+ """
+
+ def __init__(self, table_read: TableRead, splits: List[Split]):
+ """
+ Initialize TorchDataset.
+
+ Args:
+ table_read: TableRead instance for reading data
+ splits: List of splits to read
+ """
+ arrow_table = table_read.to_arrow(splits)
+ if arrow_table is None or arrow_table.num_rows == 0:
+ self._data = []
+ else:
+ self._data = arrow_table.to_pylist()
+
+ def __len__(self) -> int:
+ """
+ Return the total number of rows in the dataset.
+
+ Returns:
+ Total number of rows across all splits
+ """
+ return len(self._data)
+
+ def __getitem__(self, index: int):
+ """
+ Get a single item from the dataset.
+
+ Args:
+ index: Index of the item to retrieve
+
+ Returns:
+ Dictionary containing the row data
+ """
+ if not self._data:
+ return None
+
+ return self._data[index]
+
+
+class TorchIterDataset(IterableDataset):
+ """
+ PyTorch IterableDataset implementation for reading Paimon table data.
+
+ This class enables streaming data loading from Paimon tables, which is more
+ memory-efficient for large datasets. Data is read on-the-fly as needed,
+ rather than loading everything into memory upfront.
+ """
+
+ def __init__(self, table_read: TableRead, splits: List[Split]):
+ """
+ Initialize TorchIterDataset.
+
+ Args:
+ table_read: TableRead instance for reading data
+ splits: List of splits to read
+ """
+ self.table_read = table_read
+ self.splits = splits
+ # Get field names from read_type
+ self.field_names = [field.name for field in table_read.read_type]
+
+ def __iter__(self):
+ """
+ Iterate over the dataset, converting each OffsetRow to a dictionary.
+
+ Supports multi-worker data loading by partitioning splits across
workers.
+ When num_workers > 0 in DataLoader, each worker will process a subset
of splits.
+
+ Yields:
+ row data of dict type, where keys are column names
+ """
+ worker_info = torch.utils.data.get_worker_info()
+
+ if worker_info is None:
+ # Single-process data loading, iterate over all splits
+ splits_to_process = self.splits
+ else:
+ # Multi-process data loading, partition splits across workers
+ worker_id = worker_info.id
+ num_workers = worker_info.num_workers
+
+ # Calculate start and end indices for this worker
+ # Distribute splits evenly by slicing
+ total_splits = len(self.splits)
+ splits_per_worker = total_splits // num_workers
+ remainder = total_splits % num_workers
+
+ # Workers with id < remainder get one extra split
+ if worker_id < remainder:
+ start_idx = worker_id * (splits_per_worker + 1)
+ end_idx = start_idx + splits_per_worker + 1
+ else:
+ start_idx = worker_id * splits_per_worker + remainder
+ end_idx = start_idx + splits_per_worker
+
+ splits_to_process = self.splits[start_idx:end_idx]
+
+ worker_iterator = self.table_read.to_iterator(splits_to_process)
+
+ for offset_row in worker_iterator:
+ row_dict = {}
+ for i, field_name in enumerate(self.field_names):
+ value = offset_row.get_field(i)
+ row_dict[field_name] = value
+ yield row_dict
diff --git a/paimon-python/pypaimon/read/table_read.py
b/paimon-python/pypaimon/read/table_read.py
index 953384cc7d..7e8dbda412 100644
--- a/paimon-python/pypaimon/read/table_read.py
+++ b/paimon-python/pypaimon/read/table_read.py
@@ -165,8 +165,8 @@ class TableRead:
if override_num_blocks is not None and override_num_blocks < 1:
raise ValueError(f"override_num_blocks must be at least 1, got
{override_num_blocks}")
- from pypaimon.read.ray_datasource import PaimonDatasource
- datasource = PaimonDatasource(self, splits)
+ from pypaimon.read.datasource import RayDatasource
+ datasource = RayDatasource(self, splits)
return ray.data.read_datasource(
datasource,
ray_remote_args=ray_remote_args,
@@ -175,6 +175,17 @@ class TableRead:
**read_args
)
+ def to_torch(self, splits: List[Split], streaming: bool = False) ->
"torch.utils.data.Dataset":
+ """Wrap Paimon table data to PyTorch Dataset."""
+ if streaming:
+ from pypaimon.read.datasource import TorchIterDataset
+ dataset = TorchIterDataset(self, splits)
+ return dataset
+ else:
+ from pypaimon.read.datasource import TorchDataset
+ dataset = TorchDataset(self, splits)
+ return dataset
+
def _create_split_read(self, split: Split) -> SplitRead:
if self.table.is_primary_key_table and not split.raw_convertible:
return MergeFileSplitRead(
diff --git a/paimon-python/pypaimon/tests/blob_table_test.py
b/paimon-python/pypaimon/tests/blob_table_test.py
index f87f73ded7..9925e21be5 100755
--- a/paimon-python/pypaimon/tests/blob_table_test.py
+++ b/paimon-python/pypaimon/tests/blob_table_test.py
@@ -2644,7 +2644,7 @@ class DataBlobWriterTest(unittest.TestCase):
# Create and start multiple threads
threads = []
- num_threads = 100
+ num_threads = 10
for i in range(num_threads):
thread = threading.Thread(
target=write_blob_data,
diff --git a/paimon-python/pypaimon/tests/reader_append_only_test.py
b/paimon-python/pypaimon/tests/reader_append_only_test.py
index d65658ef5c..adb0ff4f25 100644
--- a/paimon-python/pypaimon/tests/reader_append_only_test.py
+++ b/paimon-python/pypaimon/tests/reader_append_only_test.py
@@ -438,44 +438,6 @@ class AoReaderTest(unittest.TestCase):
}, schema=self.pa_schema).sort_by('user_id')
self.assertEqual(expected, actual)
- def _write_test_table(self, table):
- write_builder = table.new_batch_write_builder()
-
- # first write
- table_write = write_builder.new_write()
- table_commit = write_builder.new_commit()
- data1 = {
- 'user_id': [1, 2, 3, 4],
- 'item_id': [1001, 1002, 1003, 1004],
- 'behavior': ['a', 'b', 'c', None],
- 'dt': ['p1', 'p1', 'p2', 'p1'],
- }
- pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
- table_write.write_arrow(pa_table)
- table_commit.commit(table_write.prepare_commit())
- table_write.close()
- table_commit.close()
-
- # second write
- table_write = write_builder.new_write()
- table_commit = write_builder.new_commit()
- data2 = {
- 'user_id': [5, 6, 7, 8],
- 'item_id': [1005, 1006, 1007, 1008],
- 'behavior': ['e', 'f', 'g', 'h'],
- 'dt': ['p2', 'p1', 'p2', 'p2'],
- }
- pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema)
- table_write.write_arrow(pa_table)
- table_commit.commit(table_write.prepare_commit())
- table_write.close()
- table_commit.close()
-
- def _read_test_table(self, read_builder):
- table_read = read_builder.new_read()
- splits = read_builder.new_scan().plan().splits()
- return table_read.to_arrow(splits)
-
def test_concurrent_writes_with_retry(self):
"""Test concurrent writes to verify retry mechanism works correctly."""
import threading
@@ -529,7 +491,7 @@ class AoReaderTest(unittest.TestCase):
# Create and start multiple threads
threads = []
- num_threads = 100
+ num_threads = 10
for i in range(num_threads):
thread = threading.Thread(
target=write_data,
@@ -576,3 +538,41 @@ class AoReaderTest(unittest.TestCase):
f"got {latest_snapshot.id}")
print(f"✓ Iteration {test_iteration + 1}/{iter_num} completed
successfully")
+
+ def _write_test_table(self, table):
+ write_builder = table.new_batch_write_builder()
+
+ # first write
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ data1 = {
+ 'user_id': [1, 2, 3, 4],
+ 'item_id': [1001, 1002, 1003, 1004],
+ 'behavior': ['a', 'b', 'c', None],
+ 'dt': ['p1', 'p1', 'p2', 'p1'],
+ }
+ pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # second write
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ data2 = {
+ 'user_id': [5, 6, 7, 8],
+ 'item_id': [1005, 1006, 1007, 1008],
+ 'behavior': ['e', 'f', 'g', 'h'],
+ 'dt': ['p2', 'p1', 'p2', 'p2'],
+ }
+ pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ def _read_test_table(self, read_builder):
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ return table_read.to_arrow(splits)
diff --git a/paimon-python/pypaimon/tests/reader_primary_key_test.py
b/paimon-python/pypaimon/tests/reader_primary_key_test.py
index 731203385d..c22346afe7 100644
--- a/paimon-python/pypaimon/tests/reader_primary_key_test.py
+++ b/paimon-python/pypaimon/tests/reader_primary_key_test.py
@@ -479,7 +479,7 @@ class PkReaderTest(unittest.TestCase):
# Create and start multiple threads
threads = []
- num_threads = 100
+ num_threads = 10
for i in range(num_threads):
thread = threading.Thread(
target=write_data,
diff --git a/paimon-python/pypaimon/tests/torch_read_test.py
b/paimon-python/pypaimon/tests/torch_read_test.py
new file mode 100644
index 0000000000..b6862c6cb1
--- /dev/null
+++ b/paimon-python/pypaimon/tests/torch_read_test.py
@@ -0,0 +1,635 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+import shutil
+import tempfile
+import unittest
+
+import pyarrow as pa
+from parameterized import parameterized
+from torch.utils.data import DataLoader
+
+from pypaimon import CatalogFactory, Schema
+
+from pypaimon.table.file_store_table import FileStoreTable
+
+
+class TorchReadTest(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.mkdtemp()
+ cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+ cls.catalog = CatalogFactory.create({
+ 'warehouse': cls.warehouse
+ })
+ cls.catalog.create_database('default', True)
+
+ cls.pa_schema = pa.schema([
+ ('user_id', pa.int32()),
+ ('item_id', pa.int64()),
+ ('behavior', pa.string()),
+ ('dt', pa.string())
+ ])
+ cls.expected = pa.Table.from_pydict({
+ 'user_id': [1, 2, 3, 4, 5, 6, 7, 8],
+ 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008],
+ 'behavior': ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'],
+ 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p2'],
+ }, schema=cls.pa_schema)
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+ @parameterized.expand([True, False])
+ def test_torch_read(self, is_streaming: bool = False):
+ schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['user_id'])
+
self.catalog.create_table(f'default.test_torch_read_{str(is_streaming)}',
schema, False)
+ table =
self.catalog.get_table(f'default.test_torch_read_{str(is_streaming)}')
+ self._write_test_table(table)
+
+ read_builder = table.new_read_builder().with_projection(['user_id',
'behavior'])
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+ dataset = table_read.to_torch(splits, streaming=is_streaming)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=2,
+ num_workers=2,
+ shuffle=False
+ )
+
+ # Collect all data from dataloader
+ all_user_ids = []
+ all_behaviors = []
+ for batch_idx, batch_data in enumerate(dataloader):
+ user_ids = batch_data['user_id'].tolist()
+ behaviors = batch_data['behavior']
+ all_user_ids.extend(user_ids)
+ all_behaviors.extend(behaviors)
+
+ # Sort by user_id for comparison
+ sorted_data = sorted(zip(all_user_ids, all_behaviors), key=lambda x:
x[0])
+ sorted_user_ids = [x[0] for x in sorted_data]
+ sorted_behaviors = [x[1] for x in sorted_data]
+
+ # Expected data (sorted by user_id)
+ expected_user_ids = [1, 2, 3, 4, 5, 6, 7, 8]
+ expected_behaviors = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
+
+ # Verify results
+ self.assertEqual(sorted_user_ids, expected_user_ids,
+ f"User IDs mismatch. Expected {expected_user_ids},
got {sorted_user_ids}")
+ self.assertEqual(sorted_behaviors, expected_behaviors,
+ f"Behaviors mismatch. Expected {expected_behaviors},
got {sorted_behaviors}")
+
+ print(f"✓ Test passed: Successfully read {len(all_user_ids)} rows with
correct data")
+
+ def test_blob_torch_read(self):
+ """Test end-to-end blob functionality using blob descriptors."""
+ import random
+ from pypaimon import Schema
+ from pypaimon.table.row.blob import BlobDescriptor
+
+ # Create schema with blob column
+ pa_schema = pa.schema([
+ ('id', pa.int32()),
+ ('picture', pa.large_binary()),
+ ])
+
+ schema = Schema.from_pyarrow_schema(
+ pa_schema,
+ options={
+ 'row-tracking.enabled': 'true',
+ 'data-evolution.enabled': 'true',
+ 'blob-as-descriptor': 'true'
+ }
+ )
+
+ # Create table
+ self.catalog.create_table('default.test_blob_torch_read', schema,
False)
+ table: FileStoreTable =
self.catalog.get_table('default.test_blob_torch_read')
+
+ # Create test blob data (1MB)
+ blob_data = bytearray(1024 * 1024)
+ random.seed(42) # For reproducible tests
+ for i in range(len(blob_data)):
+ blob_data[i] = random.randint(0, 255)
+ blob_data = bytes(blob_data)
+
+ # Create external blob file
+ external_blob_path = os.path.join(self.tempdir, 'external_blob')
+ with open(external_blob_path, 'wb') as f:
+ f.write(blob_data)
+
+ # Create blob descriptor pointing to external file
+ blob_descriptor = BlobDescriptor(external_blob_path, 0, len(blob_data))
+
+ # Create test data with blob descriptor
+ test_data = pa.Table.from_pydict({
+ 'id': [1],
+ 'picture': [blob_descriptor.serialize()]
+ }, schema=pa_schema)
+
+ # Write data using table API
+ write_builder = table.new_batch_write_builder()
+ writer = write_builder.new_write()
+ writer.write_arrow(test_data)
+
+ # Commit the data
+ commit_messages = writer.prepare_commit()
+ commit = write_builder.new_commit()
+ commit.commit(commit_messages)
+
+ # Read data back
+ read_builder = table.new_read_builder()
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ result = table_read.to_torch(table_scan.plan().splits())
+
+ dataloader = DataLoader(
+ result,
+ batch_size=1,
+ num_workers=0,
+ shuffle=False
+ )
+
+ # Collect and verify data
+ all_ids = []
+ all_pictures = []
+ for batch_idx, batch_data in enumerate(dataloader):
+ ids = batch_data['id'].tolist()
+ pictures = batch_data['picture']
+ all_ids.extend(ids)
+ all_pictures.extend(pictures)
+
+ # Verify results
+ self.assertEqual(len(all_ids), 1, "Should have exactly 1 row")
+ self.assertEqual(all_ids[0], 1, "ID should be 1")
+
+ # Verify blob descriptor
+ picture_bytes = all_pictures[0]
+ self.assertIsInstance(picture_bytes, bytes, "Picture should be bytes")
+
+ # Deserialize and verify blob descriptor
+ from pypaimon.table.row.blob import BlobDescriptor
+ read_blob_descriptor = BlobDescriptor.deserialize(picture_bytes)
+ self.assertEqual(read_blob_descriptor.length, len(blob_data),
+ f"Blob length mismatch. Expected {len(blob_data)},
got {read_blob_descriptor.length}")
+ self.assertGreaterEqual(read_blob_descriptor.offset, 0, "Offset should
be non-negative")
+
+ # Read and verify blob content
+ from pypaimon.common.uri_reader import UriReaderFactory
+ from pypaimon.common.options.config import CatalogOptions
+ from pypaimon.table.row.blob import Blob
+
+ catalog_options = {CatalogOptions.WAREHOUSE.key(): self.warehouse}
+ uri_reader_factory = UriReaderFactory(catalog_options)
+ uri_reader = uri_reader_factory.create(read_blob_descriptor.uri)
+ blob = Blob.from_descriptor(uri_reader, read_blob_descriptor)
+
+ # Verify blob data matches original
+ read_blob_data = blob.to_data()
+ self.assertEqual(len(read_blob_data), len(blob_data),
+ f"Blob data length mismatch. Expected
{len(blob_data)}, got {len(read_blob_data)}")
+ self.assertEqual(read_blob_data, blob_data, "Blob data content should
match original")
+
+ print(f"✓ Blob torch read test passed: Successfully read and verified
{len(blob_data)} bytes of blob data")
+
+ def test_torch_read_pk_table(self):
+ """Test torch read with primary key table."""
+ # Create PK table with user_id as primary key and behavior as
partition key
+ schema = Schema.from_pyarrow_schema(
+ self.pa_schema,
+ primary_keys=['user_id', 'behavior'],
+ partition_keys=['behavior'],
+ options={'bucket': 2}
+ )
+ self.catalog.create_table('default.test_pk_table', schema, False)
+ table = self.catalog.get_table('default.test_pk_table')
+ self._write_test_table(table)
+
+ read_builder = table.new_read_builder().with_projection(['user_id',
'behavior'])
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+ dataset = table_read.to_torch(splits, streaming=True)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=2,
+ num_workers=3,
+ shuffle=False
+ )
+
+ # Collect all data from dataloader
+ all_user_ids = []
+ all_behaviors = []
+ for batch_idx, batch_data in enumerate(dataloader):
+ user_ids = batch_data['user_id'].tolist()
+ behaviors = batch_data['behavior']
+ all_user_ids.extend(user_ids)
+ all_behaviors.extend(behaviors)
+
+ # Sort by user_id for comparison
+ sorted_data = sorted(zip(all_user_ids, all_behaviors), key=lambda x:
x[0])
+ sorted_user_ids = [x[0] for x in sorted_data]
+ sorted_behaviors = [x[1] for x in sorted_data]
+
+ # Expected data (sorted by user_id)
+ expected_user_ids = [1, 2, 3, 4, 5, 6, 7, 8]
+ expected_behaviors = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
+
+ # Verify results
+ self.assertEqual(sorted_user_ids, expected_user_ids,
+ f"User IDs mismatch. Expected {expected_user_ids},
got {sorted_user_ids}")
+ self.assertEqual(sorted_behaviors, expected_behaviors,
+ f"Behaviors mismatch. Expected {expected_behaviors},
got {sorted_behaviors}")
+
+ print(f"✓ PK table test passed: Successfully read {len(all_user_ids)}
rows with correct data")
+
+ def test_torch_read_large_append_table(self):
+ """Test torch read with large data volume on append-only table."""
+ # Create append-only table
+ schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'])
+ self.catalog.create_table('default.test_large_append', schema, False)
+ table = self.catalog.get_table('default.test_large_append')
+
+ # Write large amount of data
+ write_builder = table.new_batch_write_builder()
+ total_rows = 100000
+ batch_size = 10000
+ num_batches = total_rows // batch_size
+
+ print(f"\n{'=' * 60}")
+ print(f"Writing {total_rows} rows to append-only table...")
+ print(f"{'=' * 60}")
+
+ for batch_idx in range(num_batches):
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+
+ start_id = batch_idx * batch_size + 1
+ end_id = start_id + batch_size
+
+ data = {
+ 'user_id': list(range(start_id, end_id)),
+ 'item_id': [1000 + i for i in range(start_id, end_id)],
+ 'behavior': [chr(ord('a') + (i % 26)) for i in
range(batch_size)],
+ 'dt': [f'p{i % 4}' for i in range(batch_size)],
+ }
+ pa_table = pa.Table.from_pydict(data, schema=self.pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ if (batch_idx + 1) % 2 == 0:
+ print(f" Written {(batch_idx + 1) * batch_size} rows...")
+
+ # Read data using torch
+ print(f"\nReading {total_rows} rows using Torch DataLoader...")
+
+ read_builder = table.new_read_builder().with_projection(['user_id',
'behavior'])
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+
+ print(f"Total splits: {len(splits)}")
+
+ dataset = table_read.to_torch(splits, streaming=True)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=1000,
+ num_workers=4,
+ shuffle=False
+ )
+
+ # Collect all data
+ all_user_ids = []
+ batch_count = 0
+ for batch_idx, batch_data in enumerate(dataloader):
+ batch_count += 1
+ user_ids = batch_data['user_id'].tolist()
+ all_user_ids.extend(user_ids)
+
+ if (batch_idx + 1) % 20 == 0:
+ print(f" Read {len(all_user_ids)} rows...")
+
+ all_user_ids.sort()
+ # Verify data
+ self.assertEqual(len(all_user_ids), total_rows,
+ f"Row count mismatch. Expected {total_rows}, got
{len(all_user_ids)}")
+ self.assertEqual(all_user_ids, list(range(1, total_rows + 1)),
+ f"Row count mismatch. Expected {total_rows}, got
{len(all_user_ids)}")
+ print(f"\n{'=' * 60}")
+ print("✓ Large append table test passed!")
+ print(f" Total rows: {total_rows}")
+ print(f" Total batches: {batch_count}")
+ print(f"{'=' * 60}\n")
+
+ def test_torch_read_large_pk_table(self):
+ """Test torch read with large data volume on primary key table."""
+
+ # Create PK table
+ schema = Schema.from_pyarrow_schema(
+ self.pa_schema,
+ primary_keys=['user_id'],
+ partition_keys=['dt'],
+ options={'bucket': '4'}
+ )
+ self.catalog.create_table('default.test_large_pk', schema, False)
+ table = self.catalog.get_table('default.test_large_pk')
+
+ # Write large amount of data
+ write_builder = table.new_batch_write_builder()
+ total_rows = 100000
+ batch_size = 10000
+ num_batches = total_rows // batch_size
+
+ print(f"\n{'=' * 60}")
+ print(f"Writing {total_rows} rows to PK table...")
+ print(f"{'=' * 60}")
+
+ for batch_idx in range(num_batches):
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+
+ start_id = batch_idx * batch_size + 1
+ end_id = start_id + batch_size
+
+ data = {
+ 'user_id': list(range(start_id, end_id)),
+ 'item_id': [1000 + i for i in range(start_id, end_id)],
+ 'behavior': [chr(ord('a') + (i % 26)) for i in
range(batch_size)],
+ 'dt': [f'p{i % 4}' for i in range(batch_size)],
+ }
+ pa_table = pa.Table.from_pydict(data, schema=self.pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ if (batch_idx + 1) % 2 == 0:
+ print(f" Written {(batch_idx + 1) * batch_size} rows...")
+
+ # Read data using torch
+ print(f"\nReading {total_rows} rows using Torch DataLoader...")
+
+ read_builder = table.new_read_builder()
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+
+ print(f"Total splits: {len(splits)}")
+
+ dataset = table_read.to_torch(splits, streaming=True)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=1000,
+ num_workers=8,
+ shuffle=False
+ )
+
+ # Collect all data
+ all_user_ids = []
+ batch_count = 0
+ for batch_idx, batch_data in enumerate(dataloader):
+ batch_count += 1
+ user_ids = batch_data['user_id'].tolist()
+ all_user_ids.extend(user_ids)
+
+ if (batch_idx + 1) % 20 == 0:
+ print(f" Read {len(all_user_ids)} rows...")
+
+ all_user_ids.sort()
+ # Verify data
+ self.assertEqual(len(all_user_ids), total_rows,
+ f"Row count mismatch. Expected {total_rows}, got
{len(all_user_ids)}")
+
+ self.assertEqual(all_user_ids, list(range(1, total_rows + 1)),
+ f"Row count mismatch. Expected {total_rows}, got
{len(all_user_ids)}")
+
+ print(f"\n{'=' * 60}")
+ print("✓ Large PK table test passed!")
+ print(f" Total rows: {total_rows}")
+ print(f" Total batches: {batch_count}")
+ print(" Primary key uniqueness: ✓")
+ print(f"{'=' * 60}\n")
+
+ def test_torch_read_with_predicate(self):
+ """Test torch read with predicate filtering."""
+
+ schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['user_id'])
+ self.catalog.create_table('default.test_predicate', schema, False)
+ table = self.catalog.get_table('default.test_predicate')
+ self._write_test_table(table)
+
+ # Test case 1: Filter by user_id > 4
+ print(f"\n{'=' * 60}")
+ print("Test Case 1: user_id > 4")
+ print(f"{'=' * 60}")
+ predicate_builder = table.new_read_builder().new_predicate_builder()
+
+ predicate = predicate_builder.greater_than('user_id', 4)
+ read_builder = table.new_read_builder().with_filter(predicate)
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+ dataset = table_read.to_torch(splits, streaming=True)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=2,
+ num_workers=0,
+ shuffle=False
+ )
+
+ all_user_ids = []
+ for batch_idx, batch_data in enumerate(dataloader):
+ user_ids = batch_data['user_id'].tolist()
+ all_user_ids.extend(user_ids)
+
+ all_user_ids.sort()
+ expected_user_ids = [5, 6, 7, 8]
+ self.assertEqual(all_user_ids, expected_user_ids,
+ f"User IDs mismatch. Expected {expected_user_ids},
got {all_user_ids}")
+ print(f"✓ Filtered {len(all_user_ids)} rows: {all_user_ids}")
+
+ # Test case 2: Filter by user_id <= 3
+ print(f"\n{'=' * 60}")
+ print("Test Case 2: user_id <= 3")
+ print(f"{'=' * 60}")
+
+ predicate = predicate_builder.less_or_equal('user_id', 3)
+ read_builder = table.new_read_builder().with_filter(predicate)
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+ dataset = table_read.to_torch(splits, streaming=True)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=2,
+ num_workers=0,
+ shuffle=False
+ )
+
+ all_user_ids = []
+ for batch_idx, batch_data in enumerate(dataloader):
+ user_ids = batch_data['user_id'].tolist()
+ all_user_ids.extend(user_ids)
+
+ all_user_ids.sort()
+ expected_user_ids = [1, 2, 3]
+ self.assertEqual(all_user_ids, expected_user_ids,
+ f"User IDs mismatch. Expected {expected_user_ids},
got {all_user_ids}")
+ print(f"✓ Filtered {len(all_user_ids)} rows: {all_user_ids}")
+
+ # Test case 3: Filter by behavior = 'a'
+ print(f"\n{'=' * 60}")
+ print("Test Case 3: behavior = 'a'")
+ print(f"{'=' * 60}")
+
+ predicate = predicate_builder.equal('behavior', 'a')
+ read_builder = table.new_read_builder().with_filter(predicate)
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+ dataset = table_read.to_torch(splits, streaming=True)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=2,
+ num_workers=0,
+ shuffle=False
+ )
+
+ all_user_ids = []
+ all_behaviors = []
+ for batch_idx, batch_data in enumerate(dataloader):
+ user_ids = batch_data['user_id'].tolist()
+ behaviors = batch_data['behavior']
+ all_user_ids.extend(user_ids)
+ all_behaviors.extend(behaviors)
+
+ expected_user_ids = [1]
+ expected_behaviors = ['a']
+ self.assertEqual(all_user_ids, expected_user_ids,
+ f"User IDs mismatch. Expected {expected_user_ids},
got {all_user_ids}")
+ self.assertEqual(all_behaviors, expected_behaviors,
+ f"Behaviors mismatch. Expected {expected_behaviors},
got {all_behaviors}")
+ print(f"✓ Filtered {len(all_user_ids)} rows: user_ids={all_user_ids},
behaviors={all_behaviors}")
+
+ # Test case 4: Filter by user_id IN (2, 4, 6)
+ print(f"\n{'=' * 60}")
+ print("Test Case 4: user_id IN (2, 4, 6)")
+ print(f"{'=' * 60}")
+
+ predicate = predicate_builder.is_in('user_id', [2, 4, 6])
+ read_builder = table.new_read_builder().with_filter(predicate)
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+ dataset = table_read.to_torch(splits, streaming=True)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=2,
+ num_workers=0,
+ shuffle=False
+ )
+
+ all_user_ids = []
+ for batch_idx, batch_data in enumerate(dataloader):
+ user_ids = batch_data['user_id'].tolist()
+ all_user_ids.extend(user_ids)
+
+ all_user_ids.sort()
+ expected_user_ids = [2, 4, 6]
+ self.assertEqual(all_user_ids, expected_user_ids,
+ f"User IDs mismatch. Expected {expected_user_ids},
got {all_user_ids}")
+ print(f"✓ Filtered {len(all_user_ids)} rows: {all_user_ids}")
+
+ # Test case 5: Combined filter (user_id > 2 AND user_id < 7)
+ print(f"\n{'=' * 60}")
+ print("Test Case 5: user_id > 2 AND user_id < 7")
+ print(f"{'=' * 60}")
+
+ predicate1 = predicate_builder.greater_than('user_id', 2)
+ predicate2 = predicate_builder.less_than('user_id', 7)
+ combined_predicate = predicate_builder.and_predicates([predicate1,
predicate2])
+ read_builder = table.new_read_builder().with_filter(combined_predicate)
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+ dataset = table_read.to_torch(splits, streaming=True)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=2,
+ num_workers=0,
+ shuffle=False
+ )
+
+ all_user_ids = []
+ for batch_idx, batch_data in enumerate(dataloader):
+ user_ids = batch_data['user_id'].tolist()
+ all_user_ids.extend(user_ids)
+
+ all_user_ids.sort()
+ expected_user_ids = [3, 4, 5, 6]
+ self.assertEqual(all_user_ids, expected_user_ids,
+ f"User IDs mismatch. Expected {expected_user_ids},
got {all_user_ids}")
+ print(f"✓ Filtered {len(all_user_ids)} rows: {all_user_ids}")
+
+ print(f"\n{'=' * 60}")
+ print("✓ All predicate test cases passed!")
+ print(f"{'=' * 60}\n")
+
+ def _write_test_table(self, table):
+ write_builder = table.new_batch_write_builder()
+
+ # first write
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ data1 = {
+ 'user_id': [1, 2, 3, 4],
+ 'item_id': [1001, 1002, 1003, 1004],
+ 'behavior': ['a', 'b', 'c', 'd'],
+ 'dt': ['p1', 'p1', 'p2', 'p1'],
+ }
+ pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # second write
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ data2 = {
+ 'user_id': [5, 6, 7, 8],
+ 'item_id': [1005, 1006, 1007, 1008],
+ 'behavior': ['e', 'f', 'g', 'h'],
+ 'dt': ['p2', 'p1', 'p2', 'p1'],
+ }
+ pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ def _read_test_table(self, read_builder):
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ return table_read.to_arrow(splits)