This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9e7ee7601d38 [SPARK-47903][PYTHON] Add support for remaining scalar 
types in the PySpark Variant library
9e7ee7601d38 is described below

commit 9e7ee7601d38bb76715df16c3bb8655c5667aac3
Author: Harsh Motwani <harsh.motw...@databricks.com>
AuthorDate: Tue Apr 23 08:36:34 2024 +0900

    [SPARK-47903][PYTHON] Add support for remaining scalar types in the PySpark 
Variant library
    
    ### What changes were proposed in this pull request?
    
    Added support for the `date`, `timestamp`, `timestamp_ntz`, `float` and 
`binary` scalar types to the variant library in Python. Data of these types can 
also be extracted now from a variant.
    
    ### Why are the changes needed?
    
    Support for these types was added to the Scala side as part of a recent PR. 
This PR also adds support for these data types on the PySpark side.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, users can now use PySpark to extract data of more types from Variants.
    
    ### How was this patch tested?
    
    Unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #46122 from harshmotw-db/python_scalar_variant.
    
    Authored-by: Harsh Motwani <harsh.motw...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../source/reference/pyspark.sql/variant_val.rst   |   1 +
 python/pyspark/sql/tests/test_types.py             | 115 +++++++++++++++++++-
 python/pyspark/sql/types.py                        |  13 +++
 python/pyspark/sql/variant_utils.py                | 117 ++++++++++++++++++---
 .../sql/catalyst/analysis/FunctionRegistry.scala   |   1 +
 .../catalyst/expressions/ExpectsInputTypes.scala   |   1 -
 .../expressions/variant/variantExpressions.scala   |  30 ++++++
 .../sql-functions/sql-expression-schema.md         |   1 +
 .../apache/spark/sql/VariantEndToEndSuite.scala    |  40 +++++++
 9 files changed, 301 insertions(+), 18 deletions(-)

diff --git a/python/docs/source/reference/pyspark.sql/variant_val.rst 
b/python/docs/source/reference/pyspark.sql/variant_val.rst
index a7f592c18e3a..8630ae8aace1 100644
--- a/python/docs/source/reference/pyspark.sql/variant_val.rst
+++ b/python/docs/source/reference/pyspark.sql/variant_val.rst
@@ -25,3 +25,4 @@ VariantVal
     :toctree: api/
 
     VariantVal.toPython
+    VariantVal.toJson
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index af13adbc21bb..7d45adb832c8 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -1427,8 +1427,10 @@ class TypesTestsMixin:
             ("-int4", "-69633", -69633),
             ("int8", "4295033089", 4295033089),
             ("-int8", "-4294967297", -4294967297),
-            ("float4", "1.23456789e-30", 1.23456789e-30),
-            ("-float4", "-4.56789e+29", -4.56789e29),
+            ("float4", "3.402e+38", 3.402e38),
+            ("-float4", "-3.402e+38", -3.402e38),
+            ("float8", "1.79769e+308", 1.79769e308),
+            ("-float8", "-1.79769e+308", -1.79769e308),
             ("dec4", "123.456", Decimal("123.456")),
             ("-dec4", "-321.654", Decimal("-321.654")),
             ("dec8", "429.4967297", Decimal("429.4967297")),
