This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new e24dc75 feat: add register_avro and read_table (#461)
e24dc75 is described below
commit e24dc75f2fe60efb5bc888fd70d2aede80027c25
Author: Daniel Mesejo <[email protected]>
AuthorDate: Tue Aug 22 15:45:06 2023 +0200
feat: add register_avro and read_table (#461)
---
datafusion/tests/test_context.py | 10 ++++++++++
datafusion/tests/test_sql.py | 26 +++++++++++++++++++++++++
src/context.rs | 41 ++++++++++++++++++++++++++++++++++++++++
3 files changed, 77 insertions(+)
diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py
index 55a324a..97bff9b 100644
--- a/datafusion/tests/test_context.py
+++ b/datafusion/tests/test_context.py
@@ -214,6 +214,16 @@ def test_register_table(ctx, database):
assert public.names() == {"csv", "csv1", "csv2", "csv3"}
+def test_read_table(ctx, database):
+ default = ctx.catalog()
+ public = default.database("public")
+ assert public.names() == {"csv", "csv1", "csv2"}
+
+ table = public.table("csv")
+ table_df = ctx.read_table(table)
+ table_df.show()
+
+
def test_deregister_table(ctx, database):
default = ctx.catalog()
public = default.database("public")
diff --git a/datafusion/tests/test_sql.py b/datafusion/tests/test_sql.py
index 9d42a1f..19a2ad2 100644
--- a/datafusion/tests/test_sql.py
+++ b/datafusion/tests/test_sql.py
@@ -205,6 +205,32 @@ def test_register_json(ctx, tmp_path):
ctx.register_json("json4", gzip_path, file_compression_type="rar")
+def test_register_avro(ctx):
+ path = "testing/data/avro/alltypes_plain.avro"
+ ctx.register_avro("alltypes_plain", path)
+ result = ctx.sql(
+ "SELECT SUM(tinyint_col) as tinyint_sum FROM alltypes_plain"
+ ).collect()
+ result = pa.Table.from_batches(result).to_pydict()
+ assert result["tinyint_sum"][0] > 0
+
+ alternative_schema = pa.schema(
+ [
+ pa.field("id", pa.int64()),
+ ]
+ )
+
+ ctx.register_avro(
+ "alltypes_plain_schema",
+ path,
+ schema=alternative_schema,
+ infinite=False,
+ )
+ result = ctx.sql("SELECT * FROM alltypes_plain_schema").collect()
+ result = pa.Table.from_batches(result)
+ assert result.schema == alternative_schema
+
+
def test_execute(ctx, tmp_path):
data = [1, 1, 2, 2, 3, 11, 12]
diff --git a/src/context.rs b/src/context.rs
index 317ab78..c7f89f2 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -545,6 +545,39 @@ impl PySessionContext {
Ok(())
}
+ #[allow(clippy::too_many_arguments)]
+ #[pyo3(signature = (name,
+ path,
+ schema=None,
+ file_extension=".avro",
+ table_partition_cols=vec![],
+ infinite=false))]
+ fn register_avro(
+ &mut self,
+ name: &str,
+ path: PathBuf,
+ schema: Option<PyArrowType<Schema>>,
+ file_extension: &str,
+ table_partition_cols: Vec<(String, String)>,
+ infinite: bool,
+ py: Python,
+ ) -> PyResult<()> {
+ let path = path
+ .to_str()
+ .ok_or_else(|| PyValueError::new_err("Unable to convert path to a
string"))?;
+
+ let mut options = AvroReadOptions::default()
+
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+ .mark_infinite(infinite);
+ options.file_extension = file_extension;
+ options.schema = schema.as_ref().map(|x| &x.0);
+
+ let result = self.ctx.register_avro(name, path, options);
+ wait_for_future(py, result).map_err(DataFusionError::from)?;
+
+ Ok(())
+ }
+
// Registers a PyArrow.Dataset
fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) ->
PyResult<()> {
let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset,
py)?);
@@ -734,6 +767,14 @@ impl PySessionContext {
Ok(PyDataFrame::new(df))
}
+ fn read_table(&self, table: &PyTable) -> PyResult<PyDataFrame> {
+ let df = self
+ .ctx
+ .read_table(table.table())
+ .map_err(DataFusionError::from)?;
+ Ok(PyDataFrame::new(df))
+ }
+
fn __repr__(&self) -> PyResult<String> {
let config = self.ctx.copied_config();
let mut config_entries = config