This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 499f045 feat: add basic compression configuration to write_parquet
(#459)
499f045 is described below
commit 499f0458dcceea93d1512258808a805986b677e7
Author: Daniel Mesejo <[email protected]>
AuthorDate: Tue Aug 22 15:44:42 2023 +0200
feat: add basic compression configuration to write_parquet (#459)
---
datafusion/tests/test_dataframe.py | 67 ++++++++++++++++++++++++++++++++++++++
src/dataframe.rs | 58 +++++++++++++++++++++++++++++++--
2 files changed, 122 insertions(+), 3 deletions(-)
diff --git a/datafusion/tests/test_dataframe.py
b/datafusion/tests/test_dataframe.py
index 78cb50f..ce7d89e 100644
--- a/datafusion/tests/test_dataframe.py
+++ b/datafusion/tests/test_dataframe.py
@@ -14,8 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import os
import pyarrow as pa
+import pyarrow.parquet as pq
import pytest
from datafusion import functions as f
@@ -645,3 +647,68 @@ def test_describe(df):
"b": [3.0, 3.0, 5.0, 1.0, 4.0, 6.0, 5.0],
"c": [3.0, 3.0, 7.0, 1.7320508075688772, 5.0, 8.0, 8.0],
}
+
+
+def test_write_parquet(df, tmp_path):
+ path = tmp_path
+
+ df.write_parquet(str(path))
+ result = pq.read_table(str(path)).to_pydict()
+ expected = df.to_pydict()
+
+ assert result == expected
+
+
[email protected](
+ "compression, compression_level",
+ [("gzip", 6), ("brotli", 7), ("zstd", 15)],
+)
+def test_write_compressed_parquet(
+ df, tmp_path, compression, compression_level
+):
+ path = tmp_path
+
+ df.write_parquet(
+ str(path), compression=compression, compression_level=compression_level
+ )
+
+ # test that the actual compression scheme is the one written
+ for root, dirs, files in os.walk(path):
+ for file in files:
+ if file.endswith(".parquet"):
+ metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict()
+ for row_group in metadata["row_groups"]:
+ for columns in row_group["columns"]:
+ assert columns["compression"].lower() == compression
+
+ result = pq.read_table(str(path)).to_pydict()
+ expected = df.to_pydict()
+
+ assert result == expected
+
+
[email protected](
+ "compression, compression_level",
+ [("gzip", 12), ("brotli", 15), ("zstd", 23), ("wrong", 12)],
+)
+def test_write_compressed_parquet_wrong_compression_level(
+ df, tmp_path, compression, compression_level
+):
+ path = tmp_path
+
+ with pytest.raises(ValueError):
+ df.write_parquet(
+ str(path),
+ compression=compression,
+ compression_level=compression_level,
+ )
+
+
[email protected]("compression", ["brotli", "zstd", "wrong"])
+def test_write_compressed_parquet_missing_compression_level(
+ df, tmp_path, compression
+):
+ path = tmp_path
+
+ with pytest.raises(ValueError):
+ df.write_parquet(str(path), compression=compression)
diff --git a/src/dataframe.rs b/src/dataframe.rs
index b8d8ddc..61a4448 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -23,8 +23,10 @@ use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
use datafusion::arrow::util::pretty;
use datafusion::dataframe::DataFrame;
+use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel,
ZstdLevel};
+use datafusion::parquet::file::properties::WriterProperties;
use datafusion::prelude::*;
-use pyo3::exceptions::PyTypeError;
+use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyTuple;
use std::sync::Arc;
@@ -308,8 +310,58 @@ impl PyDataFrame {
}
/// Write a `DataFrame` to a Parquet file.
- fn write_parquet(&self, path: &str, py: Python) -> PyResult<()> {
- wait_for_future(py, self.df.as_ref().clone().write_parquet(path,
None))?;
+ #[pyo3(signature = (
+ path,
+ compression="uncompressed",
+ compression_level=None
+ ))]
+ fn write_parquet(
+ &self,
+ path: &str,
+ compression: &str,
+ compression_level: Option<u32>,
+ py: Python,
+ ) -> PyResult<()> {
+ fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
+ cl.ok_or(PyValueError::new_err("compression_level is not defined"))
+ }
+
+ let compression_type = match compression.to_lowercase().as_str() {
+ "snappy" => Compression::SNAPPY,
+ "gzip" => Compression::GZIP(
+ GzipLevel::try_new(compression_level.unwrap_or(6))
+ .map_err(|e| PyValueError::new_err(format!("{e}")))?,
+ ),
+ "brotli" => Compression::BROTLI(
+
BrotliLevel::try_new(verify_compression_level(compression_level)?)
+ .map_err(|e| PyValueError::new_err(format!("{e}")))?,
+ ),
+ "zstd" => Compression::ZSTD(
+
ZstdLevel::try_new(verify_compression_level(compression_level)? as i32)
+ .map_err(|e| PyValueError::new_err(format!("{e}")))?,
+ ),
+ "lz0" => Compression::LZO,
+ "lz4" => Compression::LZ4,
+ "lz4_raw" => Compression::LZ4_RAW,
+ "uncompressed" => Compression::UNCOMPRESSED,
+ _ => {
+ return Err(PyValueError::new_err(format!(
+ "Unrecognized compression type {compression}"
+ )));
+ }
+ };
+
+ let writer_properties = WriterProperties::builder()
+ .set_compression(compression_type)
+ .build();
+
+ wait_for_future(
+ py,
+ self.df
+ .as_ref()
+ .clone()
+ .write_parquet(path, Option::from(writer_properties)),
+ )?;
Ok(())
}