@@ -1447,17 +1449,77 @@ class TypesTestsMixin:
             F.struct([F.parse_json(F.lit('{"b": "2"}'))]).alias("s"),
             F.create_map([F.lit("k"), F.parse_json(F.lit('{"c": 
true}'))]).alias("m"),
         ).collect()[0]
-        variants = [row["v"], row["a"][0], row["s"]["col1"], row["m"]["k"]]
+
+        # These data types are not supported by parse_json yet so they are 
being handled
+        # separately - Date, Timestamp, TimestampNTZ, Binary, Float (Single 
Precision)
+        date_columns = self.spark.sql(
+            "select cast(Date('2021-01-01')"
+            + " as variant) as d0, cast(Date('1800-12-31')"
+            + " as variant) as d1"
+        ).collect()[0]
+        float_columns = self.spark.sql(
+            "select cast(Float(5.5)" + " as variant) as f0, cast(Float(-5.5) 
as variant) as f1"
+        ).collect()[0]
+        binary_columns = self.spark.sql(
+            "select cast(binary(x'324FA69E')" + " as variant) as b"
+        ).collect()[0]
+        timetamp_ntz_columns = self.spark.sql(
+            "select cast(cast('1940-01-01 12:33:01.123'"
+            + " as timestamp_ntz) as variant) as tntz0, cast(cast('2522-12-31 
05:57:13'"
+            + " as timestamp_ntz) as variant) as tntz1, cast(cast('0001-07-15 
17:43:26+08:00'"
+            + " as timestamp_ntz) as variant) as tntz2"
+        ).collect()[0]
+        timetamp_columns = self.spark.sql(
+            "select cast(cast('1940-01-01 12:35:13.123+7:30'"
+            + " as timestamp) as variant) as t0, cast(cast('2522-12-31 
00:00:00-5:23'"
+            + " as timestamp) as variant) as t1, cast(cast('0001-12-31 
01:01:01+08:00'"
+            + " as timestamp) as variant) as t2"
+        ).collect()[0]
+
+        variants = [
+            row["v"],
+            row["a"][0],
+            row["s"]["col1"],
+            row["m"]["k"],
+            date_columns["d0"],
+            date_columns["d1"],
+            float_columns["f0"],
+            float_columns["f1"],
+            binary_columns["b"],
+            timetamp_ntz_columns["tntz0"],
+            timetamp_ntz_columns["tntz1"],
+            timetamp_ntz_columns["tntz2"],
+            timetamp_columns["t0"],
+            timetamp_columns["t1"],
+            timetamp_columns["t2"],
+        ]
+
         for v in variants:
             self.assertEqual(type(v), VariantVal)
 
-        # check str
+        # check str (to_json)
         as_string = str(variants[0])
         for key, expected, _ in expected_values:
             self.assertTrue('"%s":%s' % (key, expected) in as_string)
         self.assertEqual(str(variants[1]), '{"a":1}')
         self.assertEqual(str(variants[2]), '{"b":"2"}')
         self.assertEqual(str(variants[3]), '{"c":true}')
