This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new e24dc75  feat: add register_avro and read_table (#461)
e24dc75 is described below

commit e24dc75f2fe60efb5bc888fd70d2aede80027c25
Author: Daniel Mesejo <[email protected]>
AuthorDate: Tue Aug 22 15:45:06 2023 +0200

    feat: add register_avro and read_table (#461)
---
 datafusion/tests/test_context.py | 10 ++++++++++
 datafusion/tests/test_sql.py     | 26 +++++++++++++++++++++++++
 src/context.rs                   | 41 ++++++++++++++++++++++++++++++++++++++++
 3 files changed, 77 insertions(+)

diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py
index 55a324a..97bff9b 100644
--- a/datafusion/tests/test_context.py
+++ b/datafusion/tests/test_context.py
@@ -214,6 +214,16 @@ def test_register_table(ctx, database):
     assert public.names() == {"csv", "csv1", "csv2", "csv3"}
 
 
+def test_read_table(ctx, database):
+    default = ctx.catalog()
+    public = default.database("public")
+    assert public.names() == {"csv", "csv1", "csv2"}
+
+    table = public.table("csv")
+    table_df = ctx.read_table(table)
+    table_df.show()
+
+
 def test_deregister_table(ctx, database):
     default = ctx.catalog()
     public = default.database("public")
diff --git a/datafusion/tests/test_sql.py b/datafusion/tests/test_sql.py
index 9d42a1f..19a2ad2 100644
--- a/datafusion/tests/test_sql.py
+++ b/datafusion/tests/test_sql.py
@@ -205,6 +205,32 @@ def test_register_json(ctx, tmp_path):
         ctx.register_json("json4", gzip_path, file_compression_type="rar")
 
 
+def test_register_avro(ctx):
+    path = "testing/data/avro/alltypes_plain.avro"
+    ctx.register_avro("alltypes_plain", path)
+    result = ctx.sql(
+        "SELECT SUM(tinyint_col) as tinyint_sum FROM alltypes_plain"
+    ).collect()
+    result = pa.Table.from_batches(result).to_pydict()
+    assert result["tinyint_sum"][0] > 0
+
+    alternative_schema = pa.schema(
+        [
+            pa.field("id", pa.int64()),
+        ]
+    )
+
+    ctx.register_avro(
+        "alltypes_plain_schema",
+        path,
+        schema=alternative_schema,
+        infinite=False,
+    )
+    result = ctx.sql("SELECT * FROM alltypes_plain_schema").collect()
+    result = pa.Table.from_batches(result)
+    assert result.schema == alternative_schema
+
+
 def test_execute(ctx, tmp_path):
     data = [1, 1, 2, 2, 3, 11, 12]
 
diff --git a/src/context.rs b/src/context.rs
index 317ab78..c7f89f2 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -545,6 +545,39 @@ impl PySessionContext {
         Ok(())
     }
 
+    #[allow(clippy::too_many_arguments)]
+    #[pyo3(signature = (name,
+                        path,
+                        schema=None,
+                        file_extension=".avro",
+                        table_partition_cols=vec![],
+                        infinite=false))]
+    fn register_avro(
+        &mut self,
+        name: &str,
+        path: PathBuf,
+        schema: Option<PyArrowType<Schema>>,
+        file_extension: &str,
+        table_partition_cols: Vec<(String, String)>,
+        infinite: bool,
+        py: Python,
+    ) -> PyResult<()> {
+        let path = path
+            .to_str()
+            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a 
string"))?;
+
+        let mut options = AvroReadOptions::default()
+            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+            .mark_infinite(infinite);
+        options.file_extension = file_extension;
+        options.schema = schema.as_ref().map(|x| &x.0);
+
+        let result = self.ctx.register_avro(name, path, options);
+        wait_for_future(py, result).map_err(DataFusionError::from)?;
+
+        Ok(())
+    }
+
     // Registers a PyArrow.Dataset
     fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> 
PyResult<()> {
         let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, 
py)?);
@@ -734,6 +767,14 @@ impl PySessionContext {
         Ok(PyDataFrame::new(df))
     }
 
+    fn read_table(&self, table: &PyTable) -> PyResult<PyDataFrame> {
+        let df = self
+            .ctx
+            .read_table(table.table())
+            .map_err(DataFusionError::from)?;
+        Ok(PyDataFrame::new(df))
+    }
+
     fn __repr__(&self) -> PyResult<String> {
         let config = self.ctx.copied_config();
         let mut config_entries = config

Reply via email to