This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 617c3554b27 [SPARK-42756][CONNECT][PYTHON] Helper function to convert 
proto literal to value in Python Client
617c3554b27 is described below

commit 617c3554b2737a3cc3f9edc8e2685e94662c5251
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Mon Mar 13 14:31:16 2023 +0800

    [SPARK-42756][CONNECT][PYTHON] Helper function to convert proto literal to 
value in Python Client
    
    ### What changes were proposed in this pull request?
    Helper function to convert proto literal to value in Python Client
    
    ### Why are the changes needed?
    needed in .ml
    
    ### Does this PR introduce _any_ user-facing change?
    no, dev-only
    
    ### How was this patch tested?
    added ut
    
    Closes #40376 from zhengruifeng/connect_literal_to_value.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/connect/expressions.py          | 58 ++++++++++++++++++++++
 .../pyspark/sql/tests/connect/test_connect_plan.py | 50 +++++++++++++++++++
 2 files changed, 108 insertions(+)

diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index dbf260382f7..0e0aa49cda8 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -71,6 +71,7 @@ from pyspark.sql.connect.types import (
     JVM_LONG_MAX,
     UnparsedDataType,
     pyspark_types_to_proto_types,
+    proto_schema_to_pyspark_data_type,
 )
 
 if TYPE_CHECKING:
@@ -308,6 +309,63 @@ class LiteralExpression(Expression):
     def _from_value(cls, value: Any) -> "LiteralExpression":
         return LiteralExpression(value=value, 
dataType=LiteralExpression._infer_type(value))
 
+    @classmethod
+    def _to_value(
+        cls, literal: "proto.Expression.Literal", dataType: Optional[DataType] 
= None
+    ) -> Any:
+        if literal.HasField("null"):
+            return None
+        elif literal.HasField("binary"):
+            assert dataType is None or isinstance(dataType, BinaryType)
+            return literal.binary
+        elif literal.HasField("boolean"):
+            assert dataType is None or isinstance(dataType, BooleanType)
+            return literal.boolean
+        elif literal.HasField("byte"):
+            assert dataType is None or isinstance(dataType, ByteType)
+            return literal.byte
+        elif literal.HasField("short"):
+            assert dataType is None or isinstance(dataType, ShortType)
+            return literal.short
+        elif literal.HasField("integer"):
+            assert dataType is None or isinstance(dataType, IntegerType)
+            return literal.integer
+        elif literal.HasField("long"):
+            assert dataType is None or isinstance(dataType, LongType)
+            return literal.long
+        elif literal.HasField("float"):
+            assert dataType is None or isinstance(dataType, FloatType)
+            return literal.float
+        elif literal.HasField("double"):
+            assert dataType is None or isinstance(dataType, DoubleType)
+            return literal.double
+        elif literal.HasField("decimal"):
+            assert dataType is None or isinstance(dataType, DecimalType)
+            return decimal.Decimal(literal.decimal.value)
+        elif literal.HasField("string"):
+            assert dataType is None or isinstance(dataType, StringType)
+            return literal.string
+        elif literal.HasField("date"):
+            assert dataType is None or isinstance(dataType, DataType)
+            return DateType().fromInternal(literal.date)
+        elif literal.HasField("timestamp"):
+            assert dataType is None or isinstance(dataType, TimestampType)
+            return TimestampType().fromInternal(literal.timestamp)
+        elif literal.HasField("timestamp_ntz"):
+            assert dataType is None or isinstance(dataType, TimestampNTZType)
+            return TimestampNTZType().fromInternal(literal.timestamp_ntz)
+        elif literal.HasField("day_time_interval"):
+            assert dataType is None or isinstance(dataType, 
DayTimeIntervalType)
+            return 
DayTimeIntervalType().fromInternal(literal.day_time_interval)
+        elif literal.HasField("array"):
+            elementType = 
proto_schema_to_pyspark_data_type(literal.array.element_type)
+            if dataType is not None:
+                assert isinstance(dataType, ArrayType)
+                assert elementType == dataType.elementType
+            return [LiteralExpression._to_value(v, elementType) for v in 
literal.array.elements]
+
+        raise TypeError(f"Unsupported Literal Value {literal}")
+
     def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
         """Converts the literal expression to the literal in proto."""
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py 
b/python/pyspark/sql/tests/connect/test_connect_plan.py
index f627136650d..129a25098b1 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -42,6 +42,7 @@ if should_test_connect:
         IntegerType,
         MapType,
         ArrayType,
+        DoubleType,
     )
 
 
@@ -986,6 +987,55 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
         self.assertEqual(len(l4.array.elements[1].array.elements), 2)
         self.assertEqual(len(l4.array.elements[2].array.elements), 0)
 
+    def test_literal_to_any_conversion(self):
+        for value in [
+            b"binary\0\0asas",
+            True,
+            False,
+            0,
+            12,
+            -1,
+            0.0,
+            1.234567,
+            decimal.Decimal(0.0),
+            decimal.Decimal(1.234567),
+            "sss",
+            datetime.date(2022, 12, 13),
+            datetime.datetime.now(),
+            datetime.timedelta(1, 2, 3),
+            [1, 2, 3, 4, 5, 6],
+            [-1.0, 2.0, 3.0],
+            ["x", "y", "z"],
+            [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]],
+        ]:
+            lit = LiteralExpression._from_value(value)
+            proto_lit = lit.to_plan(None).literal
+            value2 = LiteralExpression._to_value(proto_lit)
+            self.assertEqual(value, value2)
+
+        with self.assertRaises(AssertionError):
+            lit = LiteralExpression._from_value(1.234567)
+            proto_lit = lit.to_plan(None).literal
+            LiteralExpression._to_value(proto_lit, StringType())
+
+        with self.assertRaises(AssertionError):
+            lit = LiteralExpression._from_value("1.234567")
+            proto_lit = lit.to_plan(None).literal
+            LiteralExpression._to_value(proto_lit, DoubleType())
+
+        with self.assertRaises(AssertionError):
+            # build a array<string> proto literal, but with incorrect elements
+            proto_lit = proto.Expression().literal
+            
proto_lit.array.element_type.CopyFrom(pyspark_types_to_proto_types(StringType()))
+            proto_lit.array.elements.append(
+                LiteralExpression("string", StringType()).to_plan(None).literal
+            )
+            proto_lit.array.elements.append(
+                LiteralExpression(1.234, DoubleType()).to_plan(None).literal
+            )
+
+            LiteralExpression._to_value(proto_lit, DoubleType)
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.test_connect_plan import *  # noqa: F401


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to