This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new e1b3740 feat: add compression options (#456)
e1b3740 is described below
commit e1b37401a2d1af86ab16b899f1dda8237a0d3535
Author: Daniel Mesejo <[email protected]>
AuthorDate: Fri Aug 11 21:46:22 2023 +0200
feat: add compression options (#456)
---
datafusion/tests/test_context.py | 38 +++++++++++++++++++++++++++++++++++++-
datafusion/tests/test_sql.py | 23 +++++++++++++++++++++--
src/context.rs | 31 +++++++++++++++++++++++++------
3 files changed, 83 insertions(+), 9 deletions(-)
diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py
index 6b1223a..55a324a 100644
--- a/datafusion/tests/test_context.py
+++ b/datafusion/tests/test_context.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+import gzip
import os
import pyarrow as pa
@@ -336,11 +336,47 @@ def test_read_json(ctx):
assert result[0].column(1) == pa.array([1, 2, 3])
+def test_read_json_compressed(ctx, tmp_path):
+ path = os.path.dirname(os.path.abspath(__file__))
+ test_data_path = os.path.join(path, "data_test_context", "data.json")
+
+ # File compression type
+ gzip_path = tmp_path / "data.json.gz"
+
+ with open(test_data_path, "rb") as csv_file:
+ with gzip.open(gzip_path, "wb") as gzipped_file:
+ gzipped_file.writelines(csv_file)
+
+ df = ctx.read_json(
+ gzip_path, file_extension=".gz", file_compression_type="gz"
+ )
+ result = df.collect()
+
+ assert result[0].column(0) == pa.array(["a", "b", "c"])
+ assert result[0].column(1) == pa.array([1, 2, 3])
+
+
def test_read_csv(ctx):
csv_df = ctx.read_csv(path="testing/data/csv/aggregate_test_100.csv")
csv_df.select(column("c1")).show()
+def test_read_csv_compressed(ctx, tmp_path):
+ test_data_path = "testing/data/csv/aggregate_test_100.csv"
+
+ # File compression type
+ gzip_path = tmp_path / "aggregate_test_100.csv.gz"
+
+ with open(test_data_path, "rb") as csv_file:
+ with gzip.open(gzip_path, "wb") as gzipped_file:
+ gzipped_file.writelines(csv_file)
+
+ csv_df = ctx.read_csv(
+ gzip_path, file_extension=".gz", file_compression_type="gz"
+ )
+ csv_df.select(column("c1")).show()
+
+
def test_read_parquet(ctx):
csv_df = ctx.read_parquet(path="parquet/data/alltypes_plain.parquet")
csv_df.show()
diff --git a/datafusion/tests/test_sql.py b/datafusion/tests/test_sql.py
index 638a222..608bb19 100644
--- a/datafusion/tests/test_sql.py
+++ b/datafusion/tests/test_sql.py
@@ -19,6 +19,7 @@ import numpy as np
import pyarrow as pa
import pyarrow.dataset as ds
import pytest
+import gzip
from datafusion import udf
@@ -32,6 +33,7 @@ def test_no_table(ctx):
def test_register_csv(ctx, tmp_path):
path = tmp_path / "test.csv"
+ gzip_path = tmp_path / "test.csv.gz"
table = pa.Table.from_arrays(
[
@@ -43,6 +45,10 @@ def test_register_csv(ctx, tmp_path):
)
pa.csv.write_csv(table, path)
+ with open(path, "rb") as csv_file:
+ with gzip.open(gzip_path, "wb") as gzipped_file:
+ gzipped_file.writelines(csv_file)
+
ctx.register_csv("csv", path)
ctx.register_csv("csv1", str(path))
ctx.register_csv(
@@ -52,6 +58,13 @@ def test_register_csv(ctx, tmp_path):
delimiter=",",
schema_infer_max_records=10,
)
+ ctx.register_csv(
+ "csv_gzip",
+ gzip_path,
+ file_extension="gz",
+ file_compression_type="gzip",
+ )
+
alternative_schema = pa.schema(
[
("some_int", pa.int16()),
@@ -61,9 +74,9 @@ def test_register_csv(ctx, tmp_path):
)
ctx.register_csv("csv3", path, schema=alternative_schema)
- assert ctx.tables() == {"csv", "csv1", "csv2", "csv3"}
+ assert ctx.tables() == {"csv", "csv1", "csv2", "csv3", "csv_gzip"}
- for table in ["csv", "csv1", "csv2"]:
+ for table in ["csv", "csv1", "csv2", "csv_gzip"]:
result = ctx.sql(f"SELECT COUNT(int) AS cnt FROM {table}").collect()
result = pa.Table.from_batches(result)
assert result.to_pydict() == {"cnt": [4]}
@@ -77,6 +90,12 @@ def test_register_csv(ctx, tmp_path):
):
ctx.register_csv("csv4", path, delimiter="wrong")
+ with pytest.raises(
+ ValueError,
+ match="file_compression_type must one of: gzip, bz2, xz, zstd",
+ ):
+ ctx.register_csv("csv4", path, file_compression_type="rar")
+
def test_register_parquet(ctx, tmp_path):
path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
diff --git a/src/context.rs b/src/context.rs
index cf133d7..1dca8a7 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -17,6 +17,7 @@
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
+use std::str::FromStr;
use std::sync::Arc;
use object_store::ObjectStore;
@@ -40,6 +41,7 @@ use crate::utils::{get_tokio_runtime, wait_for_future};
use datafusion::arrow::datatypes::{DataType, Schema};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::datasource::file_format::file_type::FileCompressionType;
use datafusion::datasource::MemTable;
use datafusion::datasource::TableProvider;
use datafusion::execution::context::{SessionConfig, SessionContext,
TaskContext};
@@ -469,7 +471,8 @@ impl PySessionContext {
has_header=true,
delimiter=",",
schema_infer_max_records=1000,
- file_extension=".csv"))]
+ file_extension=".csv",
+ file_compression_type=None))]
fn register_csv(
&mut self,
name: &str,
@@ -479,6 +482,7 @@ impl PySessionContext {
delimiter: &str,
schema_infer_max_records: usize,
file_extension: &str,
+ file_compression_type: Option<String>,
py: Python,
) -> PyResult<()> {
let path = path
@@ -495,7 +499,8 @@ impl PySessionContext {
.has_header(has_header)
.delimiter(delimiter[0])
.schema_infer_max_records(schema_infer_max_records)
- .file_extension(file_extension);
+ .file_extension(file_extension)
+
.file_compression_type(parse_file_compression_type(file_compression_type)?);
options.schema = schema.as_ref().map(|x| &x.0);
let result = self.ctx.register_csv(name, path, options);
@@ -559,7 +564,7 @@ impl PySessionContext {
}
#[allow(clippy::too_many_arguments)]
- #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000,
file_extension=".json", table_partition_cols=vec![]))]
+ #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000,
file_extension=".json", table_partition_cols=vec![],
file_compression_type=None))]
fn read_json(
&mut self,
path: PathBuf,
@@ -567,13 +572,15 @@ impl PySessionContext {
schema_infer_max_records: usize,
file_extension: &str,
table_partition_cols: Vec<(String, String)>,
+ file_compression_type: Option<String>,
py: Python,
) -> PyResult<PyDataFrame> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a
string"))?;
let mut options = NdJsonReadOptions::default()
-
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
+
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+
.file_compression_type(parse_file_compression_type(file_compression_type)?);
options.schema_infer_max_records = schema_infer_max_records;
options.file_extension = file_extension;
let df = if let Some(schema) = schema {
@@ -595,7 +602,8 @@ impl PySessionContext {
delimiter=",",
schema_infer_max_records=1000,
file_extension=".csv",
- table_partition_cols=vec![]))]
+ table_partition_cols=vec![],
+ file_compression_type=None))]
fn read_csv(
&self,
path: PathBuf,
@@ -605,6 +613,7 @@ impl PySessionContext {
schema_infer_max_records: usize,
file_extension: &str,
table_partition_cols: Vec<(String, String)>,
+ file_compression_type: Option<String>,
py: Python,
) -> PyResult<PyDataFrame> {
let path = path
@@ -623,7 +632,8 @@ impl PySessionContext {
.delimiter(delimiter[0])
.schema_infer_max_records(schema_infer_max_records)
.file_extension(file_extension)
-
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
+
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+
.file_compression_type(parse_file_compression_type(file_compression_type)?);
if let Some(py_schema) = schema {
options.schema = Some(&py_schema.0);
@@ -743,6 +753,15 @@ fn convert_table_partition_cols(
.collect::<Result<Vec<_>, _>>()
}
+fn parse_file_compression_type(
+ file_compression_type: Option<String>,
+) -> Result<FileCompressionType, PyErr> {
+
FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str())
+ .map_err(|_| {
+ PyValueError::new_err("file_compression_type must one of: gzip,
bz2, xz, zstd")
+ })
+}
+
impl From<PySessionContext> for SessionContext {
fn from(ctx: PySessionContext) -> SessionContext {
ctx.ctx