This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 2d8b1d3 Default to ZSTD compression when writing Parquet (#981)
2d8b1d3 is described below
commit 2d8b1d32f4941b2e02a29e9135025a32ba6ae471
Author: kosiew <[email protected]>
AuthorDate: Sat Jan 11 10:12:04 2025 +0800
Default to ZSTD compression when writing Parquet (#981)
* fix: update default compression to ZSTD and improve documentation for
write_parquet method
* fix: clarify compression level documentation for ZSTD in write_parquet
method
* fix: update default compression level for ZSTD to 4 in write_parquet
method
* fix: improve docstring formatting for DataFrame parquet writing method
* feat: implement Compression enum and update write_parquet method to use it
* add test
* fix: remove unused import and update default compression to ZSTD in rs'
write_parquet method
* fix: update compression type strings to lowercase in DataFrame parquet
writing method doc
* test: update parquet compression tests to validate invalid and default
compression levels
* add comment on source of Compression
* docs: enhance Compression enum documentation and add default level method
* test: include gzip in default compression level tests for write_parquet
* refactor: simplify Compression enum methods and improve type handling in
DataFrame.write_parquet
* docs: update Compression enum methods to include return type descriptions
* move comment to within test
* Ruff format
---------
Co-authored-by: Tim Saucer <[email protected]>
---
python/datafusion/dataframe.py | 94 +++++++++++++++++++++++++++++++++++++++---
python/tests/test_dataframe.py | 14 ++++++-
src/dataframe.rs | 2 +-
3 files changed, 101 insertions(+), 9 deletions(-)
diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py
index 0b38db9..f8aef0c 100644
--- a/python/datafusion/dataframe.py
+++ b/python/datafusion/dataframe.py
@@ -21,7 +21,16 @@ See :ref:`user_guide_concepts` in the online documentation
for more information.
from __future__ import annotations
import warnings
-from typing import Any, Iterable, List, TYPE_CHECKING, Literal, overload
+from typing import (
+ Any,
+ Iterable,
+ List,
+ TYPE_CHECKING,
+ Literal,
+ overload,
+ Optional,
+ Union,
+)
from datafusion.record_batch import RecordBatchStream
from typing_extensions import deprecated
from datafusion.plan import LogicalPlan, ExecutionPlan
@@ -35,6 +44,60 @@ if TYPE_CHECKING:
from datafusion._internal import DataFrame as DataFrameInternal
from datafusion.expr import Expr, SortExpr, sort_or_default
+from enum import Enum
+
+
+# excerpt from deltalake
+# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
+class Compression(Enum):
+ """Enum representing the available compression types for Parquet files."""
+
+ UNCOMPRESSED = "uncompressed"
+ SNAPPY = "snappy"
+ GZIP = "gzip"
+ BROTLI = "brotli"
+ LZ4 = "lz4"
+ LZ0 = "lz0"
+ ZSTD = "zstd"
+ LZ4_RAW = "lz4_raw"
+
+ @classmethod
+ def from_str(cls, value: str) -> "Compression":
+ """Convert a string to a Compression enum value.
+
+ Args:
+ value: The string representation of the compression type.
+
+ Returns:
+ The Compression enum lowercase value.
+
+ Raises:
+ ValueError: If the string does not match any Compression enum
value.
+ """
+ try:
+ return cls(value.lower())
+ except ValueError:
+ raise ValueError(
+ f"{value} is not a valid Compression. Valid values are:
{[item.value for item in Compression]}"
+ )
+
+ def get_default_level(self) -> Optional[int]:
+ """Get the default compression level for the compression type.
+
+ Returns:
+ The default compression level for the compression type.
+ """
+ # GZIP, BROTLI default values from deltalake repo
+ #
https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
+ # ZSTD default value from delta-rs
+ #
https://github.com/apache/datafusion-python/pull/981#discussion_r1904789223
+ if self == Compression.GZIP:
+ return 6
+ elif self == Compression.BROTLI:
+ return 1
+ elif self == Compression.ZSTD:
+ return 4
+ return None
class DataFrame:
@@ -620,17 +683,36 @@ class DataFrame:
def write_parquet(
self,
path: str | pathlib.Path,
- compression: str = "uncompressed",
+ compression: Union[str, Compression] = Compression.ZSTD,
compression_level: int | None = None,
) -> None:
"""Execute the :py:class:`DataFrame` and write the results to a
Parquet file.
Args:
path: Path of the Parquet file to write.
- compression: Compression type to use.
- compression_level: Compression level to use.
- """
- self.df.write_parquet(str(path), compression, compression_level)
+ compression: Compression type to use. Default is "ZSTD".
+ Available compression types are:
+ - "uncompressed": No compression.
+ - "snappy": Snappy compression.
+ - "gzip": Gzip compression.
+ - "brotli": Brotli compression.
+ - "lz0": LZ0 compression.
+ - "lz4": LZ4 compression.
+ - "lz4_raw": LZ4_RAW compression.
+ - "zstd": Zstandard compression.
+ compression_level: Compression level to use. For ZSTD, the
+ recommended range is 1 to 22, with the default being 4. Higher
levels
+ provide better compression but slower speed.
+ """
+ # Convert string to Compression enum if necessary
+ if isinstance(compression, str):
+ compression = Compression.from_str(compression)
+
+ if compression in {Compression.GZIP, Compression.BROTLI,
Compression.ZSTD}:
+ if compression_level is None:
+ compression_level = compression.get_default_level()
+
+ self.df.write_parquet(str(path), compression.value, compression_level)
def write_json(self, path: str | pathlib.Path) -> None:
"""Execute the :py:class:`DataFrame` and write the results to a JSON
file.
diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index e3bd1b2..fa5f4e8 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -1107,14 +1107,24 @@ def
test_write_compressed_parquet_wrong_compression_level(
)
[email protected]("compression", ["brotli", "zstd", "wrong"])
-def test_write_compressed_parquet_missing_compression_level(df, tmp_path,
compression):
[email protected]("compression", ["wrong"])
+def test_write_compressed_parquet_invalid_compression(df, tmp_path,
compression):
path = tmp_path
with pytest.raises(ValueError):
df.write_parquet(str(path), compression=compression)
[email protected]("compression", ["zstd", "brotli", "gzip"])
+def test_write_compressed_parquet_default_compression_level(df, tmp_path,
compression):
+ # Test write_parquet with zstd, brotli, gzip default compression level,
+ # ie don't specify compression level
+ # should complete without error
+ path = tmp_path
+
+ df.write_parquet(str(path), compression=compression)
+
+
def test_dataframe_export(df) -> None:
# Guarantees that we have the canonical implementation
# reading our dataframe export
diff --git a/src/dataframe.rs b/src/dataframe.rs
index fcb46a7..71a6fe6 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -463,7 +463,7 @@ impl PyDataFrame {
/// Write a `DataFrame` to a Parquet file.
#[pyo3(signature = (
path,
- compression="uncompressed",
+ compression="zstd",
compression_level=None
))]
fn write_parquet(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]