This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 617c3554b27 [SPARK-42756][CONNECT][PYTHON] Helper function to convert proto literal to value in Python Client 617c3554b27 is described below commit 617c3554b2737a3cc3f9edc8e2685e94662c5251 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Mon Mar 13 14:31:16 2023 +0800 [SPARK-42756][CONNECT][PYTHON] Helper function to convert proto literal to value in Python Client ### What changes were proposed in this pull request? Helper function to convert proto literal to value in Python Client ### Why are the changes needed? needed in .ml ### Does this PR introduce _any_ user-facing change? no, dev-only ### How was this patch tested? added ut Closes #40376 from zhengruifeng/connect_literal_to_value. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/connect/expressions.py | 58 ++++++++++++++++++++++ .../pyspark/sql/tests/connect/test_connect_plan.py | 50 +++++++++++++++++++ 2 files changed, 108 insertions(+) diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index dbf260382f7..0e0aa49cda8 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -71,6 +71,7 @@ from pyspark.sql.connect.types import ( JVM_LONG_MAX, UnparsedDataType, pyspark_types_to_proto_types, + proto_schema_to_pyspark_data_type, ) if TYPE_CHECKING: @@ -308,6 +309,63 @@ class LiteralExpression(Expression): def _from_value(cls, value: Any) -> "LiteralExpression": return LiteralExpression(value=value, dataType=LiteralExpression._infer_type(value)) + @classmethod + def _to_value( + cls, literal: "proto.Expression.Literal", dataType: Optional[DataType] = None + ) -> Any: + if literal.HasField("null"): + return None + elif literal.HasField("binary"): + assert dataType is None or isinstance(dataType, BinaryType) + return literal.binary + elif literal.HasField("boolean"): + assert dataType is None or isinstance(dataType, BooleanType) + return literal.boolean + elif literal.HasField("byte"): + assert dataType is None or isinstance(dataType, ByteType) + return literal.byte + elif literal.HasField("short"): + assert dataType is None or isinstance(dataType, ShortType) + return literal.short + elif literal.HasField("integer"): + assert dataType is None or isinstance(dataType, IntegerType) + return literal.integer + elif literal.HasField("long"): + assert dataType is None or isinstance(dataType, LongType) + return literal.long + elif literal.HasField("float"): + assert dataType is None or isinstance(dataType, FloatType) + return literal.float + elif literal.HasField("double"): + assert dataType is None or isinstance(dataType, DoubleType) + return literal.double + elif literal.HasField("decimal"): + assert dataType is None or isinstance(dataType, DecimalType) + return decimal.Decimal(literal.decimal.value) + elif literal.HasField("string"): + assert dataType is None or isinstance(dataType, StringType) + return literal.string + elif literal.HasField("date"): + assert dataType is None or isinstance(dataType, DataType) + return DateType().fromInternal(literal.date) + elif literal.HasField("timestamp"): + assert dataType is None or isinstance(dataType, TimestampType) + return TimestampType().fromInternal(literal.timestamp) + elif literal.HasField("timestamp_ntz"): + assert dataType is None or isinstance(dataType, TimestampNTZType) + return TimestampNTZType().fromInternal(literal.timestamp_ntz) + elif literal.HasField("day_time_interval"): + assert dataType is None or isinstance(dataType, DayTimeIntervalType) + return DayTimeIntervalType().fromInternal(literal.day_time_interval) + elif literal.HasField("array"): + elementType = proto_schema_to_pyspark_data_type(literal.array.element_type) + if dataType is not None: + assert isinstance(dataType, ArrayType) + assert elementType == dataType.elementType + return [LiteralExpression._to_value(v, elementType) for v in literal.array.elements] + + raise TypeError(f"Unsupported Literal Value {literal}") + def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": """Converts the literal expression to the literal in proto.""" diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py index f627136650d..129a25098b1 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan.py @@ -42,6 +42,7 @@ if should_test_connect: IntegerType, MapType, ArrayType, + DoubleType, ) @@ -986,6 +987,55 @@ class SparkConnectPlanTests(PlanOnlyTestFixture): self.assertEqual(len(l4.array.elements[1].array.elements), 2) self.assertEqual(len(l4.array.elements[2].array.elements), 0) + def test_literal_to_any_conversion(self): + for value in [ + b"binary\0\0asas", + True, + False, + 0, + 12, + -1, + 0.0, + 1.234567, + decimal.Decimal(0.0), + decimal.Decimal(1.234567), + "sss", + datetime.date(2022, 12, 13), + datetime.datetime.now(), + datetime.timedelta(1, 2, 3), + [1, 2, 3, 4, 5, 6], + [-1.0, 2.0, 3.0], + ["x", "y", "z"], + [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]], + ]: + lit = LiteralExpression._from_value(value) + proto_lit = lit.to_plan(None).literal + value2 = LiteralExpression._to_value(proto_lit) + self.assertEqual(value, value2) + + with self.assertRaises(AssertionError): + lit = LiteralExpression._from_value(1.234567) + proto_lit = lit.to_plan(None).literal + LiteralExpression._to_value(proto_lit, StringType()) + + with self.assertRaises(AssertionError): + lit = LiteralExpression._from_value("1.234567") + proto_lit = lit.to_plan(None).literal + LiteralExpression._to_value(proto_lit, DoubleType()) + + with self.assertRaises(AssertionError): + # build a array<string> proto literal, but with incorrect elements + proto_lit = proto.Expression().literal + proto_lit.array.element_type.CopyFrom(pyspark_types_to_proto_types(StringType())) + proto_lit.array.elements.append( + LiteralExpression("string", StringType()).to_plan(None).literal + ) + proto_lit.array.elements.append( + LiteralExpression(1.234, DoubleType()).to_plan(None).literal + ) + + LiteralExpression._to_value(proto_lit, DoubleType) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_plan import * # noqa: F401 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org