+        self.assertEqual(str(variants[4]), '"2021-01-01"')
+        self.assertEqual(str(variants[5]), '"1800-12-31"')
+        self.assertEqual(str(variants[6]), "5.5")
+        self.assertEqual(str(variants[7]), "-5.5")
+        self.assertEqual(str(variants[8]), '"Mk+mng=="')
+        self.assertEqual(str(variants[9]), '"1940-01-01 12:33:01.123000"')
+        self.assertEqual(str(variants[10]), '"2522-12-31 05:57:13"')
+        self.assertEqual(str(variants[11]), '"0001-07-15 17:43:26"')
+        self.assertEqual(str(variants[12]), '"1940-01-01 
05:05:13.123000+00:00"')
+        self.assertEqual(str(variants[13]), '"2522-12-31 05:23:00+00:00"')
+        self.assertEqual(str(variants[14]), '"0001-12-30 17:01:01+00:00"')
+
+        # Check to_json on timestamps with custom timezones
+        self.assertEqual(
+            variants[12].toJson("America/Los_Angeles"), '"1939-12-31 
21:05:13.123000-08:00"'
+        )
 
         # check toPython
         as_python = variants[0].toPython()
@@ -1466,6 +1528,51 @@ class TypesTestsMixin:
         self.assertEqual(variants[1].toPython(), {"a": 1})
         self.assertEqual(variants[2].toPython(), {"b": "2"})
         self.assertEqual(variants[3].toPython(), {"c": True})
+        self.assertEqual(variants[4].toPython(), datetime.date(2021, 1, 1))
+        self.assertEqual(variants[5].toPython(), datetime.date(1800, 12, 31))
+        self.assertEqual(variants[6].toPython(), float(5.5))
+        self.assertEqual(variants[7].toPython(), float(-5.5))
+        self.assertEqual(variants[8].toPython(), bytearray(b"2O\xa6\x9e"))
+        self.assertEqual(variants[9].toPython(), datetime.datetime(1940, 1, 1, 
12, 33, 1, 123000))
+        self.assertEqual(variants[10].toPython(), datetime.datetime(2522, 12, 
31, 5, 57, 13))
+        self.assertEqual(variants[11].toPython(), datetime.datetime(1, 7, 15, 
17, 43, 26))
+        self.assertEqual(
+            variants[12].toPython(),
+            datetime.datetime(
+                1940,
+                1,
+                1,
+                12,
+                35,
+                13,
+                123000,
+                tzinfo=datetime.timezone(datetime.timedelta(hours=7, 
minutes=30)),
+            ),
+        )
+        self.assertEqual(
+            variants[13].toPython(),
+            datetime.datetime(
+                2522,
+                12,
+                31,
+                3,
+                3,
+                31,
+                tzinfo=datetime.timezone(datetime.timedelta(hours=-2, 
minutes=-20, seconds=31)),
+            ),
+        )
+        self.assertEqual(
+            variants[14].toPython(),
+            datetime.datetime(
+                1,
+                12,
+                31,
+                16,
+                3,
+                23,
+                tzinfo=datetime.timezone(datetime.timedelta(hours=23, 
minutes=2, seconds=22)),
+            ),
+        )
 
         # check repr
         self.assertEqual(str(variants[0]), str(eval(repr(variants[0]))))
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 3546fd822814..48aa3e8e4fab 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1521,6 +1521,19 @@ class VariantVal:
         """
         return VariantUtils.to_python(self.value, self.metadata)
 
+    def toJson(self, zone_id: str = "UTC") -> str:
+        """
+        Convert the VariantVal to a JSON string. The zone ID represents the 
time zone that the
+        timestamp should be printed in. It is defaulted to UTC. The list of 
valid zone IDs can be
+        found by importing the `zoneinfo` module and running 
:code:`zoneinfo.available_timezones()`.
+
+        Returns
+        -------
+        str
+            A JSON string that represents the Variant.
+        """
+        return VariantUtils.to_json(self.value, self.metadata, zone_id)
+
 
 _atomic_types: List[Type[DataType]] = [
     StringType,
diff --git a/python/pyspark/sql/variant_utils.py 
b/python/pyspark/sql/variant_utils.py
index 11dc29503921..1ee139506b91 100644
--- a/python/pyspark/sql/variant_utils.py
+++ b/python/pyspark/sql/variant_utils.py
@@ -15,12 +15,15 @@
 # limitations under the License.
 #
 
+import base64
 import decimal
+import datetime
 import json
 import struct
 from array import array
 from typing import Any, Callable, Dict, List, Tuple
 from pyspark.errors import PySparkValueError
+from zoneinfo import ZoneInfo
 
 
 class VariantUtils:
@@ -86,19 +89,41 @@ class VariantUtils:
     DECIMAL8 = 9
     # 16-byte decimal. Content is 1-byte scale + 16-byte little-endian signed 
integer.
     DECIMAL16 = 10
+    # Date value. Content is 4-byte little-endian signed integer that 
represents the number of days
+    # from the Unix epoch.
+    DATE = 11
+    # Timestamp value. Content is 8-byte little-endian signed integer that 
represents the number of
+    # microseconds elapsed since the Unix epoch, 1970-01-01 00:00:00 UTC. This 
is a timezone-aware
+    # field and when reading into a Python datetime object defaults to the UTC 
timezone.
+    TIMESTAMP = 12
+    # Timestamp_ntz value. It has the same content as `TIMESTAMP` but should 
always be interpreted
+    # as if the local time zone is UTC.
+    TIMESTAMP_NTZ = 13
+    # 4-byte IEEE float.
+    FLOAT = 14
+    # Binary value. The content is (4-byte little-endian unsigned integer 
representing the binary
+    # size) + (size bytes of binary content).
+    BINARY = 15
     # Long string value. The content is (4-byte little-endian unsigned integer 
representing the
     # string size) + (size bytes of string content).
     LONG_STR = 16
 
     U32_SIZE = 4
 
+    EPOCH = datetime.datetime(
+        year=1970, month=1, day=1, hour=0, minute=0, second=0, 
tzinfo=datetime.timezone.utc
+    )
+    EPOCH_NTZ = datetime.datetime(year=1970, month=1, day=1, hour=0, minute=0, 
second=0)
+
     @classmethod
-    def to_json(cls, value: bytes, metadata: bytes) -> str:
+    def to_json(cls, value: bytes, metadata: bytes, zone_id: str = "UTC") -> 
str:
         """
-        Convert the VariantVal to a JSON string.
+        Convert the VariantVal to a JSON string. The `zone_id` parameter 
denotes the time zone that
+        timestamp fields should be parsed in. It defaults to "UTC". The list 
of valid zone IDs can
+        found by importing the `zoneinfo` module and running 
`zoneinfo.available_timezones()`.
         :return: JSON string
         """
-        return cls._to_json(value, metadata, 0)
+        return cls._to_json(value, metadata, 0, zone_id)
 
     @classmethod
     def to_python(cls, value: bytes, metadata: bytes) -> str:
@@ -168,12 +193,41 @@ class VariantUtils:
             return cls._read_long(value, pos + 1, 1, signed=True)
         elif type_info == VariantUtils.INT2:
             return cls._read_long(value, pos + 1, 2, signed=True)
-        elif type_info == VariantUtils.INT4:
+        elif type_info == VariantUtils.INT4 or type_info == VariantUtils.DATE:
             return cls._read_long(value, pos + 1, 4, signed=True)
         elif type_info == VariantUtils.INT8:
             return cls._read_long(value, pos + 1, 8, signed=True)
         raise PySparkValueError(error_class="MALFORMED_VARIANT")
 
+    @classmethod
+    def _get_date(cls, value: bytes, pos: int) -> datetime.date:
+        cls._check_index(pos, len(value))
+        basic_type, type_info = cls._get_type_info(value, pos)
+        if basic_type != VariantUtils.PRIMITIVE:
+            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        if type_info == VariantUtils.DATE:
+            days_since_epoch = cls._read_long(value, pos + 1, 4, signed=True)
+            return datetime.date.fromordinal(VariantUtils.EPOCH.toordinal() + 
days_since_epoch)
+        raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+    @classmethod
+    def _get_timestamp(cls, value: bytes, pos: int, zone_id: str) -> 
datetime.datetime:
+        cls._check_index(pos, len(value))
+        basic_type, type_info = cls._get_type_info(value, pos)
+        if basic_type != VariantUtils.PRIMITIVE:
+            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        if type_info == VariantUtils.TIMESTAMP_NTZ:
+            microseconds_since_epoch = cls._read_long(value, pos + 1, 8, 
signed=True)
+            return VariantUtils.EPOCH_NTZ + datetime.timedelta(
+                microseconds=microseconds_since_epoch
+            )
+        if type_info == VariantUtils.TIMESTAMP:
+            microseconds_since_epoch = cls._read_long(value, pos + 1, 8, 
signed=True)
+            return (
+                VariantUtils.EPOCH + 
datetime.timedelta(microseconds=microseconds_since_epoch)
+            ).astimezone(ZoneInfo(zone_id))
+        raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
     @classmethod
     def _get_string(cls, value: bytes, pos: int) -> str:
         cls._check_index(pos, len(value))
@@ -197,9 +251,15 @@ class VariantUtils:
     def _get_double(cls, value: bytes, pos: int) -> float:
         cls._check_index(pos, len(value))
         basic_type, type_info = cls._get_type_info(value, pos)
-        if basic_type != VariantUtils.PRIMITIVE or type_info != 
VariantUtils.DOUBLE:
+        if basic_type != VariantUtils.PRIMITIVE:
             raise PySparkValueError(error_class="MALFORMED_VARIANT")
-        return struct.unpack("d", value[pos + 1 : pos + 9])[0]
+        if type_info == VariantUtils.FLOAT:
+            cls._check_index(pos + 4, len(value))
+            return struct.unpack("<f", value[pos + 1 : pos + 5])[0]
+        elif type_info == VariantUtils.DOUBLE:
+            cls._check_index(pos + 8, len(value))
+            return struct.unpack("<d", value[pos + 1 : pos + 9])[0]
+        raise PySparkValueError(error_class="MALFORMED_VARIANT")
 
     @classmethod
     def _get_decimal(cls, value: bytes, pos: int) -> decimal.Decimal:
@@ -220,6 +280,17 @@ class VariantUtils:
             raise PySparkValueError(error_class="MALFORMED_VARIANT")
         return decimal.Decimal(unscaled) * (decimal.Decimal(10) ** (-scale))
 
+    @classmethod
+    def _get_binary(cls, value: bytes, pos: int) -> bytes:
+        cls._check_index(pos, len(value))
+        basic_type, type_info = cls._get_type_info(value, pos)
+        if basic_type != VariantUtils.PRIMITIVE or type_info != 
VariantUtils.BINARY:
+            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        start = pos + 1 + VariantUtils.U32_SIZE
+        length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, 
signed=False)
+        cls._check_index(start + length - 1, len(value))
+        return bytes(value[start : start + length])
+
     @classmethod
     def _get_type(cls, value: bytes, pos: int) -> Any:
         """
@@ -244,7 +315,7 @@ class VariantUtils:
             or type_info == VariantUtils.INT8
         ):
             return int
