This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 499f045  feat: add basic compression configuration to write_parquet 
(#459)
499f045 is described below

commit 499f0458dcceea93d1512258808a805986b677e7
Author: Daniel Mesejo <[email protected]>
AuthorDate: Tue Aug 22 15:44:42 2023 +0200

    feat: add basic compression configuration to write_parquet (#459)
---
 datafusion/tests/test_dataframe.py | 67 ++++++++++++++++++++++++++++++++++++++
 src/dataframe.rs                   | 58 +++++++++++++++++++++++++++++++--
 2 files changed, 122 insertions(+), 3 deletions(-)

diff --git a/datafusion/tests/test_dataframe.py 
b/datafusion/tests/test_dataframe.py
index 78cb50f..ce7d89e 100644
--- a/datafusion/tests/test_dataframe.py
+++ b/datafusion/tests/test_dataframe.py
@@ -14,8 +14,10 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import os
 
 import pyarrow as pa
+import pyarrow.parquet as pq
 import pytest
 
 from datafusion import functions as f
@@ -645,3 +647,68 @@ def test_describe(df):
         "b": [3.0, 3.0, 5.0, 1.0, 4.0, 6.0, 5.0],
         "c": [3.0, 3.0, 7.0, 1.7320508075688772, 5.0, 8.0, 8.0],
     }
+
+
+def test_write_parquet(df, tmp_path):
+    path = tmp_path
+
+    df.write_parquet(str(path))
+    result = pq.read_table(str(path)).to_pydict()
+    expected = df.to_pydict()
+
+    assert result == expected
+
+
[email protected](
+    "compression, compression_level",
+    [("gzip", 6), ("brotli", 7), ("zstd", 15)],
+)
+def test_write_compressed_parquet(
+    df, tmp_path, compression, compression_level
+):
+    path = tmp_path
+
+    df.write_parquet(
+        str(path), compression=compression, compression_level=compression_level
+    )
+
+    # test that the actual compression scheme is the one written
+    for root, dirs, files in os.walk(path):
+        for file in files:
+            if file.endswith(".parquet"):
+                metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict()
+                for row_group in metadata["row_groups"]:
+                    for columns in row_group["columns"]:
+                        assert columns["compression"].lower() == compression
+
+    result = pq.read_table(str(path)).to_pydict()
+    expected = df.to_pydict()
+
+    assert result == expected
+
+
[email protected](
+    "compression, compression_level",
+    [("gzip", 12), ("brotli", 15), ("zstd", 23), ("wrong", 12)],
+)
+def test_write_compressed_parquet_wrong_compression_level(
+    df, tmp_path, compression, compression_level
+):
+    path = tmp_path
+
+    with pytest.raises(ValueError):
+        df.write_parquet(
+            str(path),
+            compression=compression,
+            compression_level=compression_level,
+        )
+
+
[email protected]("compression", ["brotli", "zstd", "wrong"])
+def test_write_compressed_parquet_missing_compression_level(
+    df, tmp_path, compression
+):
+    path = tmp_path
+
+    with pytest.raises(ValueError):
+        df.write_parquet(str(path), compression=compression)
diff --git a/src/dataframe.rs b/src/dataframe.rs
index b8d8ddc..61a4448 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -23,8 +23,10 @@ use datafusion::arrow::datatypes::Schema;
 use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
 use datafusion::arrow::util::pretty;
 use datafusion::dataframe::DataFrame;
+use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, 
ZstdLevel};
+use datafusion::parquet::file::properties::WriterProperties;
 use datafusion::prelude::*;
-use pyo3::exceptions::PyTypeError;
+use pyo3::exceptions::{PyTypeError, PyValueError};
 use pyo3::prelude::*;
 use pyo3::types::PyTuple;
 use std::sync::Arc;
@@ -308,8 +310,58 @@ impl PyDataFrame {
     }
 
     /// Write a `DataFrame` to a Parquet file.
-    fn write_parquet(&self, path: &str, py: Python) -> PyResult<()> {
-        wait_for_future(py, self.df.as_ref().clone().write_parquet(path, 
None))?;
+    #[pyo3(signature = (
+        path,
+        compression="uncompressed",
+        compression_level=None
+        ))]
+    fn write_parquet(
+        &self,
+        path: &str,
+        compression: &str,
+        compression_level: Option<u32>,
+        py: Python,
+    ) -> PyResult<()> {
+        fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
+            cl.ok_or(PyValueError::new_err("compression_level is not defined"))
+        }
+
+        let compression_type = match compression.to_lowercase().as_str() {
+            "snappy" => Compression::SNAPPY,
+            "gzip" => Compression::GZIP(
+                GzipLevel::try_new(compression_level.unwrap_or(6))
+                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
+            ),
+            "brotli" => Compression::BROTLI(
+                
BrotliLevel::try_new(verify_compression_level(compression_level)?)
+                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
+            ),
+            "zstd" => Compression::ZSTD(
+                
ZstdLevel::try_new(verify_compression_level(compression_level)? as i32)
+                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
+            ),
+            "lz0" => Compression::LZO,
+            "lz4" => Compression::LZ4,
+            "lz4_raw" => Compression::LZ4_RAW,
+            "uncompressed" => Compression::UNCOMPRESSED,
+            _ => {
+                return Err(PyValueError::new_err(format!(
+                    "Unrecognized compression type {compression}"
+                )));
+            }
+        };
+
+        let writer_properties = WriterProperties::builder()
+            .set_compression(compression_type)
+            .build();
+
+        wait_for_future(
+            py,
+            self.df
+                .as_ref()
+                .clone()
+                .write_parquet(path, Option::from(writer_properties)),
+        )?;
         Ok(())
     }
 

Reply via email to