This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new e1b3740  feat: add compression options (#456)
e1b3740 is described below

commit e1b37401a2d1af86ab16b899f1dda8237a0d3535
Author: Daniel Mesejo <[email protected]>
AuthorDate: Fri Aug 11 21:46:22 2023 +0200

    feat: add compression options (#456)
---
 datafusion/tests/test_context.py | 38 +++++++++++++++++++++++++++++++++++++-
 datafusion/tests/test_sql.py     | 23 +++++++++++++++++++++--
 src/context.rs                   | 31 +++++++++++++++++++++++++------
 3 files changed, 83 insertions(+), 9 deletions(-)

diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py
index 6b1223a..55a324a 100644
--- a/datafusion/tests/test_context.py
+++ b/datafusion/tests/test_context.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
+import gzip
 import os
 
 import pyarrow as pa
@@ -336,11 +336,47 @@ def test_read_json(ctx):
     assert result[0].column(1) == pa.array([1, 2, 3])
 
 
+def test_read_json_compressed(ctx, tmp_path):
+    path = os.path.dirname(os.path.abspath(__file__))
+    test_data_path = os.path.join(path, "data_test_context", "data.json")
+
+    # File compression type
+    gzip_path = tmp_path / "data.json.gz"
+
+    with open(test_data_path, "rb") as csv_file:
+        with gzip.open(gzip_path, "wb") as gzipped_file:
+            gzipped_file.writelines(csv_file)
+
+    df = ctx.read_json(
+        gzip_path, file_extension=".gz", file_compression_type="gz"
+    )
+    result = df.collect()
+
+    assert result[0].column(0) == pa.array(["a", "b", "c"])
+    assert result[0].column(1) == pa.array([1, 2, 3])
+
+
 def test_read_csv(ctx):
     csv_df = ctx.read_csv(path="testing/data/csv/aggregate_test_100.csv")
     csv_df.select(column("c1")).show()
 
 
+def test_read_csv_compressed(ctx, tmp_path):
+    test_data_path = "testing/data/csv/aggregate_test_100.csv"
+
+    # File compression type
+    gzip_path = tmp_path / "aggregate_test_100.csv.gz"
+
+    with open(test_data_path, "rb") as csv_file:
+        with gzip.open(gzip_path, "wb") as gzipped_file:
+            gzipped_file.writelines(csv_file)
+
+    csv_df = ctx.read_csv(
+        gzip_path, file_extension=".gz", file_compression_type="gz"
+    )
+    csv_df.select(column("c1")).show()
+
+
 def test_read_parquet(ctx):
     csv_df = ctx.read_parquet(path="parquet/data/alltypes_plain.parquet")
     csv_df.show()
diff --git a/datafusion/tests/test_sql.py b/datafusion/tests/test_sql.py
index 638a222..608bb19 100644
--- a/datafusion/tests/test_sql.py
+++ b/datafusion/tests/test_sql.py
@@ -19,6 +19,7 @@ import numpy as np
 import pyarrow as pa
 import pyarrow.dataset as ds
 import pytest
+import gzip
 
 from datafusion import udf
 
@@ -32,6 +33,7 @@ def test_no_table(ctx):
 
 def test_register_csv(ctx, tmp_path):
     path = tmp_path / "test.csv"
+    gzip_path = tmp_path / "test.csv.gz"
 
     table = pa.Table.from_arrays(
         [
@@ -43,6 +45,10 @@ def test_register_csv(ctx, tmp_path):
     )
     pa.csv.write_csv(table, path)
 
+    with open(path, "rb") as csv_file:
+        with gzip.open(gzip_path, "wb") as gzipped_file:
+            gzipped_file.writelines(csv_file)
+
     ctx.register_csv("csv", path)
     ctx.register_csv("csv1", str(path))
     ctx.register_csv(
@@ -52,6 +58,13 @@ def test_register_csv(ctx, tmp_path):
         delimiter=",",
         schema_infer_max_records=10,
     )
+    ctx.register_csv(
+        "csv_gzip",
+        gzip_path,
+        file_extension="gz",
+        file_compression_type="gzip",
+    )
+
     alternative_schema = pa.schema(
         [
             ("some_int", pa.int16()),
@@ -61,9 +74,9 @@ def test_register_csv(ctx, tmp_path):
     )
     ctx.register_csv("csv3", path, schema=alternative_schema)
 
-    assert ctx.tables() == {"csv", "csv1", "csv2", "csv3"}
+    assert ctx.tables() == {"csv", "csv1", "csv2", "csv3", "csv_gzip"}
 
-    for table in ["csv", "csv1", "csv2"]:
+    for table in ["csv", "csv1", "csv2", "csv_gzip"]:
         result = ctx.sql(f"SELECT COUNT(int) AS cnt FROM {table}").collect()
         result = pa.Table.from_batches(result)
         assert result.to_pydict() == {"cnt": [4]}
@@ -77,6 +90,12 @@ def test_register_csv(ctx, tmp_path):
     ):
         ctx.register_csv("csv4", path, delimiter="wrong")
 
+    with pytest.raises(
+        ValueError,
+        match="file_compression_type must one of: gzip, bz2, xz, zstd",
+    ):
+        ctx.register_csv("csv4", path, file_compression_type="rar")
+
 
 def test_register_parquet(ctx, tmp_path):
     path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
diff --git a/src/context.rs b/src/context.rs
index cf133d7..1dca8a7 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -17,6 +17,7 @@
 
 use std::collections::{HashMap, HashSet};
 use std::path::PathBuf;
+use std::str::FromStr;
 use std::sync::Arc;
 
 use object_store::ObjectStore;
@@ -40,6 +41,7 @@ use crate::utils::{get_tokio_runtime, wait_for_future};
 use datafusion::arrow::datatypes::{DataType, Schema};
 use datafusion::arrow::pyarrow::PyArrowType;
 use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::datasource::file_format::file_type::FileCompressionType;
 use datafusion::datasource::MemTable;
 use datafusion::datasource::TableProvider;
 use datafusion::execution::context::{SessionConfig, SessionContext, 
TaskContext};
@@ -469,7 +471,8 @@ impl PySessionContext {
                         has_header=true,
                         delimiter=",",
                         schema_infer_max_records=1000,
-                        file_extension=".csv"))]
+                        file_extension=".csv",
+                        file_compression_type=None))]
     fn register_csv(
         &mut self,
         name: &str,
@@ -479,6 +482,7 @@ impl PySessionContext {
         delimiter: &str,
         schema_infer_max_records: usize,
         file_extension: &str,
+        file_compression_type: Option<String>,
         py: Python,
     ) -> PyResult<()> {
         let path = path
@@ -495,7 +499,8 @@ impl PySessionContext {
             .has_header(has_header)
             .delimiter(delimiter[0])
             .schema_infer_max_records(schema_infer_max_records)
-            .file_extension(file_extension);
+            .file_extension(file_extension)
+            
.file_compression_type(parse_file_compression_type(file_compression_type)?);
         options.schema = schema.as_ref().map(|x| &x.0);
 
         let result = self.ctx.register_csv(name, path, options);
@@ -559,7 +564,7 @@ impl PySessionContext {
     }
 
     #[allow(clippy::too_many_arguments)]
-    #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, 
file_extension=".json", table_partition_cols=vec![]))]
+    #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, 
file_extension=".json", table_partition_cols=vec![], 
file_compression_type=None))]
     fn read_json(
         &mut self,
         path: PathBuf,
@@ -567,13 +572,15 @@ impl PySessionContext {
         schema_infer_max_records: usize,
         file_extension: &str,
         table_partition_cols: Vec<(String, String)>,
+        file_compression_type: Option<String>,
         py: Python,
     ) -> PyResult<PyDataFrame> {
         let path = path
             .to_str()
             .ok_or_else(|| PyValueError::new_err("Unable to convert path to a 
string"))?;
         let mut options = NdJsonReadOptions::default()
-            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
+            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+            
.file_compression_type(parse_file_compression_type(file_compression_type)?);
         options.schema_infer_max_records = schema_infer_max_records;
         options.file_extension = file_extension;
         let df = if let Some(schema) = schema {
@@ -595,7 +602,8 @@ impl PySessionContext {
         delimiter=",",
         schema_infer_max_records=1000,
         file_extension=".csv",
-        table_partition_cols=vec![]))]
+        table_partition_cols=vec![],
+        file_compression_type=None))]
     fn read_csv(
         &self,
         path: PathBuf,
@@ -605,6 +613,7 @@ impl PySessionContext {
         schema_infer_max_records: usize,
         file_extension: &str,
         table_partition_cols: Vec<(String, String)>,
+        file_compression_type: Option<String>,
         py: Python,
     ) -> PyResult<PyDataFrame> {
         let path = path
@@ -623,7 +632,8 @@ impl PySessionContext {
             .delimiter(delimiter[0])
             .schema_infer_max_records(schema_infer_max_records)
             .file_extension(file_extension)
-            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
+            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+            
.file_compression_type(parse_file_compression_type(file_compression_type)?);
 
         if let Some(py_schema) = schema {
             options.schema = Some(&py_schema.0);
@@ -743,6 +753,15 @@ fn convert_table_partition_cols(
         .collect::<Result<Vec<_>, _>>()
 }
 
+fn parse_file_compression_type(
+    file_compression_type: Option<String>,
+) -> Result<FileCompressionType, PyErr> {
+    
FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str())
+        .map_err(|_| {
+            PyValueError::new_err("file_compression_type must one of: gzip, 
bz2, xz, zstd")
+        })
+}
+
 impl From<PySessionContext> for SessionContext {
     fn from(ctx: PySessionContext) -> SessionContext {
         ctx.ctx

Reply via email to