-        elif type_info == VariantUtils.DOUBLE:
+        elif type_info == VariantUtils.DOUBLE or type_info == 
VariantUtils.FLOAT:
             return float
         elif (
             type_info == VariantUtils.DECIMAL4
@@ -252,18 +323,24 @@ class VariantUtils:
             or type_info == VariantUtils.DECIMAL16
         ):
             return decimal.Decimal
+        elif type_info == VariantUtils.BINARY:
+            return bytes
+        elif type_info == VariantUtils.DATE:
+            return datetime.date
+        elif type_info == VariantUtils.TIMESTAMP or type_info == 
VariantUtils.TIMESTAMP_NTZ:
+            return datetime.datetime
         elif type_info == VariantUtils.LONG_STR:
             return str
         raise PySparkValueError(error_class="MALFORMED_VARIANT")
 
     @classmethod
-    def _to_json(cls, value: bytes, metadata: bytes, pos: int) -> Any:
+    def _to_json(cls, value: bytes, metadata: bytes, pos: int, zone_id: str) 
-> str:
         variant_type = cls._get_type(value, pos)
         if variant_type == dict:
 
             def handle_object(key_value_pos_list: List[Tuple[str, int]]) -> 
str:
                 key_value_list = [
-                    json.dumps(key) + ":" + cls._to_json(value, metadata, 
value_pos)
+                    json.dumps(key) + ":" + cls._to_json(value, metadata, 
value_pos, zone_id)
                     for (key, value_pos) in key_value_pos_list
                 ]
                 return "{" + ",".join(key_value_list) + "}"
