This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 78565733c1 [python] Fix default file compression and support file 
compression when writing avro file  (#6996)
78565733c1 is described below

commit 78565733c10ad8d9ad9249cb6a85718e02a70e7f
Author: XiaoHongbo <[email protected]>
AuthorDate: Sat Jan 10 19:47:54 2026 +0800

    [python] Fix default file compression and support file compression when 
writing avro file  (#6996)
---
 .github/workflows/paimon-python-checks.yml         |   3 +-
 paimon-python/dev/requirements.txt                 |   3 +-
 paimon-python/pypaimon/common/file_io.py           |  43 ++++-
 .../pypaimon/common/options/core_options.py        |  17 +-
 paimon-python/pypaimon/tests/reader_base_test.py   | 189 +++++++++++++++++++--
 .../pypaimon/write/writer/data_blob_writer.py      |   6 +-
 paimon-python/pypaimon/write/writer/data_writer.py |   7 +-
 7 files changed, 237 insertions(+), 31 deletions(-)

diff --git a/.github/workflows/paimon-python-checks.yml 
b/.github/workflows/paimon-python-checks.yml
index ff2929c0fa..367668453d 100755
--- a/.github/workflows/paimon-python-checks.yml
+++ b/.github/workflows/paimon-python-checks.yml
@@ -97,7 +97,7 @@ jobs:
           else
             python -m pip install --upgrade pip
             pip install torch --index-url https://download.pytorch.org/whl/cpu
-            python -m pip install pyroaring readerwriterlock==1.0.9 
fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0 
fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 
numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0 
py4j==0.10.9.9 requests parameterized==0.9.0
+            python -m pip install pyroaring readerwriterlock==1.0.9 
fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0 
fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 
numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 cramjam flake8==4.0.1 pytest~=7.0 
py4j==0.10.9.9 requests parameterized==0.9.0
           fi
           df -h
       - name: Run lint-python.sh
@@ -177,6 +177,7 @@ jobs:
             duckdb==1.3.2 \
             numpy==1.24.3 \
             pandas==2.0.3 \
+            cramjam \
             pytest~=7.0 \
             py4j==0.10.9.9 \
             requests \
diff --git a/paimon-python/dev/requirements.txt 
b/paimon-python/dev/requirements.txt
index e76827db3e..9f07798037 100644
--- a/paimon-python/dev/requirements.txt
+++ b/paimon-python/dev/requirements.txt
@@ -38,4 +38,5 @@ pyroaring
 ray>=2.10,<3
 readerwriterlock>=1,<2
 torch
-zstandard>=0.19,<1
\ No newline at end of file
+zstandard>=0.19,<1
+cramjam>=0.6,<1; python_version>="3.7"
diff --git a/paimon-python/pypaimon/common/file_io.py 
b/paimon-python/pypaimon/common/file_io.py
index 556d9e9ae7..80430aef9c 100644
--- a/paimon-python/pypaimon/common/file_io.py
+++ b/paimon-python/pypaimon/common/file_io.py
@@ -369,20 +369,30 @@ class FileIO:
 
         return None
 
-    def write_parquet(self, path: str, data: pyarrow.Table, compression: str = 
'snappy', **kwargs):
+    def write_parquet(self, path: str, data: pyarrow.Table, compression: str = 
'zstd',
+                      zstd_level: int = 1, **kwargs):
         try:
             import pyarrow.parquet as pq
 
             with self.new_output_stream(path) as output_stream:
+                if compression.lower() == 'zstd':
+                    kwargs['compression_level'] = zstd_level
                 pq.write_table(data, output_stream, compression=compression, 
**kwargs)
 
         except Exception as e:
             self.delete_quietly(path)
             raise RuntimeError(f"Failed to write Parquet file {path}: {e}") 
from e
 
-    def write_orc(self, path: str, data: pyarrow.Table, compression: str = 
'zstd', **kwargs):
+    def write_orc(self, path: str, data: pyarrow.Table, compression: str = 
'zstd',
+                  zstd_level: int = 1, **kwargs):
         try:
-            """Write ORC file using PyArrow ORC writer."""
+            """Write ORC file using PyArrow ORC writer.
+            
+            Note: PyArrow's ORC writer doesn't support compression_level 
parameter.
+            ORC files will use zstd compression with default level
+            (which is 3, see 
https://github.com/facebook/zstd/blob/dev/programs/zstdcli.c)
+            instead of the specified level.
+            """
             import sys
             import pyarrow.orc as orc
 
@@ -402,7 +412,10 @@ class FileIO:
             self.delete_quietly(path)
             raise RuntimeError(f"Failed to write ORC file {path}: {e}") from e
 
-    def write_avro(self, path: str, data: pyarrow.Table, avro_schema: 
Optional[Dict[str, Any]] = None, **kwargs):
+    def write_avro(
+            self, path: str, data: pyarrow.Table,
+            avro_schema: Optional[Dict[str, Any]] = None,
+            compression: str = 'zstd', zstd_level: int = 1, **kwargs):
         import fastavro
         if avro_schema is None:
             from pypaimon.schema.data_types import PyarrowFieldParser
@@ -417,8 +430,28 @@ class FileIO:
 
         records = record_generator()
 
+        codec_map = {
+            'null': 'null',
+            'deflate': 'deflate',
+            'snappy': 'snappy',
+            'bzip2': 'bzip2',
+            'xz': 'xz',
+            'zstandard': 'zstandard',
+            'zstd': 'zstandard',  # zstd is commonly used in Paimon
+        }
+        compression_lower = compression.lower()
+        
+        codec = codec_map.get(compression_lower)
+        if codec is None:
+            raise ValueError(
+                f"Unsupported compression '{compression}' for Avro format. "
+                f"Supported compressions: {', 
'.join(sorted(codec_map.keys()))}."
+            )
+
         with self.new_output_stream(path) as output_stream:
-            fastavro.writer(output_stream, avro_schema, records, **kwargs)
+            if codec == 'zstandard':
+                kwargs['codec_compression_level'] = zstd_level
+            fastavro.writer(output_stream, avro_schema, records, codec=codec, 
**kwargs)
 
     def write_lance(self, path: str, data: pyarrow.Table, **kwargs):
         try:
diff --git a/paimon-python/pypaimon/common/options/core_options.py 
b/paimon-python/pypaimon/common/options/core_options.py
index 49230240a7..ddc4c03e39 100644
--- a/paimon-python/pypaimon/common/options/core_options.py
+++ b/paimon-python/pypaimon/common/options/core_options.py
@@ -125,8 +125,18 @@ class CoreOptions:
     FILE_COMPRESSION: ConfigOption[str] = (
         ConfigOptions.key("file.compression")
         .string_type()
-        .default_value("lz4")
-        .with_description("Default file compression format.")
+        .default_value("zstd")
+        .with_description("Default file compression format. For faster read 
and write, it is recommended to use zstd.")
+    )
+
+    FILE_COMPRESSION_ZSTD_LEVEL: ConfigOption[int] = (
+        ConfigOptions.key("file.compression.zstd-level")
+        .int_type()
+        .default_value(1)
+        .with_description(
+            "Default file compression zstd level. For higher compression 
rates, it can be configured to 9, "
+            "but the read and write speed will significantly decrease."
+        )
     )
 
     FILE_COMPRESSION_PER_LEVEL: ConfigOption[Dict[str, str]] = (
@@ -346,6 +356,9 @@ class CoreOptions:
     def file_compression(self, default=None):
         return self.options.get(CoreOptions.FILE_COMPRESSION, default)
 
+    def file_compression_zstd_level(self, default=None):
+        return self.options.get(CoreOptions.FILE_COMPRESSION_ZSTD_LEVEL, 
default)
+
     def file_compression_per_level(self, default=None):
         return self.options.get(CoreOptions.FILE_COMPRESSION_PER_LEVEL, 
default)
 
diff --git a/paimon-python/pypaimon/tests/reader_base_test.py 
b/paimon-python/pypaimon/tests/reader_base_test.py
index c6223dadcb..a7b2abd516 100644
--- a/paimon-python/pypaimon/tests/reader_base_test.py
+++ b/paimon-python/pypaimon/tests/reader_base_test.py
@@ -28,6 +28,7 @@ from unittest.mock import Mock
 
 import pandas as pd
 import pyarrow as pa
+from parameterized import parameterized
 
 from pypaimon import CatalogFactory, Schema
 from pypaimon.manifest.manifest_file_manager import ManifestFileManager
@@ -675,7 +676,12 @@ class ReaderBasicTest(unittest.TestCase):
             l2.append(field.to_dict())
         self.assertEqual(l1, l2)
 
-    def test_write(self):
+    @parameterized.expand([
+        ('parquet',),
+        ('orc',),
+        ('avro',),
+    ])
+    def test_write(self, file_format):
         pa_schema = pa.schema([
             ('f0', pa.int32()),
             ('f1', pa.string()),
@@ -684,9 +690,15 @@ class ReaderBasicTest(unittest.TestCase):
         catalog = CatalogFactory.create({
             "warehouse": self.warehouse
         })
-        catalog.create_database("test_write_db", False)
-        catalog.create_table("test_write_db.test_table", 
Schema.from_pyarrow_schema(pa_schema), False)
-        table = catalog.get_table("test_write_db.test_table")
+        db_name = f"test_write_{file_format}_db"
+        table_name = f"test_{file_format}_table"
+        catalog.create_database(db_name, False)
+        schema = Schema.from_pyarrow_schema(
+            pa_schema,
+            options={'file.format': file_format}
+        )
+        catalog.create_table(f"{db_name}.{table_name}", schema, False)
+        table = catalog.get_table(f"{db_name}.{table_name}")
 
         data = {
             'f0': [1, 2, 3],
@@ -704,17 +716,7 @@ class ReaderBasicTest(unittest.TestCase):
         table_write.close()
         table_commit.close()
 
-        self.assertTrue(os.path.exists(self.warehouse + 
"/test_write_db.db/test_table/snapshot/LATEST"))
-        self.assertTrue(os.path.exists(self.warehouse + 
"/test_write_db.db/test_table/snapshot/snapshot-1"))
-        self.assertTrue(os.path.exists(self.warehouse + 
"/test_write_db.db/test_table/manifest"))
-        self.assertTrue(os.path.exists(self.warehouse + 
"/test_write_db.db/test_table/bucket-0"))
-        self.assertEqual(len(glob.glob(self.warehouse + 
"/test_write_db.db/test_table/manifest/*")), 3)
-        self.assertEqual(len(glob.glob(self.warehouse + 
"/test_write_db.db/test_table/bucket-0/*.parquet")), 1)
-
-        with open(self.warehouse + 
'/test_write_db.db/test_table/snapshot/snapshot-1', 'r', encoding='utf-8') as 
file:
-            content = ''.join(file.readlines())
-            self.assertTrue(content.__contains__('\"totalRecordCount\": 3'))
-            self.assertTrue(content.__contains__('\"deltaRecordCount\": 3'))
+        self._verify_file_compression(file_format, db_name, table_name, 
expected_rows=3)
 
         write_builder = table.new_batch_write_builder()
         table_write = write_builder.new_write()
@@ -725,11 +727,166 @@ class ReaderBasicTest(unittest.TestCase):
         table_write.close()
         table_commit.close()
 
-        with open(self.warehouse + 
'/test_write_db.db/test_table/snapshot/snapshot-2', 'r', encoding='utf-8') as 
file:
+        snapshot_path = os.path.join(self.warehouse, f"{db_name}.db", 
table_name, "snapshot", "snapshot-2")
+        with open(snapshot_path, 'r', encoding='utf-8') as file:
             content = ''.join(file.readlines())
             self.assertTrue(content.__contains__('\"totalRecordCount\": 6'))
             self.assertTrue(content.__contains__('\"deltaRecordCount\": 3'))
 
+    @parameterized.expand([
+        ('parquet', 'zstd'),
+        ('parquet', 'lz4'),
+        ('parquet', 'snappy'),
+        ('orc', 'zstd'),
+        ('orc', 'lz4'),
+        ('orc', 'snappy'),
+        ('avro', 'zstd'),
+        ('avro', 'snappy'),
+    ])
+    def test_write_with_compression(self, file_format, compression):
+        pa_schema = pa.schema([
+            ('f0', pa.int32()),
+            ('f1', pa.string()),
+            ('f2', pa.string())
+        ])
+        catalog = CatalogFactory.create({
+            "warehouse": self.warehouse
+        })
+        db_name = f"test_write_{file_format}_{compression}_db"
+        table_name = f"test_{file_format}_{compression}_table"
+        catalog.create_database(db_name, False)
+        schema = Schema.from_pyarrow_schema(
+            pa_schema,
+            options={
+                'file.format': file_format,
+                'file.compression': compression
+            }
+        )
+        catalog.create_table(f"{db_name}.{table_name}", schema, False)
+        table = catalog.get_table(f"{db_name}.{table_name}")
+
+        data = {
+            'f0': [1, 2, 3],
+            'f1': ['a', 'b', 'c'],
+            'f2': ['X', 'Y', 'Z']
+        }
+        expect = pa.Table.from_pydict(data, schema=pa_schema)
+
+        write_builder = table.new_batch_write_builder()
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+
+        try:
+            table_write.write_arrow(expect)
+            commit_messages = table_write.prepare_commit()
+            table_commit.commit(commit_messages)
+            table_write.close()
+            table_commit.close()
+
+            self._verify_file_compression_with_format(
+                file_format, compression, db_name, table_name, expected_rows=3
+            )
+        except (ValueError, RuntimeError):
+            raise
+
+    def _verify_file_compression_with_format(
+            self, file_format: str, compression: str,
+            db_name: str, table_name: str, expected_rows: int = 3, 
expected_zstd_level: int = 1):
+        if file_format == 'parquet':
+            parquet_files = glob.glob(self.warehouse + 
f"/{db_name}.db/{table_name}/bucket-0/*.parquet")
+            self.assertEqual(len(parquet_files), 1)
+            import pyarrow.parquet as pq
+            parquet_file_path = parquet_files[0]
+            parquet_metadata = pq.read_metadata(parquet_file_path)
+            for i in range(parquet_metadata.num_columns):
+                column_metadata = parquet_metadata.row_group(0).column(i)
+                actual_compression = column_metadata.compression
+                compression_str = str(actual_compression).upper()
+                expected_compression_upper = compression.upper()
+                self.assertIn(
+                    expected_compression_upper, compression_str,
+                    f"Expected compression to be {compression}, but got 
{actual_compression}")
+                if compression.lower() == 'zstd' and hasattr(column_metadata, 
'compression_level'):
+                    actual_level = column_metadata.compression_level
+                    self.assertEqual(
+                        actual_level, expected_zstd_level,
+                        f"Expected zstd compression level to be 
{expected_zstd_level}, but got {actual_level}")
+        elif file_format == 'orc':
+            orc_files = glob.glob(self.warehouse + 
f"/{db_name}.db/{table_name}/bucket-0/*.orc")
+            self.assertEqual(len(orc_files), 1)
+            import pyarrow.orc as orc
+            orc_file_path = orc_files[0]
+            orc_file = orc.ORCFile(orc_file_path)
+            try:
+                table = orc_file.read()
+                self.assertEqual(table.num_rows, expected_rows, "ORC file 
should contain expected rows")
+            except Exception as e:
+                self.fail(f"Failed to read ORC file (compression may be 
incorrect): {e}")
+        elif file_format == 'avro':
+            avro_files = glob.glob(self.warehouse + 
f"/{db_name}.db/{table_name}/bucket-0/*.avro")
+            self.assertEqual(len(avro_files), 1)
+            import fastavro
+            avro_file_path = avro_files[0]
+            with open(avro_file_path, 'rb') as f:
+                reader = fastavro.reader(f)
+                codec = reader.codec
+                expected_codec_map = {
+                    'zstd': 'zstandard',
+                    'zstandard': 'zstandard',
+                    'snappy': 'snappy',
+                    'deflate': 'deflate',
+                }
+                expected_codec = expected_codec_map.get(
+                    compression.lower(), compression.lower())
+                self.assertEqual(
+                    codec, expected_codec,
+                    f"Expected compression codec to be '{expected_codec}', but 
got '{codec}'")
+
+    def _verify_file_compression(self, file_format: str, db_name: str, 
table_name: str,
+                                 expected_rows: int = 3, expected_zstd_level: 
int = 1):
+        if file_format == 'parquet':
+            parquet_files = glob.glob(self.warehouse + 
f"/{db_name}.db/{table_name}/bucket-0/*.parquet")
+            self.assertEqual(len(parquet_files), 1)
+            import pyarrow.parquet as pq
+            parquet_file_path = parquet_files[0]
+            parquet_metadata = pq.read_metadata(parquet_file_path)
+            for i in range(parquet_metadata.num_columns):
+                column_metadata = parquet_metadata.row_group(0).column(i)
+                compression = column_metadata.compression
+                compression_str = str(compression).upper()
+                self.assertIn(
+                    'ZSTD', compression_str,
+                    f"Expected compression to be ZSTD , "
+                    f"but got {compression}")
+                if hasattr(column_metadata, 'compression_level'):
+                    actual_level = column_metadata.compression_level
+                    self.assertEqual(
+                        actual_level, expected_zstd_level,
+                        f"Expected zstd compression level to be 
{expected_zstd_level}, but got {actual_level}")
+        elif file_format == 'orc':
+            orc_files = glob.glob(self.warehouse + 
f"/{db_name}.db/{table_name}/bucket-0/*.orc")
+            self.assertEqual(len(orc_files), 1)
+            import pyarrow.orc as orc
+            orc_file_path = orc_files[0]
+            orc_file = orc.ORCFile(orc_file_path)
+            try:
+                table = orc_file.read()
+                self.assertEqual(table.num_rows, expected_rows, "ORC file 
should contain expected rows")
+            except Exception as e:
+                self.fail(f"Failed to read ORC file (compression may be 
incorrect): {e}")
+        elif file_format == 'avro':
+            avro_files = glob.glob(self.warehouse + 
f"/{db_name}.db/{table_name}/bucket-0/*.avro")
+            self.assertEqual(len(avro_files), 1)
+            import fastavro
+            avro_file_path = avro_files[0]
+            with open(avro_file_path, 'rb') as f:
+                reader = fastavro.reader(f)
+                codec = reader.codec
+                self.assertEqual(
+                    codec, 'zstandard',
+                    f"Expected compression codec to be 'zstandard', "
+                    f"but got '{codec}'")
+
     def _test_value_stats_cols_case(self, manifest_manager, table, 
value_stats_cols, expected_fields_count, test_name):
         """Helper method to test a specific _VALUE_STATS_COLS case."""
 
diff --git a/paimon-python/pypaimon/write/writer/data_blob_writer.py 
b/paimon-python/pypaimon/write/writer/data_blob_writer.py
index 8cdd7428dc..800e21e5a6 100644
--- a/paimon-python/pypaimon/write/writer/data_blob_writer.py
+++ b/paimon-python/pypaimon/write/writer/data_blob_writer.py
@@ -255,11 +255,11 @@ class DataBlobWriter(DataWriter):
 
         # Write file based on format
         if self.file_format == CoreOptions.FILE_FORMAT_PARQUET:
-            self.file_io.write_parquet(file_path, data, 
compression=self.compression)
+            self.file_io.write_parquet(file_path, data, 
compression=self.compression, zstd_level=self.zstd_level)
         elif self.file_format == CoreOptions.FILE_FORMAT_ORC:
-            self.file_io.write_orc(file_path, data, 
compression=self.compression)
+            self.file_io.write_orc(file_path, data, 
compression=self.compression, zstd_level=self.zstd_level)
         elif self.file_format == CoreOptions.FILE_FORMAT_AVRO:
-            self.file_io.write_avro(file_path, data)
+            self.file_io.write_avro(file_path, data, 
compression=self.compression, zstd_level=self.zstd_level)
         elif self.file_format == CoreOptions.FILE_FORMAT_LANCE:
             self.file_io.write_lance(file_path, data)
         else:
diff --git a/paimon-python/pypaimon/write/writer/data_writer.py 
b/paimon-python/pypaimon/write/writer/data_writer.py
index fa5f004b8e..8e1230678f 100644
--- a/paimon-python/pypaimon/write/writer/data_writer.py
+++ b/paimon-python/pypaimon/write/writer/data_writer.py
@@ -56,6 +56,7 @@ class DataWriter(ABC):
         )
         self.file_format = self.options.file_format(default_format)
         self.compression = self.options.file_compression()
+        self.zstd_level = self.options.file_compression_zstd_level()
         self.sequence_generator = SequenceGenerator(max_seq_number)
 
         self.pending_data: Optional[pa.Table] = None
@@ -169,11 +170,11 @@ class DataWriter(ABC):
             external_path_str = None
 
         if self.file_format == CoreOptions.FILE_FORMAT_PARQUET:
-            self.file_io.write_parquet(file_path, data, 
compression=self.compression)
+            self.file_io.write_parquet(file_path, data, 
compression=self.compression, zstd_level=self.zstd_level)
         elif self.file_format == CoreOptions.FILE_FORMAT_ORC:
-            self.file_io.write_orc(file_path, data, 
compression=self.compression)
+            self.file_io.write_orc(file_path, data, 
compression=self.compression, zstd_level=self.zstd_level)
         elif self.file_format == CoreOptions.FILE_FORMAT_AVRO:
-            self.file_io.write_avro(file_path, data)
+            self.file_io.write_avro(file_path, data, 
compression=self.compression, zstd_level=self.zstd_level)
         elif self.file_format == CoreOptions.FILE_FORMAT_BLOB:
             self.file_io.write_blob(file_path, data, self.blob_as_descriptor)
         elif self.file_format == CoreOptions.FILE_FORMAT_LANCE:

Reply via email to