This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new da8e246a19 [python] Support read paimon table as pytorch dataset 
(#6987)
da8e246a19 is described below

commit da8e246a197d055c273696bd486e5f167621ee5f
Author: umi <[email protected]>
AuthorDate: Fri Jan 9 19:19:01 2026 +0800

    [python] Support read paimon table as pytorch dataset (#6987)
---
 .github/workflows/paimon-python-checks.yml         |  82 +--
 docs/content/program-api/python-api.md             |  67 ++-
 paimon-python/dev/lint-python.sh                   |  31 +-
 paimon-python/dev/requirements.txt                 |  20 +-
 .../read/{ray_datasource.py => datasource.py}      | 138 ++++-
 paimon-python/pypaimon/read/table_read.py          |  15 +-
 paimon-python/pypaimon/tests/blob_table_test.py    |   2 +-
 .../pypaimon/tests/reader_append_only_test.py      |  78 +--
 .../pypaimon/tests/reader_primary_key_test.py      |   2 +-
 paimon-python/pypaimon/tests/torch_read_test.py    | 635 +++++++++++++++++++++
 10 files changed, 953 insertions(+), 117 deletions(-)

diff --git a/.github/workflows/paimon-python-checks.yml 
b/.github/workflows/paimon-python-checks.yml
index 4fb7fe07e4..ff2929c0fa 100755
--- a/.github/workflows/paimon-python-checks.yml
+++ b/.github/workflows/paimon-python-checks.yml
@@ -46,7 +46,7 @@ jobs:
     container: "python:${{ matrix.python-version }}-slim"
     strategy:
       matrix:
-        python-version: ['3.6.15', '3.10']
+        python-version: [ '3.6.15', '3.10' ]
 
     steps:
       - name: Checkout code
@@ -70,6 +70,7 @@ jobs:
             build-essential \
             git \
             curl \
+            && apt-get clean \
             && rm -rf /var/lib/apt/lists/*
 
       - name: Verify Java and Maven installation
@@ -88,21 +89,24 @@ jobs:
       - name: Install Python dependencies
         shell: bash
         run: |
+          df -h
           if [[ "${{ matrix.python-version }}" == "3.6.15" ]]; then
             python -m pip install --upgrade pip==21.3.1
             python --version
-            python -m pip install -q pyroaring readerwriterlock==1.0.9 
'fsspec==2021.10.1' 'cachetools==4.2.4' 'ossfs==2021.8.0' pyarrow==6.0.1 
pandas==1.1.5 'polars==0.9.12' 'fastavro==1.4.7' zstandard==0.19.0 
dataclasses==0.8.0 flake8 pytest py4j==0.10.9.9 requests parameterized==0.8.1 
2>&1 >/dev/null
+            python -m pip install --no-cache-dir pyroaring 
readerwriterlock==1.0.9 'fsspec==2021.10.1' 'cachetools==4.2.4' 
'ossfs==2021.8.0' pyarrow==6.0.1 pandas==1.1.5 'polars==0.9.12' 
'fastavro==1.4.7' zstandard==0.19.0 dataclasses==0.8.0 flake8 pytest 
py4j==0.10.9.9 requests parameterized==0.8.1 2>&1 >/dev/null
           else
             python -m pip install --upgrade pip
-            python -m pip install -q pyroaring readerwriterlock==1.0.9 
fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0 
fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 
numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0 
py4j==0.10.9.9 requests parameterized==0.9.0 2>&1 >/dev/null
+            pip install torch --index-url https://download.pytorch.org/whl/cpu
+            python -m pip install pyroaring readerwriterlock==1.0.9 
fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0 
fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 
numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0 
py4j==0.10.9.9 requests parameterized==0.9.0
           fi
+          df -h
       - name: Run lint-python.sh
         shell: bash
         run: |
           chmod +x paimon-python/dev/lint-python.sh
-          ./paimon-python/dev/lint-python.sh
+          ./paimon-python/dev/lint-python.sh -e pytest_torch
 
-  requirement_version_compatible_test:
+  torch_test:
     runs-on: ubuntu-latest
     container: "python:3.10-slim"
 
@@ -110,17 +114,6 @@ jobs:
       - name: Checkout code
         uses: actions/checkout@v2
 
-      - name: Set up JDK ${{ env.JDK_VERSION }}
-        uses: actions/setup-java@v4
-        with:
-          java-version: ${{ env.JDK_VERSION }}
-          distribution: 'temurin'
-
-      - name: Set up Maven
-        uses: stCarolas/[email protected]
-        with:
-          maven-version: 3.8.8
-
       - name: Install system dependencies
         shell: bash
         run: |
@@ -128,26 +121,50 @@ jobs:
             build-essential \
             git \
             curl \
+            && apt-get clean \
             && rm -rf /var/lib/apt/lists/*
 
-      - name: Verify Java and Maven installation
-        run: |
-          java -version
-          mvn -version
-
       - name: Verify Python version
         run: python --version
 
-      - name: Build Java
+      - name: Install Python dependencies
+        shell: bash
         run: |
-          echo "Start compiling modules"
-          mvn -T 2C -B clean install -DskipTests
+            python -m pip install --upgrade pip
+            pip install torch --index-url https://download.pytorch.org/whl/cpu
+            python -m pip install pyroaring readerwriterlock==1.0.9 
fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0 
fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 
numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0 
py4j==0.10.9.9 requests parameterized==0.9.0
+      - name: Run lint-python.sh
+        shell: bash
+        run: |
+          chmod +x paimon-python/dev/lint-python.sh
+          ./paimon-python/dev/lint-python.sh -i pytest_torch
+
+  requirement_version_compatible_test:
+    runs-on: ubuntu-latest
+    container: "python:3.10-slim"
+
+    steps:
+      - name: Checkout code
+        uses: actions/checkout@v2
+
+      - name: Install system dependencies
+        shell: bash
+        run: |
+          apt-get update && apt-get install -y \
+            build-essential \
+            git \
+            curl \
+            && rm -rf /var/lib/apt/lists/*
+
+      - name: Verify Python version
+        run: python --version
 
       - name: Install base Python dependencies
         shell: bash
         run: |
           python -m pip install --upgrade pip
-          python -m pip install -q \
+          pip install torch --index-url https://download.pytorch.org/whl/cpu
+          python -m pip install --no-cache-dir \
             pyroaring \
             readerwriterlock==1.0.9 \
             fsspec==2024.3.1 \
@@ -165,36 +182,37 @@ jobs:
             requests \
             parameterized==0.9.0 \
             packaging
+            
 
       - name: Test requirement version compatibility
         shell: bash
         run: |
           cd paimon-python
-          
+
           # Test Ray version compatibility
           echo "=========================================="
           echo "Testing Ray version compatibility"
           echo "=========================================="
           for ray_version in 2.44.0 2.48.0 2.53.0; do
             echo "Testing Ray version: $ray_version"
-            
+
             # Install specific Ray version
-            python -m pip install -q ray==$ray_version
-            
+            python -m pip install --no-cache-dir -q ray==$ray_version
+
             # Verify Ray version
             python -c "import ray; print(f'Ray version: {ray.__version__}')"
             python -c "from packaging.version import parse; import ray; assert 
parse(ray.__version__) == parse('$ray_version'), f'Expected Ray $ray_version, 
got {ray.__version__}'"
-            
+
             # Run tests
             python -m pytest pypaimon/tests/ray_data_test.py::RayDataTest -v 
--tb=short || {
               echo "Tests failed for Ray $ray_version"
               exit 1
             }
-            
+
             # Uninstall Ray to avoid conflicts
             python -m pip uninstall -y ray
           done
-          
+
           # Add other dependency version tests here in the future
           # Example:
           # echo "=========================================="
diff --git a/docs/content/program-api/python-api.md 
b/docs/content/program-api/python-api.md
index 406c8c1ef6..aa7773e049 100644
--- a/docs/content/program-api/python-api.md
+++ b/docs/content/program-api/python-api.md
@@ -72,6 +72,7 @@ catalog_options = {
 }
 catalog = CatalogFactory.create(catalog_options)
 ```
+
 {{< /tab >}}
 {{< /tabs >}}
 
@@ -473,6 +474,38 @@ ray_dataset = table_read.to_ray(splits)
 
 See [Ray Data API 
Documentation](https://docs.ray.io/en/latest/data/api/doc/ray.data.read_datasource.html)
 for more details.
 
+### Read Pytorch Dataset
+
+This requires `torch` to be installed.
+
+You can read all the data into a `torch.utils.data.Dataset` or 
`torch.utils.data.IterableDataset`:
+
+```python
+from torch.utils.data import DataLoader
+
+table_read = read_builder.new_read()
+dataset = table_read.to_torch(splits, streaming=True)
+dataloader = DataLoader(
+    dataset,
+    batch_size=2,
+    num_workers=2,  # Concurrency to read data
+    shuffle=False
+)
+
+# Collect all data from dataloader
+for batch_idx, batch_data in enumerate(dataloader):
+    print(batch_data)
+
+# output:
+#   {'user_id': tensor([1, 2]), 'behavior': ['a', 'b']}
+#   {'user_id': tensor([3, 4]), 'behavior': ['c', 'd']}
+#   {'user_id': tensor([5, 6]), 'behavior': ['e', 'f']}
+#   {'user_id': tensor([7, 8]), 'behavior': ['g', 'h']}
+```
+
+When the `streaming` parameter is true, it will iteratively read;
+when it is false, it will read the full amount of data into memory.
+
 ### Incremental Read
 
 This API allows reading data committed between two snapshot timestamps. The 
steps are as follows.
@@ -671,22 +704,22 @@ Key points about shard read:
 The following shows the supported features of Python Paimon compared to Java 
Paimon:
 
 **Catalog Level**
-   - FileSystemCatalog
-   - RestCatalog
+  - FileSystemCatalog
+  - RestCatalog
 
 **Table Level**
-   - Append Tables
-     - `bucket = -1` (unaware)
-     - `bucket > 0` (fixed)
-   - Primary Key Tables
-     - only support deduplicate
-     - `bucket = -2` (postpone)
-     - `bucket > 0` (fixed)
-     - read with deletion vectors enabled
-   - Read/Write Operations
-     - Batch read and write for append tables and primary key tables
-     - Predicate filtering
-     - Overwrite semantics
-     - Incremental reading of Delta data
-     - Reading and writing blob data
-     - `with_shard` feature
+  - Append Tables
+    - `bucket = -1` (unaware)
+    - `bucket > 0` (fixed)
+  - Primary Key Tables
+      - only support deduplicate
+      - `bucket = -2` (postpone)
+      - `bucket > 0` (fixed)
+      - read with deletion vectors enabled
+  - Read/Write Operations
+      - Batch read and write for append tables and primary key tables
+      - Predicate filtering
+      - Overwrite semantics
+      - Incremental reading of Delta data
+      - Reading and writing blob data
+      - `with_shard` feature
diff --git a/paimon-python/dev/lint-python.sh b/paimon-python/dev/lint-python.sh
index d174b120ad..44be287149 100755
--- a/paimon-python/dev/lint-python.sh
+++ b/paimon-python/dev/lint-python.sh
@@ -107,7 +107,7 @@ function collect_checks() {
 function get_all_supported_checks() {
     _OLD_IFS=$IFS
     IFS=$'\n'
-    SUPPORT_CHECKS=("flake8_check" "pytest_check" "mixed_check") # control the 
calling sequence
+    SUPPORT_CHECKS=("flake8_check" "pytest_torch_check" "pytest_check" 
"mixed_check") # control the calling sequence
     for fun in $(declare -F); do
         if [[ `regexp_match "$fun" "_check$"` = true ]]; then
             check_name="${fun:11}"
@@ -179,7 +179,7 @@ function pytest_check() {
         TEST_DIR="pypaimon/tests/py36"
         echo "Running tests for Python 3.6: $TEST_DIR"
     else
-        TEST_DIR="pypaimon/tests --ignore=pypaimon/tests/py36 
--ignore=pypaimon/tests/e2e"
+        TEST_DIR="pypaimon/tests --ignore=pypaimon/tests/py36 
--ignore=pypaimon/tests/e2e --ignore=pypaimon/tests/torch_read_test.py"
         echo "Running tests for Python $PYTHON_VERSION (excluding py36): 
pypaimon/tests --ignore=pypaimon/tests/py36"
     fi
 
@@ -197,7 +197,32 @@ function pytest_check() {
         print_function "STAGE" "pytest checks... [SUCCESS]"
     fi
 }
+function pytest_torch_check() {
+    print_function "STAGE" "pytest torch checks"
+    if [ ! -f "$PYTEST_PATH" ]; then
+        echo "For some unknown reasons, the pytest package is not complete."
+    fi
 
+    # Get Python version
+    PYTHON_VERSION=$(python -c "import sys; 
print(f'{sys.version_info.major}.{sys.version_info.minor}')")
+    echo "Detected Python version: $PYTHON_VERSION"
+    TEST_DIR="pypaimon/tests/torch_read_test.py"
+    echo "Running tests for Python $PYTHON_VERSION: 
pypaimon/tests/torch_read_test.py"
+
+    # the return value of a pipeline is the status of the last command to exit
+    # with a non-zero status or zero if no command exited with a non-zero 
status
+    set -o pipefail
+    ($PYTEST_PATH $TEST_DIR) 2>&1 | tee -a $LOG_FILE
+
+    PYCODESTYLE_STATUS=$?
+    if [ $PYCODESTYLE_STATUS -ne 0 ]; then
+        print_function "STAGE" "pytest checks... [FAILED]"
+        # Stop the running script.
+        exit 1;
+    else
+        print_function "STAGE" "pytest checks... [SUCCESS]"
+    fi
+}
 # Mixed tests check - runs Java-Python interoperability tests
 function mixed_check() {
     # Get Python version
@@ -279,7 +304,7 @@ usage: $0 [options]
 -l          list all checks supported.
 Examples:
   ./lint-python.sh                 =>  exec all checks.
-  ./lint-python.sh -e tox,flake8   =>  exclude checks tox,flake8.
+  ./lint-python.sh -e flake8       =>  exclude checks flake8.
   ./lint-python.sh -i flake8       =>  include checks flake8.
   ./lint-python.sh -i mixed        =>  include checks mixed.
   ./lint-python.sh -l              =>  list all checks supported.
diff --git a/paimon-python/dev/requirements.txt 
b/paimon-python/dev/requirements.txt
index 703adec8e8..e76827db3e 100644
--- a/paimon-python/dev/requirements.txt
+++ b/paimon-python/dev/requirements.txt
@@ -19,27 +19,23 @@
 cachetools>=4.2,<6; python_version=="3.6"
 cachetools>=5,<6; python_version>"3.6"
 dataclasses>=0.8; python_version < "3.7"
-fastavro>=1.4,<2; python_version<"3.9"
-fastavro>=1.4,<2; python_version>="3.9"
+fastavro>=1.4,<2
 fsspec>=2021.10,<2026; python_version<"3.8"
 fsspec>=2023,<2026; python_version>="3.8"
 ossfs>=2021.8; python_version<"3.8"
 ossfs>=2023; python_version>="3.8"
-packaging>=21,<26; python_version<"3.8"
-packaging>=21,<26; python_version>="3.8"
+packaging>=21,<26
 pandas>=1.1,<2; python_version < "3.7"
 pandas>=1.3,<3; python_version >= "3.7" and python_version < "3.9"
 pandas>=1.5,<3; python_version >= "3.9"
 polars>=0.9,<1; python_version<"3.8"
-polars>=1,<2; python_version=="3.8"
-polars>=1,<2; python_version>"3.8"
+polars>=1,<2; python_version>="3.8"
 pyarrow>=6,<7; python_version < "3.8"
-pyarrow>=16,<20; python_version >= "3.8" and python_version < "3.13"
-pyarrow>=16,<20; python_version >= "3.13"
+pyarrow>=16,<20; python_version >= "3.8"
+pylance>=0.20,<1; python_version>="3.9"
+pylance>=0.10,<1; python_version>="3.8" and python_version<"3.9"
 pyroaring
 ray>=2.10,<3
 readerwriterlock>=1,<2
-zstandard>=0.19,<1; python_version<"3.9"
-zstandard>=0.19,<1; python_version>="3.9"
-pylance>=0.20,<1; python_version>="3.9"
-pylance>=0.10,<1; python_version>="3.8" and python_version<"3.9"
+torch
+zstandard>=0.19,<1
\ No newline at end of file
diff --git a/paimon-python/pypaimon/read/ray_datasource.py 
b/paimon-python/pypaimon/read/datasource.py
similarity index 67%
rename from paimon-python/pypaimon/read/ray_datasource.py
rename to paimon-python/pypaimon/read/datasource.py
index 905c8bddef..835effbf0b 100644
--- a/paimon-python/pypaimon/read/ray_datasource.py
+++ b/paimon-python/pypaimon/read/datasource.py
@@ -27,6 +27,7 @@ from typing import List, Optional, Iterable
 import pyarrow
 from packaging.version import parse
 import ray
+import torch
 
 from pypaimon.read.split import Split
 from pypaimon.read.table_read import TableRead
@@ -40,8 +41,10 @@ RAY_VERSION_PER_TASK_ROW_LIMIT = "2.52.0"  # 
per_task_row_limit parameter introd
 
 from ray.data.datasource import Datasource
 
+from torch.utils.data import Dataset, IterableDataset
 
-class PaimonDatasource(Datasource):
+
+class RayDatasource(Datasource):
     """
     Ray Data Datasource implementation for reading Paimon tables.
 
@@ -76,7 +79,7 @@ class PaimonDatasource(Datasource):
 
     @staticmethod
     def _distribute_splits_into_equal_chunks(
-        splits: Iterable[Split], n_chunks: int
+            splits: Iterable[Split], n_chunks: int
     ) -> List[List[Split]]:
         """
         Implement a greedy knapsack algorithm to distribute the splits across 
tasks,
@@ -88,7 +91,7 @@ class PaimonDatasource(Datasource):
 
         # From largest to smallest, add the splits to the smallest chunk one 
at a time
         for split in sorted(
-            splits, key=lambda s: s.file_size if hasattr(s, 'file_size') and 
s.file_size > 0 else 0, reverse=True
+                splits, key=lambda s: s.file_size if hasattr(s, 'file_size') 
and s.file_size > 0 else 0, reverse=True
         ):
             smallest_chunk = heapq.heappop(chunk_sizes)
             chunks[smallest_chunk[1]].append(split)
@@ -132,11 +135,11 @@ class PaimonDatasource(Datasource):
         # Create a partial function to avoid capturing self in closure
         # This reduces serialization overhead (see 
https://github.com/ray-project/ray/issues/49107)
         def _get_read_task(
-            splits: List[Split],
-            table=table,
-            predicate=predicate,
-            read_type=read_type,
-            schema=schema,
+                splits: List[Split],
+                table=table,
+                predicate=predicate,
+                read_type=read_type,
+                schema=schema,
         ) -> Iterable[pyarrow.Table]:
             """Read function that will be executed by Ray workers."""
             from pypaimon.read.table_read import TableRead
@@ -216,13 +219,128 @@ class PaimonDatasource(Datasource):
                 'read_fn': read_fn,
                 'metadata': metadata,
             }
-            
+
             if parse(ray.__version__) >= 
parse(RAY_VERSION_SCHEMA_IN_READ_TASK):
                 read_task_kwargs['schema'] = schema
-            
+
             if parse(ray.__version__) >= parse(RAY_VERSION_PER_TASK_ROW_LIMIT) 
and per_task_row_limit is not None:
                 read_task_kwargs['per_task_row_limit'] = per_task_row_limit
 
             read_tasks.append(ReadTask(**read_task_kwargs))
 
         return read_tasks
+
+
+class TorchDataset(Dataset):
+    """
+    PyTorch Dataset implementation for reading Paimon table data.
+
+    This class enables Paimon table data to be used directly with PyTorch's
+    training pipeline, allowing for efficient data loading and batching.
+    """
+
+    def __init__(self, table_read: TableRead, splits: List[Split]):
+        """
+        Initialize TorchDataset.
+
+        Args:
+            table_read: TableRead instance for reading data
+            splits: List of splits to read
+        """
+        arrow_table = table_read.to_arrow(splits)
+        if arrow_table is None or arrow_table.num_rows == 0:
+            self._data = []
+        else:
+            self._data = arrow_table.to_pylist()
+
+    def __len__(self) -> int:
+        """
+        Return the total number of rows in the dataset.
+
+        Returns:
+            Total number of rows across all splits
+        """
+        return len(self._data)
+
+    def __getitem__(self, index: int):
+        """
+        Get a single item from the dataset.
+
+        Args:
+            index: Index of the item to retrieve
+
+        Returns:
+            Dictionary containing the row data
+        """
+        if not self._data:
+            return None
+
+        return self._data[index]
+
+
+class TorchIterDataset(IterableDataset):
+    """
+    PyTorch IterableDataset implementation for reading Paimon table data.
+
+    This class enables streaming data loading from Paimon tables, which is more
+    memory-efficient for large datasets. Data is read on-the-fly as needed,
+    rather than loading everything into memory upfront.
+    """
+
+    def __init__(self, table_read: TableRead, splits: List[Split]):
+        """
+        Initialize TorchIterDataset.
+
+        Args:
+            table_read: TableRead instance for reading data
+            splits: List of splits to read
+        """
+        self.table_read = table_read
+        self.splits = splits
+        # Get field names from read_type
+        self.field_names = [field.name for field in table_read.read_type]
+
+    def __iter__(self):
+        """
+        Iterate over the dataset, converting each OffsetRow to a dictionary.
+
+        Supports multi-worker data loading by partitioning splits across 
workers.
+        When num_workers > 0 in DataLoader, each worker will process a subset 
of splits.
+
+        Yields:
+            row data of dict type, where keys are column names
+        """
+        worker_info = torch.utils.data.get_worker_info()
+
+        if worker_info is None:
+            # Single-process data loading, iterate over all splits
+            splits_to_process = self.splits
+        else:
+            # Multi-process data loading, partition splits across workers
+            worker_id = worker_info.id
+            num_workers = worker_info.num_workers
+
+            # Calculate start and end indices for this worker
+            # Distribute splits evenly by slicing
+            total_splits = len(self.splits)
+            splits_per_worker = total_splits // num_workers
+            remainder = total_splits % num_workers
+
+            # Workers with id < remainder get one extra split
+            if worker_id < remainder:
+                start_idx = worker_id * (splits_per_worker + 1)
+                end_idx = start_idx + splits_per_worker + 1
+            else:
+                start_idx = worker_id * splits_per_worker + remainder
+                end_idx = start_idx + splits_per_worker
+
+            splits_to_process = self.splits[start_idx:end_idx]
+
+        worker_iterator = self.table_read.to_iterator(splits_to_process)
+
+        for offset_row in worker_iterator:
+            row_dict = {}
+            for i, field_name in enumerate(self.field_names):
+                value = offset_row.get_field(i)
+                row_dict[field_name] = value
+            yield row_dict
diff --git a/paimon-python/pypaimon/read/table_read.py 
b/paimon-python/pypaimon/read/table_read.py
index 953384cc7d..7e8dbda412 100644
--- a/paimon-python/pypaimon/read/table_read.py
+++ b/paimon-python/pypaimon/read/table_read.py
@@ -165,8 +165,8 @@ class TableRead:
         if override_num_blocks is not None and override_num_blocks < 1:
             raise ValueError(f"override_num_blocks must be at least 1, got 
{override_num_blocks}")
 
-        from pypaimon.read.ray_datasource import PaimonDatasource
-        datasource = PaimonDatasource(self, splits)
+        from pypaimon.read.datasource import RayDatasource
+        datasource = RayDatasource(self, splits)
         return ray.data.read_datasource(
             datasource,
             ray_remote_args=ray_remote_args,
@@ -175,6 +175,17 @@ class TableRead:
             **read_args
         )
 
+    def to_torch(self, splits: List[Split], streaming: bool = False) -> 
"torch.utils.data.Dataset":
+        """Wrap Paimon table data to PyTorch Dataset."""
+        if streaming:
+            from pypaimon.read.datasource import TorchIterDataset
+            dataset = TorchIterDataset(self, splits)
+            return dataset
+        else:
+            from pypaimon.read.datasource import TorchDataset
+            dataset = TorchDataset(self, splits)
+            return dataset
+
     def _create_split_read(self, split: Split) -> SplitRead:
         if self.table.is_primary_key_table and not split.raw_convertible:
             return MergeFileSplitRead(
diff --git a/paimon-python/pypaimon/tests/blob_table_test.py 
b/paimon-python/pypaimon/tests/blob_table_test.py
index f87f73ded7..9925e21be5 100755
--- a/paimon-python/pypaimon/tests/blob_table_test.py
+++ b/paimon-python/pypaimon/tests/blob_table_test.py
@@ -2644,7 +2644,7 @@ class DataBlobWriterTest(unittest.TestCase):
 
             # Create and start multiple threads
             threads = []
-            num_threads = 100
+            num_threads = 10
             for i in range(num_threads):
                 thread = threading.Thread(
                     target=write_blob_data,
diff --git a/paimon-python/pypaimon/tests/reader_append_only_test.py 
b/paimon-python/pypaimon/tests/reader_append_only_test.py
index d65658ef5c..adb0ff4f25 100644
--- a/paimon-python/pypaimon/tests/reader_append_only_test.py
+++ b/paimon-python/pypaimon/tests/reader_append_only_test.py
@@ -438,44 +438,6 @@ class AoReaderTest(unittest.TestCase):
         }, schema=self.pa_schema).sort_by('user_id')
         self.assertEqual(expected, actual)
 
-    def _write_test_table(self, table):
-        write_builder = table.new_batch_write_builder()
-
-        # first write
-        table_write = write_builder.new_write()
-        table_commit = write_builder.new_commit()
-        data1 = {
-            'user_id': [1, 2, 3, 4],
-            'item_id': [1001, 1002, 1003, 1004],
-            'behavior': ['a', 'b', 'c', None],
-            'dt': ['p1', 'p1', 'p2', 'p1'],
-        }
-        pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
-        table_write.write_arrow(pa_table)
-        table_commit.commit(table_write.prepare_commit())
-        table_write.close()
-        table_commit.close()
-
-        # second write
-        table_write = write_builder.new_write()
-        table_commit = write_builder.new_commit()
-        data2 = {
-            'user_id': [5, 6, 7, 8],
-            'item_id': [1005, 1006, 1007, 1008],
-            'behavior': ['e', 'f', 'g', 'h'],
-            'dt': ['p2', 'p1', 'p2', 'p2'],
-        }
-        pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema)
-        table_write.write_arrow(pa_table)
-        table_commit.commit(table_write.prepare_commit())
-        table_write.close()
-        table_commit.close()
-
-    def _read_test_table(self, read_builder):
-        table_read = read_builder.new_read()
-        splits = read_builder.new_scan().plan().splits()
-        return table_read.to_arrow(splits)
-
     def test_concurrent_writes_with_retry(self):
         """Test concurrent writes to verify retry mechanism works correctly."""
         import threading
@@ -529,7 +491,7 @@ class AoReaderTest(unittest.TestCase):
 
             # Create and start multiple threads
             threads = []
-            num_threads = 100
+            num_threads = 10
             for i in range(num_threads):
                 thread = threading.Thread(
                     target=write_data,
@@ -576,3 +538,41 @@ class AoReaderTest(unittest.TestCase):
                              f"got {latest_snapshot.id}")
 
             print(f"✓ Iteration {test_iteration + 1}/{iter_num} completed 
successfully")
+
+    def _write_test_table(self, table):
+        write_builder = table.new_batch_write_builder()
+
+        # first write
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data1 = {
+            'user_id': [1, 2, 3, 4],
+            'item_id': [1001, 1002, 1003, 1004],
+            'behavior': ['a', 'b', 'c', None],
+            'dt': ['p1', 'p1', 'p2', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
+        # second write
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data2 = {
+            'user_id': [5, 6, 7, 8],
+            'item_id': [1005, 1006, 1007, 1008],
+            'behavior': ['e', 'f', 'g', 'h'],
+            'dt': ['p2', 'p1', 'p2', 'p2'],
+        }
+        pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
+    def _read_test_table(self, read_builder):
+        table_read = read_builder.new_read()
+        splits = read_builder.new_scan().plan().splits()
+        return table_read.to_arrow(splits)
diff --git a/paimon-python/pypaimon/tests/reader_primary_key_test.py 
b/paimon-python/pypaimon/tests/reader_primary_key_test.py
index 731203385d..c22346afe7 100644
--- a/paimon-python/pypaimon/tests/reader_primary_key_test.py
+++ b/paimon-python/pypaimon/tests/reader_primary_key_test.py
@@ -479,7 +479,7 @@ class PkReaderTest(unittest.TestCase):
 
             # Create and start multiple threads
             threads = []
-            num_threads = 100
+            num_threads = 10
             for i in range(num_threads):
                 thread = threading.Thread(
                     target=write_data,
diff --git a/paimon-python/pypaimon/tests/torch_read_test.py 
b/paimon-python/pypaimon/tests/torch_read_test.py
new file mode 100644
index 0000000000..b6862c6cb1
--- /dev/null
+++ b/paimon-python/pypaimon/tests/torch_read_test.py
@@ -0,0 +1,635 @@
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing,
+#  software distributed under the License is distributed on an
+#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+#  KIND, either express or implied.  See the License for the
+#  specific language governing permissions and limitations
+#  under the License.
+import os
+import shutil
+import tempfile
+import unittest
+
+import pyarrow as pa
+from parameterized import parameterized
+from torch.utils.data import DataLoader
+
+from pypaimon import CatalogFactory, Schema
+
+from pypaimon.table.file_store_table import FileStoreTable
+
+
+class TorchReadTest(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        cls.tempdir = tempfile.mkdtemp()
+        cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+        cls.catalog = CatalogFactory.create({
+            'warehouse': cls.warehouse
+        })
+        cls.catalog.create_database('default', True)
+
+        cls.pa_schema = pa.schema([
+            ('user_id', pa.int32()),
+            ('item_id', pa.int64()),
+            ('behavior', pa.string()),
+            ('dt', pa.string())
+        ])
+        cls.expected = pa.Table.from_pydict({
+            'user_id': [1, 2, 3, 4, 5, 6, 7, 8],
+            'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008],
+            'behavior': ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'],
+            'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p2'],
+        }, schema=cls.pa_schema)
+
+    @classmethod
+    def tearDownClass(cls):
+        shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+    @parameterized.expand([True, False])
+    def test_torch_read(self, is_streaming: bool = False):
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['user_id'])
+        
self.catalog.create_table(f'default.test_torch_read_{str(is_streaming)}', 
schema, False)
+        table = 
self.catalog.get_table(f'default.test_torch_read_{str(is_streaming)}')
+        self._write_test_table(table)
+
+        read_builder = table.new_read_builder().with_projection(['user_id', 
'behavior'])
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        splits = table_scan.plan().splits()
+        dataset = table_read.to_torch(splits, streaming=is_streaming)
+        dataloader = DataLoader(
+            dataset,
+            batch_size=2,
+            num_workers=2,
+            shuffle=False
+        )
+
+        # Collect all data from dataloader
+        all_user_ids = []
+        all_behaviors = []
+        for batch_idx, batch_data in enumerate(dataloader):
+            user_ids = batch_data['user_id'].tolist()
+            behaviors = batch_data['behavior']
+            all_user_ids.extend(user_ids)
+            all_behaviors.extend(behaviors)
+
+        # Sort by user_id for comparison
+        sorted_data = sorted(zip(all_user_ids, all_behaviors), key=lambda x: 
x[0])
+        sorted_user_ids = [x[0] for x in sorted_data]
+        sorted_behaviors = [x[1] for x in sorted_data]
+
+        # Expected data (sorted by user_id)
+        expected_user_ids = [1, 2, 3, 4, 5, 6, 7, 8]
+        expected_behaviors = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
+
+        # Verify results
+        self.assertEqual(sorted_user_ids, expected_user_ids,
+                         f"User IDs mismatch. Expected {expected_user_ids}, 
got {sorted_user_ids}")
+        self.assertEqual(sorted_behaviors, expected_behaviors,
+                         f"Behaviors mismatch. Expected {expected_behaviors}, 
got {sorted_behaviors}")
+
+        print(f"✓ Test passed: Successfully read {len(all_user_ids)} rows with 
correct data")
+
+    def test_blob_torch_read(self):
+        """Test end-to-end blob functionality using blob descriptors."""
+        import random
+        from pypaimon import Schema
+        from pypaimon.table.row.blob import BlobDescriptor
+
+        # Create schema with blob column
+        pa_schema = pa.schema([
+            ('id', pa.int32()),
+            ('picture', pa.large_binary()),
+        ])
+
+        schema = Schema.from_pyarrow_schema(
+            pa_schema,
+            options={
+                'row-tracking.enabled': 'true',
+                'data-evolution.enabled': 'true',
+                'blob-as-descriptor': 'true'
+            }
+        )
+
+        # Create table
+        self.catalog.create_table('default.test_blob_torch_read', schema, 
False)
+        table: FileStoreTable = 
self.catalog.get_table('default.test_blob_torch_read')
+
+        # Create test blob data (1MB)
+        blob_data = bytearray(1024 * 1024)
+        random.seed(42)  # For reproducible tests
+        for i in range(len(blob_data)):
+            blob_data[i] = random.randint(0, 255)
+        blob_data = bytes(blob_data)
+
+        # Create external blob file
+        external_blob_path = os.path.join(self.tempdir, 'external_blob')
+        with open(external_blob_path, 'wb') as f:
+            f.write(blob_data)
+
+        # Create blob descriptor pointing to external file
+        blob_descriptor = BlobDescriptor(external_blob_path, 0, len(blob_data))
+
+        # Create test data with blob descriptor
+        test_data = pa.Table.from_pydict({
+            'id': [1],
+            'picture': [blob_descriptor.serialize()]
+        }, schema=pa_schema)
+
+        # Write data using table API
+        write_builder = table.new_batch_write_builder()
+        writer = write_builder.new_write()
+        writer.write_arrow(test_data)
+
+        # Commit the data
+        commit_messages = writer.prepare_commit()
+        commit = write_builder.new_commit()
+        commit.commit(commit_messages)
+
+        # Read data back
+        read_builder = table.new_read_builder()
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        result = table_read.to_torch(table_scan.plan().splits())
+
+        dataloader = DataLoader(
+            result,
+            batch_size=1,
+            num_workers=0,
+            shuffle=False
+        )
+
+        # Collect and verify data
+        all_ids = []
+        all_pictures = []
+        for batch_idx, batch_data in enumerate(dataloader):
+            ids = batch_data['id'].tolist()
+            pictures = batch_data['picture']
+            all_ids.extend(ids)
+            all_pictures.extend(pictures)
+
+        # Verify results
+        self.assertEqual(len(all_ids), 1, "Should have exactly 1 row")
+        self.assertEqual(all_ids[0], 1, "ID should be 1")
+
+        # Verify blob descriptor
+        picture_bytes = all_pictures[0]
+        self.assertIsInstance(picture_bytes, bytes, "Picture should be bytes")
+
+        # Deserialize and verify blob descriptor
+        from pypaimon.table.row.blob import BlobDescriptor
+        read_blob_descriptor = BlobDescriptor.deserialize(picture_bytes)
+        self.assertEqual(read_blob_descriptor.length, len(blob_data),
+                         f"Blob length mismatch. Expected {len(blob_data)}, 
got {read_blob_descriptor.length}")
+        self.assertGreaterEqual(read_blob_descriptor.offset, 0, "Offset should 
be non-negative")
+
+        # Read and verify blob content
+        from pypaimon.common.uri_reader import UriReaderFactory
+        from pypaimon.common.options.config import CatalogOptions
+        from pypaimon.table.row.blob import Blob
+
+        catalog_options = {CatalogOptions.WAREHOUSE.key(): self.warehouse}
+        uri_reader_factory = UriReaderFactory(catalog_options)
+        uri_reader = uri_reader_factory.create(read_blob_descriptor.uri)
+        blob = Blob.from_descriptor(uri_reader, read_blob_descriptor)
+
+        # Verify blob data matches original
+        read_blob_data = blob.to_data()
+        self.assertEqual(len(read_blob_data), len(blob_data),
+                         f"Blob data length mismatch. Expected 
{len(blob_data)}, got {len(read_blob_data)}")
+        self.assertEqual(read_blob_data, blob_data, "Blob data content should 
match original")
+
+        print(f"✓ Blob torch read test passed: Successfully read and verified 
{len(blob_data)} bytes of blob data")
+
+    def test_torch_read_pk_table(self):
+        """Test torch read with primary key table."""
+        # Create PK table with user_id as primary key and behavior as 
partition key
+        schema = Schema.from_pyarrow_schema(
+            self.pa_schema,
+            primary_keys=['user_id', 'behavior'],
+            partition_keys=['behavior'],
+            options={'bucket': 2}
+        )
+        self.catalog.create_table('default.test_pk_table', schema, False)
+        table = self.catalog.get_table('default.test_pk_table')
+        self._write_test_table(table)
+
+        read_builder = table.new_read_builder().with_projection(['user_id', 
'behavior'])
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        splits = table_scan.plan().splits()
+        dataset = table_read.to_torch(splits, streaming=True)
+        dataloader = DataLoader(
+            dataset,
+            batch_size=2,
+            num_workers=3,
+            shuffle=False
+        )
+
+        # Collect all data from dataloader
+        all_user_ids = []
+        all_behaviors = []
+        for batch_idx, batch_data in enumerate(dataloader):
+            user_ids = batch_data['user_id'].tolist()
+            behaviors = batch_data['behavior']
+            all_user_ids.extend(user_ids)
+            all_behaviors.extend(behaviors)
+
+        # Sort by user_id for comparison
+        sorted_data = sorted(zip(all_user_ids, all_behaviors), key=lambda x: 
x[0])
+        sorted_user_ids = [x[0] for x in sorted_data]
+        sorted_behaviors = [x[1] for x in sorted_data]
+
+        # Expected data (sorted by user_id)
+        expected_user_ids = [1, 2, 3, 4, 5, 6, 7, 8]
+        expected_behaviors = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
+
+        # Verify results
+        self.assertEqual(sorted_user_ids, expected_user_ids,
+                         f"User IDs mismatch. Expected {expected_user_ids}, 
got {sorted_user_ids}")
+        self.assertEqual(sorted_behaviors, expected_behaviors,
+                         f"Behaviors mismatch. Expected {expected_behaviors}, 
got {sorted_behaviors}")
+
+        print(f"✓ PK table test passed: Successfully read {len(all_user_ids)} 
rows with correct data")
+
+    def test_torch_read_large_append_table(self):
+        """Test torch read with large data volume on append-only table."""
+        # Create append-only table
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'])
+        self.catalog.create_table('default.test_large_append', schema, False)
+        table = self.catalog.get_table('default.test_large_append')
+
+        # Write large amount of data
+        write_builder = table.new_batch_write_builder()
+        total_rows = 100000
+        batch_size = 10000
+        num_batches = total_rows // batch_size
+
+        print(f"\n{'=' * 60}")
+        print(f"Writing {total_rows} rows to append-only table...")
+        print(f"{'=' * 60}")
+
+        for batch_idx in range(num_batches):
+            table_write = write_builder.new_write()
+            table_commit = write_builder.new_commit()
+
+            start_id = batch_idx * batch_size + 1
+            end_id = start_id + batch_size
+
+            data = {
+                'user_id': list(range(start_id, end_id)),
+                'item_id': [1000 + i for i in range(start_id, end_id)],
+                'behavior': [chr(ord('a') + (i % 26)) for i in 
range(batch_size)],
+                'dt': [f'p{i % 4}' for i in range(batch_size)],
+            }
+            pa_table = pa.Table.from_pydict(data, schema=self.pa_schema)
+            table_write.write_arrow(pa_table)
+            table_commit.commit(table_write.prepare_commit())
+            table_write.close()
+            table_commit.close()
+
+            if (batch_idx + 1) % 2 == 0:
+                print(f"  Written {(batch_idx + 1) * batch_size} rows...")
+
+        # Read data using torch
+        print(f"\nReading {total_rows} rows using Torch DataLoader...")
+
+        read_builder = table.new_read_builder().with_projection(['user_id', 
'behavior'])
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        splits = table_scan.plan().splits()
+
+        print(f"Total splits: {len(splits)}")
+
+        dataset = table_read.to_torch(splits, streaming=True)
+        dataloader = DataLoader(
+            dataset,
+            batch_size=1000,
+            num_workers=4,
+            shuffle=False
+        )
+
+        # Collect all data
+        all_user_ids = []
+        batch_count = 0
+        for batch_idx, batch_data in enumerate(dataloader):
+            batch_count += 1
+            user_ids = batch_data['user_id'].tolist()
+            all_user_ids.extend(user_ids)
+
+            if (batch_idx + 1) % 20 == 0:
+                print(f"  Read {len(all_user_ids)} rows...")
+
+        all_user_ids.sort()
+        # Verify data
+        self.assertEqual(len(all_user_ids), total_rows,
+                         f"Row count mismatch. Expected {total_rows}, got 
{len(all_user_ids)}")
+        self.assertEqual(all_user_ids, list(range(1, total_rows + 1)),
+                         f"Row count mismatch. Expected {total_rows}, got 
{len(all_user_ids)}")
+        print(f"\n{'=' * 60}")
+        print("✓ Large append table test passed!")
+        print(f"  Total rows: {total_rows}")
+        print(f"  Total batches: {batch_count}")
+        print(f"{'=' * 60}\n")
+
+    def test_torch_read_large_pk_table(self):
+        """Test torch read with large data volume on primary key table."""
+
+        # Create PK table
+        schema = Schema.from_pyarrow_schema(
+            self.pa_schema,
+            primary_keys=['user_id'],
+            partition_keys=['dt'],
+            options={'bucket': '4'}
+        )
+        self.catalog.create_table('default.test_large_pk', schema, False)
+        table = self.catalog.get_table('default.test_large_pk')
+
+        # Write large amount of data
+        write_builder = table.new_batch_write_builder()
+        total_rows = 100000
+        batch_size = 10000
+        num_batches = total_rows // batch_size
+
+        print(f"\n{'=' * 60}")
+        print(f"Writing {total_rows} rows to PK table...")
+        print(f"{'=' * 60}")
+
+        for batch_idx in range(num_batches):
+            table_write = write_builder.new_write()
+            table_commit = write_builder.new_commit()
+
+            start_id = batch_idx * batch_size + 1
+            end_id = start_id + batch_size
+
+            data = {
+                'user_id': list(range(start_id, end_id)),
+                'item_id': [1000 + i for i in range(start_id, end_id)],
+                'behavior': [chr(ord('a') + (i % 26)) for i in 
range(batch_size)],
+                'dt': [f'p{i % 4}' for i in range(batch_size)],
+            }
+            pa_table = pa.Table.from_pydict(data, schema=self.pa_schema)
+            table_write.write_arrow(pa_table)
+            table_commit.commit(table_write.prepare_commit())
+            table_write.close()
+            table_commit.close()
+
+            if (batch_idx + 1) % 2 == 0:
+                print(f"  Written {(batch_idx + 1) * batch_size} rows...")
+
+        # Read data using torch
+        print(f"\nReading {total_rows} rows using Torch DataLoader...")
+
+        read_builder = table.new_read_builder()
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        splits = table_scan.plan().splits()
+
+        print(f"Total splits: {len(splits)}")
+
+        dataset = table_read.to_torch(splits, streaming=True)
+        dataloader = DataLoader(
+            dataset,
+            batch_size=1000,
+            num_workers=8,
+            shuffle=False
+        )
+
+        # Collect all data
+        all_user_ids = []
+        batch_count = 0
+        for batch_idx, batch_data in enumerate(dataloader):
+            batch_count += 1
+            user_ids = batch_data['user_id'].tolist()
+            all_user_ids.extend(user_ids)
+
+            if (batch_idx + 1) % 20 == 0:
+                print(f"  Read {len(all_user_ids)} rows...")
+
+        all_user_ids.sort()
+        # Verify data
+        self.assertEqual(len(all_user_ids), total_rows,
+                         f"Row count mismatch. Expected {total_rows}, got 
{len(all_user_ids)}")
+
+        self.assertEqual(all_user_ids, list(range(1, total_rows + 1)),
+                         f"Row count mismatch. Expected {total_rows}, got 
{len(all_user_ids)}")
+
+        print(f"\n{'=' * 60}")
+        print("✓ Large PK table test passed!")
+        print(f"  Total rows: {total_rows}")
+        print(f"  Total batches: {batch_count}")
+        print("  Primary key uniqueness: ✓")
+        print(f"{'=' * 60}\n")
+
+    def test_torch_read_with_predicate(self):
+        """Test torch read with predicate filtering."""
+
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['user_id'])
+        self.catalog.create_table('default.test_predicate', schema, False)
+        table = self.catalog.get_table('default.test_predicate')
+        self._write_test_table(table)
+
+        # Test case 1: Filter by user_id > 4
+        print(f"\n{'=' * 60}")
+        print("Test Case 1: user_id > 4")
+        print(f"{'=' * 60}")
+        predicate_builder = table.new_read_builder().new_predicate_builder()
+
+        predicate = predicate_builder.greater_than('user_id', 4)
+        read_builder = table.new_read_builder().with_filter(predicate)
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        splits = table_scan.plan().splits()
+        dataset = table_read.to_torch(splits, streaming=True)
+        dataloader = DataLoader(
+            dataset,
+            batch_size=2,
+            num_workers=0,
+            shuffle=False
+        )
+
+        all_user_ids = []
+        for batch_idx, batch_data in enumerate(dataloader):
+            user_ids = batch_data['user_id'].tolist()
+            all_user_ids.extend(user_ids)
+
+        all_user_ids.sort()
+        expected_user_ids = [5, 6, 7, 8]
+        self.assertEqual(all_user_ids, expected_user_ids,
+                         f"User IDs mismatch. Expected {expected_user_ids}, 
got {all_user_ids}")
+        print(f"✓ Filtered {len(all_user_ids)} rows: {all_user_ids}")
+
+        # Test case 2: Filter by user_id <= 3
+        print(f"\n{'=' * 60}")
+        print("Test Case 2: user_id <= 3")
+        print(f"{'=' * 60}")
+
+        predicate = predicate_builder.less_or_equal('user_id', 3)
+        read_builder = table.new_read_builder().with_filter(predicate)
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        splits = table_scan.plan().splits()
+        dataset = table_read.to_torch(splits, streaming=True)
+        dataloader = DataLoader(
+            dataset,
+            batch_size=2,
+            num_workers=0,
+            shuffle=False
+        )
+
+        all_user_ids = []
+        for batch_idx, batch_data in enumerate(dataloader):
+            user_ids = batch_data['user_id'].tolist()
+            all_user_ids.extend(user_ids)
+
+        all_user_ids.sort()
+        expected_user_ids = [1, 2, 3]
+        self.assertEqual(all_user_ids, expected_user_ids,
+                         f"User IDs mismatch. Expected {expected_user_ids}, 
got {all_user_ids}")
+        print(f"✓ Filtered {len(all_user_ids)} rows: {all_user_ids}")
+
+        # Test case 3: Filter by behavior = 'a'
+        print(f"\n{'=' * 60}")
+        print("Test Case 3: behavior = 'a'")
+        print(f"{'=' * 60}")
+
+        predicate = predicate_builder.equal('behavior', 'a')
+        read_builder = table.new_read_builder().with_filter(predicate)
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        splits = table_scan.plan().splits()
+        dataset = table_read.to_torch(splits, streaming=True)
+        dataloader = DataLoader(
+            dataset,
+            batch_size=2,
+            num_workers=0,
+            shuffle=False
+        )
+
+        all_user_ids = []
+        all_behaviors = []
+        for batch_idx, batch_data in enumerate(dataloader):
+            user_ids = batch_data['user_id'].tolist()
+            behaviors = batch_data['behavior']
+            all_user_ids.extend(user_ids)
+            all_behaviors.extend(behaviors)
+
+        expected_user_ids = [1]
+        expected_behaviors = ['a']
+        self.assertEqual(all_user_ids, expected_user_ids,
+                         f"User IDs mismatch. Expected {expected_user_ids}, 
got {all_user_ids}")
+        self.assertEqual(all_behaviors, expected_behaviors,
+                         f"Behaviors mismatch. Expected {expected_behaviors}, 
got {all_behaviors}")
+        print(f"✓ Filtered {len(all_user_ids)} rows: user_ids={all_user_ids}, 
behaviors={all_behaviors}")
+
+        # Test case 4: Filter by user_id IN (2, 4, 6)
+        print(f"\n{'=' * 60}")
+        print("Test Case 4: user_id IN (2, 4, 6)")
+        print(f"{'=' * 60}")
+
+        predicate = predicate_builder.is_in('user_id', [2, 4, 6])
+        read_builder = table.new_read_builder().with_filter(predicate)
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        splits = table_scan.plan().splits()
+        dataset = table_read.to_torch(splits, streaming=True)
+        dataloader = DataLoader(
+            dataset,
+            batch_size=2,
+            num_workers=0,
+            shuffle=False
+        )
+
+        all_user_ids = []
+        for batch_idx, batch_data in enumerate(dataloader):
+            user_ids = batch_data['user_id'].tolist()
+            all_user_ids.extend(user_ids)
+
+        all_user_ids.sort()
+        expected_user_ids = [2, 4, 6]
+        self.assertEqual(all_user_ids, expected_user_ids,
+                         f"User IDs mismatch. Expected {expected_user_ids}, 
got {all_user_ids}")
+        print(f"✓ Filtered {len(all_user_ids)} rows: {all_user_ids}")
+
+        # Test case 5: Combined filter (user_id > 2 AND user_id < 7)
+        print(f"\n{'=' * 60}")
+        print("Test Case 5: user_id > 2 AND user_id < 7")
+        print(f"{'=' * 60}")
+
+        predicate1 = predicate_builder.greater_than('user_id', 2)
+        predicate2 = predicate_builder.less_than('user_id', 7)
+        combined_predicate = predicate_builder.and_predicates([predicate1, 
predicate2])
+        read_builder = table.new_read_builder().with_filter(combined_predicate)
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        splits = table_scan.plan().splits()
+        dataset = table_read.to_torch(splits, streaming=True)
+        dataloader = DataLoader(
+            dataset,
+            batch_size=2,
+            num_workers=0,
+            shuffle=False
+        )
+
+        all_user_ids = []
+        for batch_idx, batch_data in enumerate(dataloader):
+            user_ids = batch_data['user_id'].tolist()
+            all_user_ids.extend(user_ids)
+
+        all_user_ids.sort()
+        expected_user_ids = [3, 4, 5, 6]
+        self.assertEqual(all_user_ids, expected_user_ids,
+                         f"User IDs mismatch. Expected {expected_user_ids}, 
got {all_user_ids}")
+        print(f"✓ Filtered {len(all_user_ids)} rows: {all_user_ids}")
+
+        print(f"\n{'=' * 60}")
+        print("✓ All predicate test cases passed!")
+        print(f"{'=' * 60}\n")
+
+    def _write_test_table(self, table):
+        write_builder = table.new_batch_write_builder()
+
+        # first write
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data1 = {
+            'user_id': [1, 2, 3, 4],
+            'item_id': [1001, 1002, 1003, 1004],
+            'behavior': ['a', 'b', 'c', 'd'],
+            'dt': ['p1', 'p1', 'p2', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
+        # second write
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data2 = {
+            'user_id': [5, 6, 7, 8],
+            'item_id': [1005, 1006, 1007, 1008],
+            'behavior': ['e', 'f', 'g', 'h'],
+            'dt': ['p2', 'p1', 'p2', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
+    def _read_test_table(self, read_builder):
+        table_read = read_builder.new_read()
+        splits = read_builder.new_scan().plan().splits()
+        return table_read.to_arrow(splits)

Reply via email to