@@ -273,19 +350,25 @@ class VariantUtils:
 
             def handle_array(value_pos_list: List[int]) -> str:
                 value_list = [
-                    cls._to_json(value, metadata, value_pos) for value_pos in 
value_pos_list
+                    cls._to_json(value, metadata, value_pos, zone_id)
+                    for value_pos in value_pos_list
                 ]
                 return "[" + ",".join(value_list) + "]"
 
             return cls._handle_array(value, pos, handle_array)
         else:
-            value = cls._get_scalar(variant_type, value, metadata, pos)
+            value = cls._get_scalar(variant_type, value, metadata, pos, 
zone_id)
             if value is None:
                 return "null"
             if type(value) == bool:
                 return "true" if value else "false"
             if type(value) == str:
                 return json.dumps(value)
+            if type(value) == bytes:
+                # decoding simply converts byte array to string
+                return '"' + base64.b64encode(value).decode("utf-8") + '"'
+            if type(value) == datetime.date or type(value) == 
datetime.datetime:
+                return '"' + str(value) + '"'
             return str(value)
 
     @classmethod
@@ -311,10 +394,12 @@ class VariantUtils:
 
             return cls._handle_array(value, pos, handle_array)
         else:
-            return cls._get_scalar(variant_type, value, metadata, pos)
+            return cls._get_scalar(variant_type, value, metadata, pos, 
zone_id="UTC")
 
     @classmethod
