This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new c0be61b  LogicalPlan.to_variant() make public (#412)
c0be61b is described below

commit c0be61bd575ae717ec8c949693abd6cdec39a50c
Author: Jeremy Dyer <[email protected]>
AuthorDate: Mon Jul 3 11:59:59 2023 -0400

    LogicalPlan.to_variant() make public (#412)
    
    * Make to_variant public
    
    * Add to_variant() coverage for Subquery and SubqueryAlias
    
    * Add variant_name() function to Expr
    
    * Add function for friendly display of arrow DataType to Python layer since 
DataType enum values cannot be directly accessed
    
    * Change signature
    
    * Updated to include Decimal128 variant coverage
    
    * Add try_from to PyCreateView
    
    * Cargo lint fixes
---
 src/common/data_type.rs | 44 ++++++++++++++++++++++++++++++++++++++++++++
 src/expr.rs             | 16 +++++++++++-----
 src/expr/create_view.rs | 15 +++++++++++++--
 src/expr/subquery.rs    | 44 ++++++++++++++++++++++++++++++++++++++++++++
 src/sql/logical.rs      |  6 +++++-
 5 files changed, 117 insertions(+), 8 deletions(-)

diff --git a/src/common/data_type.rs b/src/common/data_type.rs
index 622e1aa..85d2feb 100644
--- a/src/common/data_type.rs
+++ b/src/common/data_type.rs
@@ -503,6 +503,50 @@ impl DataTypeMap {
             )),
         }
     }
+
+    /// Unfortunately PyO3 does not allow for us to expose the DataType as an 
enum since
+    /// we cannot directly annotae the Enum instance of dependency code. 
Therefore, here
+    /// we provide an enum to mimic it.
+    #[pyo3(name = "friendly_arrow_type_name")]
+    pub fn friendly_arrow_type_name(&self) -> PyResult<&str> {
+        Ok(match &self.arrow_type.data_type {
+            DataType::Null => "Null",
+            DataType::Boolean => "Boolean",
+            DataType::Int8 => "Int8",
+            DataType::Int16 => "Int16",
+            DataType::Int32 => "Int32",
+            DataType::Int64 => "Int64",
+            DataType::UInt8 => "UInt8",
+            DataType::UInt16 => "UInt16",
+            DataType::UInt32 => "UInt32",
+            DataType::UInt64 => "UInt64",
+            DataType::Float16 => "Float16",
+            DataType::Float32 => "Float32",
+            DataType::Float64 => "Float64",
+            DataType::Timestamp(_, _) => "Timestamp",
+            DataType::Date32 => "Date32",
+            DataType::Date64 => "Date64",
+            DataType::Time32(_) => "Time32",
+            DataType::Time64(_) => "Time64",
+            DataType::Duration(_) => "Duration",
+            DataType::Interval(_) => "Interval",
+            DataType::Binary => "Binary",
+            DataType::FixedSizeBinary(_) => "FixedSizeBinary",
+            DataType::LargeBinary => "LargeBinary",
+            DataType::Utf8 => "Utf8",
+            DataType::LargeUtf8 => "LargeUtf8",
+            DataType::List(_) => "List",
+            DataType::FixedSizeList(_, _) => "FixedSizeList",
+            DataType::LargeList(_) => "LargeList",
+            DataType::Struct(_) => "Struct",
+            DataType::Union(_, _) => "Union",
+            DataType::Dictionary(_, _) => "Dictionary",
+            DataType::Decimal128(_, _) => "Decimal128",
+            DataType::Decimal256(_, _) => "Decimal256",
+            DataType::Map(_, _) => "Map",
+            DataType::RunEndEncoded(_, _) => "RunEndEncoded",
+        })
+    }
 }
 
 /// PyO3 requires that objects passed between Rust and Python implement the 
trait `PyClass`
diff --git a/src/expr.rs b/src/expr.rs
index d519d0c..17b6c34 100644
--- a/src/expr.rs
+++ b/src/expr.rs
@@ -153,6 +153,12 @@ impl PyExpr {
         Ok(self.expr.canonical_name())
     }
 
+    /// Returns the name of the Expr variant.
+    /// Ex: 'IsNotNull', 'Literal', 'BinaryExpr', etc
+    fn variant_name(&self) -> PyResult<&str> {
+        Ok(self.expr.variant_name())
+    }
+
     fn __richcmp__(&self, other: PyExpr, op: CompareOp) -> PyExpr {
         let expr = match op {
             CompareOp::Lt => self.expr.clone().lt(other.expr),
@@ -302,7 +308,7 @@ impl PyExpr {
                 ScalarValue::Boolean(v) => v.into_py(py),
                 ScalarValue::Float32(v) => v.into_py(py),
                 ScalarValue::Float64(v) => v.into_py(py),
-                ScalarValue::Decimal128(_, _, _) => todo!(),
+                ScalarValue::Decimal128(v, _, _) => v.into_py(py),
                 ScalarValue::Int8(v) => v.into_py(py),
                 ScalarValue::Int16(v) => v.into_py(py),
                 ScalarValue::Int32(v) => v.into_py(py),
@@ -323,10 +329,10 @@ impl PyExpr {
                 ScalarValue::Time32Millisecond(v) => v.into_py(py),
                 ScalarValue::Time64Microsecond(v) => v.into_py(py),
                 ScalarValue::Time64Nanosecond(v) => v.into_py(py),
-                ScalarValue::TimestampSecond(_, _) => todo!(),
-                ScalarValue::TimestampMillisecond(_, _) => todo!(),
-                ScalarValue::TimestampMicrosecond(_, _) => todo!(),
-                ScalarValue::TimestampNanosecond(_, _) => todo!(),
+                ScalarValue::TimestampSecond(v, _) => v.into_py(py),
+                ScalarValue::TimestampMillisecond(v, _) => v.into_py(py),
+                ScalarValue::TimestampMicrosecond(v, _) => v.into_py(py),
+                ScalarValue::TimestampNanosecond(v, _) => v.into_py(py),
                 ScalarValue::IntervalYearMonth(v) => v.into_py(py),
                 ScalarValue::IntervalDayTime(v) => v.into_py(py),
                 ScalarValue::IntervalMonthDayNano(v) => v.into_py(py),
diff --git a/src/expr/create_view.rs b/src/expr/create_view.rs
index 9d06239..febd723 100644
--- a/src/expr/create_view.rs
+++ b/src/expr/create_view.rs
@@ -17,10 +17,10 @@
 
 use std::fmt::{self, Display, Formatter};
 
-use datafusion_expr::CreateView;
+use datafusion_expr::{CreateView, DdlStatement, LogicalPlan};
 use pyo3::prelude::*;
 
-use crate::sql::logical::PyLogicalPlan;
+use crate::{errors::py_type_err, sql::logical::PyLogicalPlan};
 
 use super::logical_node::LogicalNode;
 
@@ -92,3 +92,14 @@ impl LogicalNode for PyCreateView {
         Ok(self.clone().into_py(py))
     }
 }
+
+impl TryFrom<LogicalPlan> for PyCreateView {
+    type Error = PyErr;
+
+    fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {
+        match logical_plan {
+            LogicalPlan::Ddl(DdlStatement::CreateView(create)) => 
Ok(PyCreateView { create }),
+            _ => Err(py_type_err("unexpected plan")),
+        }
+    }
+}
diff --git a/src/expr/subquery.rs b/src/expr/subquery.rs
index 93ff244..f6f7b7f 100644
--- a/src/expr/subquery.rs
+++ b/src/expr/subquery.rs
@@ -15,9 +15,15 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::fmt::{self, Display, Formatter};
+
 use datafusion_expr::Subquery;
 use pyo3::prelude::*;
 
+use crate::sql::logical::PyLogicalPlan;
+
+use super::logical_node::LogicalNode;
+
 #[pyclass(name = "Subquery", module = "datafusion.expr", subclass)]
 #[derive(Clone)]
 pub struct PySubquery {
@@ -35,3 +41,41 @@ impl From<Subquery> for PySubquery {
         PySubquery { subquery }
     }
 }
+
+impl Display for PySubquery {
+    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+        write!(
+            f,
+            "Subquery
+            Subquery: {:?}
+            outer_ref_columns: {:?}",
+            self.subquery.subquery, self.subquery.outer_ref_columns,
+        )
+    }
+}
+
+#[pymethods]
+impl PySubquery {
+    /// Retrieves the input `LogicalPlan` to this `Projection` node
+    fn input(&self) -> PyResult<Vec<PyLogicalPlan>> {
+        Ok(Self::inputs(self))
+    }
+
+    fn __repr__(&self) -> PyResult<String> {
+        Ok(format!("Subquery({})", self))
+    }
+
+    fn __name__(&self) -> PyResult<String> {
+        Ok("Subquery".to_string())
+    }
+}
+
+impl LogicalNode for PySubquery {
+    fn inputs(&self) -> Vec<PyLogicalPlan> {
+        vec![]
+    }
+
+    fn to_variant(&self, py: Python) -> PyResult<PyObject> {
+        Ok(self.clone().into_py(py))
+    }
+}
diff --git a/src/sql/logical.rs b/src/sql/logical.rs
index 07a3f65..2183155 100644
--- a/src/sql/logical.rs
+++ b/src/sql/logical.rs
@@ -28,6 +28,8 @@ use crate::expr::filter::PyFilter;
 use crate::expr::limit::PyLimit;
 use crate::expr::projection::PyProjection;
 use crate::expr::sort::PySort;
+use crate::expr::subquery::PySubquery;
+use crate::expr::subquery_alias::PySubqueryAlias;
 use crate::expr::table_scan::PyTableScan;
 use datafusion_expr::LogicalPlan;
 use pyo3::prelude::*;
@@ -56,7 +58,7 @@ impl PyLogicalPlan {
 #[pymethods]
 impl PyLogicalPlan {
     /// Return the specific logical operator
-    fn to_variant(&self, py: Python) -> PyResult<PyObject> {
+    pub fn to_variant(&self, py: Python) -> PyResult<PyObject> {
         Python::with_gil(|_| match self.plan.as_ref() {
             LogicalPlan::Aggregate(plan) => 
PyAggregate::from(plan.clone()).to_variant(py),
             LogicalPlan::Analyze(plan) => 
PyAnalyze::from(plan.clone()).to_variant(py),
@@ -69,6 +71,8 @@ impl PyLogicalPlan {
             LogicalPlan::Projection(plan) => 
PyProjection::from(plan.clone()).to_variant(py),
             LogicalPlan::Sort(plan) => 
PySort::from(plan.clone()).to_variant(py),
             LogicalPlan::TableScan(plan) => 
PyTableScan::from(plan.clone()).to_variant(py),
+            LogicalPlan::Subquery(plan) => 
PySubquery::from(plan.clone()).to_variant(py),
+            LogicalPlan::SubqueryAlias(plan) => 
PySubqueryAlias::from(plan.clone()).to_variant(py),
             other => Err(py_unsupported_variant_err(format!(
                 "Cannot convert this plan to a LogicalNode: {:?}",
                 other

Reply via email to