-    def _get_scalar(cls, variant_type: Any, value: bytes, metadata: bytes, 
pos: int) -> Any:
+    def _get_scalar(
+        cls, variant_type: Any, value: bytes, metadata: bytes, pos: int, 
zone_id: str
+    ) -> Any:
         if isinstance(None, variant_type):
             return None
         elif variant_type == bool:
@@ -327,6 +412,12 @@ class VariantUtils:
             return cls._get_double(value, pos)
         elif variant_type == decimal.Decimal:
             return cls._get_decimal(value, pos)
+        elif variant_type == bytes:
+            return cls._get_binary(value, pos)
+        elif variant_type == datetime.date:
+            return cls._get_date(value, pos)
+        elif variant_type == datetime.datetime:
+            return cls._get_timestamp(value, pos, zone_id)
         else:
             raise PySparkValueError(error_class="MALFORMED_VARIANT")
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index e4e663d15167..5f43cc106e67 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -822,6 +822,7 @@ object FunctionRegistry {
 
     // Variant
     expression[ParseJson]("parse_json"),
+    expression[TryParseJson]("try_parse_json"),
     expression[IsVariantNull]("is_variant_null"),
     expressionBuilder("variant_get", VariantGetExpressionBuilder),
     expressionBuilder("try_variant_get", TryVariantGetExpressionBuilder),
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
index 1a4a0271c54b..66c2f736f235 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
@@ -48,7 +48,6 @@ trait ExpectsInputTypes extends Expression {
 }
 
 object ExpectsInputTypes extends QueryErrorsBase {
-
   def checkInputDataTypes(
       inputs: Seq[Expression],
       inputTypes: Seq[AbstractDataType]): TypeCheckResult = {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index 6c4a8f90e3b5..07f08aa7e70e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -75,6 +75,36 @@ case class ParseJson(child: Expression)
     copy(child = newChild)
 }
 
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "_FUNC_(jsonStr) - Parse a JSON string as an Variant value. Returns 
null when the string is not valid JSON value.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_('{"a":1,"b":0.8}');
+       {"a":1,"b":0.8}
+  """,
+  since = "4.0.0",
+  group = "variant_funcs"
+)
+// scalastyle:on line.size.limit
+case class TryParseJson(expr: Expression, replacement: Expression)
+  extends RuntimeReplaceable with InheritAnalysisRules {
+  def this(child: Expression) = this(child, TryEval(ParseJson(child)))
+
+  override def parameters: Seq[Expression] = Seq(expr)
+
+  override def dataType: DataType = VariantType
+
+  override def prettyName: String = "try_parse_json"
+
+  override protected def withNewChildInternal(newChild: Expression): 
Expression =
+    copy(replacement = newChild)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    ExpectsInputTypes.checkInputDataTypes(Seq(expr), Seq(StringType))
+  }
+}
+
 // scalastyle:off line.size.limit
 @ExpressionDescription(
   usage = "_FUNC_(expr) - Check if a variant value is a variant null. Returns 
true if and only if the input is a variant null and false otherwise (including 
in the case of SQL NULL).",
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md 
b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index ae9e68c4cbb1..a2fa30b7f364 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -440,6 +440,7 @@
 | org.apache.spark.sql.catalyst.expressions.variant.ParseJson | parse_json | 
SELECT parse_json('{"a":1,"b":0.8}') | 
struct<parse_json({"a":1,"b":0.8}):variant> |
 | org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariant | 
schema_of_variant | SELECT schema_of_variant(parse_json('null')) | 
struct<schema_of_variant(parse_json(null)):string> |
 | org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariantAgg | 
schema_of_variant_agg | SELECT schema_of_variant_agg(parse_json(j)) FROM VALUES 
('1'), ('2'), ('3') AS tab(j) | 
struct<schema_of_variant_agg(parse_json(j)):string> |
+| org.apache.spark.sql.catalyst.expressions.variant.TryParseJson | 
try_parse_json | SELECT try_parse_json('{"a":1,"b":0.8}') | 
struct<try_parse_json({"a":1,"b":0.8}):variant> |
 | 
org.apache.spark.sql.catalyst.expressions.variant.TryVariantGetExpressionBuilder
 | try_variant_get | SELECT try_variant_get(parse_json('{"a": 1}'), '$.a', 
'int') | struct<try_variant_get(parse_json({"a": 1}), $.a):int> |
 | 
org.apache.spark.sql.catalyst.expressions.variant.VariantGetExpressionBuilder | 
variant_get | SELECT variant_get(parse_json('{"a": 1}'), '$.a', 'int') | 
struct<variant_get(parse_json({"a": 1}), $.a):int> |
 | org.apache.spark.sql.catalyst.expressions.xml.XPathBoolean | xpath_boolean | 
SELECT xpath_boolean('<a><b>1</b></a>','a/b') | 
struct<xpath_boolean(<a><b>1</b></a>, a/b):boolean> |
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
index d53b49f7ab5a..96e85dc58b40 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
@@ -88,6 +88,46 @@ class VariantEndToEndSuite extends QueryTest with 
SharedSparkSession {
     check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]")
   }
 
+  test("try_parse_json/to_json round-trip") {
+    def check(input: String, output: String = "INPUT IS OUTPUT"): Unit = {
+      val df = Seq(input).toDF("v")
+      val variantDF = df.selectExpr("to_json(try_parse_json(v)) as 
v").select(Column("v"))
+      val expected = if (output != "INPUT IS OUTPUT") output else input
+      checkAnswer(variantDF, Seq(Row(expected)))
+    }
+
+    check("null")
+    check("true")
+    check("false")
+    check("-1")
+    check("1.0E10")
+    check("\"\"")
+    check("\"" + ("a" * 63) + "\"")
+    check("\"" + ("b" * 64) + "\"")
+    // scalastyle:off nonascii
+    check("\"" + ("你好,世界" * 20) + "\"")
+    // scalastyle:on nonascii
+    check("[]")
+    check("{}")
+    // scalastyle:off nonascii
+    check(
+      "[null, true,   false,-1, 1e10, \"\\uD83D\\uDE05\", [ ], { } ]",
+      "[null,true,false,-1,1.0E10,\"😅\",[],{}]"
+    )
+    // scalastyle:on nonascii
+    check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]")
+    // Places where parse_json should fail and therefore, try_parse_json 
should return null
+    check("{1:2}", null)
+    check("{\"a\":1", null)
+    check("{\"a\":[a,b,c]}", null)
+  }
+
+  test("try_parse_json with invalid input type") {
+    // This test is required because the type checking logic in try_parse_json 
is custom.
+    val exception = intercept[Exception](spark.sql("select try_parse_json(1)"))
+    assert(exception != null)
+  }
+
   test("to_json with nested variant") {
     val df = Seq(1).toDF("v")
     val variantDF1 = df.select(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to