This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9e7ee7601d38 [SPARK-47903][PYTHON] Add support for remaining scalar types in the PySpark Variant library 9e7ee7601d38 is described below commit 9e7ee7601d38bb76715df16c3bb8655c5667aac3 Author: Harsh Motwani <harsh.motw...@databricks.com> AuthorDate: Tue Apr 23 08:36:34 2024 +0900 [SPARK-47903][PYTHON] Add support for remaining scalar types in the PySpark Variant library ### What changes were proposed in this pull request? Added support for the `date`, `timestamp`, `timestamp_ntz`, `float` and `binary` scalar types to the variant library in Python. Data of these types can also be extracted now from a variant. ### Why are the changes needed? Support for these types was added to the Scala side as part of a recent PR. This PR also adds support for these data types on the PySpark side. ### Does this PR introduce _any_ user-facing change? Yes, users can now use PySpark to extract data of more types from Variants. ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #46122 from harshmotw-db/python_scalar_variant. Authored-by: Harsh Motwani <harsh.motw...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../source/reference/pyspark.sql/variant_val.rst | 1 + python/pyspark/sql/tests/test_types.py | 115 +++++++++++++++++++- python/pyspark/sql/types.py | 13 +++ python/pyspark/sql/variant_utils.py | 117 ++++++++++++++++++--- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../catalyst/expressions/ExpectsInputTypes.scala | 1 - .../expressions/variant/variantExpressions.scala | 30 ++++++ .../sql-functions/sql-expression-schema.md | 1 + .../apache/spark/sql/VariantEndToEndSuite.scala | 40 +++++++ 9 files changed, 301 insertions(+), 18 deletions(-) diff --git a/python/docs/source/reference/pyspark.sql/variant_val.rst b/python/docs/source/reference/pyspark.sql/variant_val.rst index a7f592c18e3a..8630ae8aace1 100644 --- a/python/docs/source/reference/pyspark.sql/variant_val.rst +++ b/python/docs/source/reference/pyspark.sql/variant_val.rst @@ -25,3 +25,4 @@ VariantVal :toctree: api/ VariantVal.toPython + VariantVal.toJson diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index af13adbc21bb..7d45adb832c8 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -1427,8 +1427,10 @@ class TypesTestsMixin: ("-int4", "-69633", -69633), ("int8", "4295033089", 4295033089), ("-int8", "-4294967297", -4294967297), - ("float4", "1.23456789e-30", 1.23456789e-30), - ("-float4", "-4.56789e+29", -4.56789e29), + ("float4", "3.402e+38", 3.402e38), + ("-float4", "-3.402e+38", -3.402e38), + ("float8", "1.79769e+308", 1.79769e308), + ("-float8", "-1.79769e+308", -1.79769e308), ("dec4", "123.456", Decimal("123.456")), ("-dec4", "-321.654", Decimal("-321.654")), ("dec8", "429.4967297", Decimal("429.4967297")), @@ -1447,17 +1449,77 @@ class TypesTestsMixin: F.struct([F.parse_json(F.lit('{"b": "2"}'))]).alias("s"), F.create_map([F.lit("k"), F.parse_json(F.lit('{"c": true}'))]).alias("m"), ).collect()[0] - variants = [row["v"], row["a"][0], row["s"]["col1"], row["m"]["k"]] + + # These data types are not supported by parse_json yet so they are being handled + # separately - Date, Timestamp, TimestampNTZ, Binary, Float (Single Precision) + date_columns = self.spark.sql( + "select cast(Date('2021-01-01')" + + " as variant) as d0, cast(Date('1800-12-31')" + + " as variant) as d1" + ).collect()[0] + float_columns = self.spark.sql( + "select cast(Float(5.5)" + " as variant) as f0, cast(Float(-5.5) as variant) as f1" + ).collect()[0] + binary_columns = self.spark.sql( + "select cast(binary(x'324FA69E')" + " as variant) as b" + ).collect()[0] + timetamp_ntz_columns = self.spark.sql( + "select cast(cast('1940-01-01 12:33:01.123'" + + " as timestamp_ntz) as variant) as tntz0, cast(cast('2522-12-31 05:57:13'" + + " as timestamp_ntz) as variant) as tntz1, cast(cast('0001-07-15 17:43:26+08:00'" + + " as timestamp_ntz) as variant) as tntz2" + ).collect()[0] + timetamp_columns = self.spark.sql( + "select cast(cast('1940-01-01 12:35:13.123+7:30'" + + " as timestamp) as variant) as t0, cast(cast('2522-12-31 00:00:00-5:23'" + + " as timestamp) as variant) as t1, cast(cast('0001-12-31 01:01:01+08:00'" + + " as timestamp) as variant) as t2" + ).collect()[0] + + variants = [ + row["v"], + row["a"][0], + row["s"]["col1"], + row["m"]["k"], + date_columns["d0"], + date_columns["d1"], + float_columns["f0"], + float_columns["f1"], + binary_columns["b"], + timetamp_ntz_columns["tntz0"], + timetamp_ntz_columns["tntz1"], + timetamp_ntz_columns["tntz2"], + timetamp_columns["t0"], + timetamp_columns["t1"], + timetamp_columns["t2"], + ] + for v in variants: self.assertEqual(type(v), VariantVal) - # check str + # check str (to_json) as_string = str(variants[0]) for key, expected, _ in expected_values: self.assertTrue('"%s":%s' % (key, expected) in as_string) self.assertEqual(str(variants[1]), '{"a":1}') self.assertEqual(str(variants[2]), '{"b":"2"}') self.assertEqual(str(variants[3]), '{"c":true}') + self.assertEqual(str(variants[4]), '"2021-01-01"') + self.assertEqual(str(variants[5]), '"1800-12-31"') + self.assertEqual(str(variants[6]), "5.5") + self.assertEqual(str(variants[7]), "-5.5") + self.assertEqual(str(variants[8]), '"Mk+mng=="') + self.assertEqual(str(variants[9]), '"1940-01-01 12:33:01.123000"') + self.assertEqual(str(variants[10]), '"2522-12-31 05:57:13"') + self.assertEqual(str(variants[11]), '"0001-07-15 17:43:26"') + self.assertEqual(str(variants[12]), '"1940-01-01 05:05:13.123000+00:00"') + self.assertEqual(str(variants[13]), '"2522-12-31 05:23:00+00:00"') + self.assertEqual(str(variants[14]), '"0001-12-30 17:01:01+00:00"') + + # Check to_json on timestamps with custom timezones + self.assertEqual( + variants[12].toJson("America/Los_Angeles"), '"1939-12-31 21:05:13.123000-08:00"' + ) # check toPython as_python = variants[0].toPython() @@ -1466,6 +1528,51 @@ class TypesTestsMixin: self.assertEqual(variants[1].toPython(), {"a": 1}) self.assertEqual(variants[2].toPython(), {"b": "2"}) self.assertEqual(variants[3].toPython(), {"c": True}) + self.assertEqual(variants[4].toPython(), datetime.date(2021, 1, 1)) + self.assertEqual(variants[5].toPython(), datetime.date(1800, 12, 31)) + self.assertEqual(variants[6].toPython(), float(5.5)) + self.assertEqual(variants[7].toPython(), float(-5.5)) + self.assertEqual(variants[8].toPython(), bytearray(b"2O\xa6\x9e")) + self.assertEqual(variants[9].toPython(), datetime.datetime(1940, 1, 1, 12, 33, 1, 123000)) + self.assertEqual(variants[10].toPython(), datetime.datetime(2522, 12, 31, 5, 57, 13)) + self.assertEqual(variants[11].toPython(), datetime.datetime(1, 7, 15, 17, 43, 26)) + self.assertEqual( + variants[12].toPython(), + datetime.datetime( + 1940, + 1, + 1, + 12, + 35, + 13, + 123000, + tzinfo=datetime.timezone(datetime.timedelta(hours=7, minutes=30)), + ), + ) + self.assertEqual( + variants[13].toPython(), + datetime.datetime( + 2522, + 12, + 31, + 3, + 3, + 31, + tzinfo=datetime.timezone(datetime.timedelta(hours=-2, minutes=-20, seconds=31)), + ), + ) + self.assertEqual( + variants[14].toPython(), + datetime.datetime( + 1, + 12, + 31, + 16, + 3, + 23, + tzinfo=datetime.timezone(datetime.timedelta(hours=23, minutes=2, seconds=22)), + ), + ) # check repr self.assertEqual(str(variants[0]), str(eval(repr(variants[0])))) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 3546fd822814..48aa3e8e4fab 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1521,6 +1521,19 @@ class VariantVal: """ return VariantUtils.to_python(self.value, self.metadata) + def toJson(self, zone_id: str = "UTC") -> str: + """ + Convert the VariantVal to a JSON string. The zone ID represents the time zone that the + timestamp should be printed in. It is defaulted to UTC. The list of valid zone IDs can be + found by importing the `zoneinfo` module and running :code:`zoneinfo.available_timezones()`. + + Returns + ------- + str + A JSON string that represents the Variant. + """ + return VariantUtils.to_json(self.value, self.metadata, zone_id) + _atomic_types: List[Type[DataType]] = [ StringType, diff --git a/python/pyspark/sql/variant_utils.py b/python/pyspark/sql/variant_utils.py index 11dc29503921..1ee139506b91 100644 --- a/python/pyspark/sql/variant_utils.py +++ b/python/pyspark/sql/variant_utils.py @@ -15,12 +15,15 @@ # limitations under the License. # +import base64 import decimal +import datetime import json import struct from array import array from typing import Any, Callable, Dict, List, Tuple from pyspark.errors import PySparkValueError +from zoneinfo import ZoneInfo class VariantUtils: @@ -86,19 +89,41 @@ class VariantUtils: DECIMAL8 = 9 # 16-byte decimal. Content is 1-byte scale + 16-byte little-endian signed integer. DECIMAL16 = 10 + # Date value. Content is 4-byte little-endian signed integer that represents the number of days + # from the Unix epoch. + DATE = 11 + # Timestamp value. Content is 8-byte little-endian signed integer that represents the number of + # microseconds elapsed since the Unix epoch, 1970-01-01 00:00:00 UTC. This is a timezone-aware + # field and when reading into a Python datetime object defaults to the UTC timezone. + TIMESTAMP = 12 + # Timestamp_ntz value. It has the same content as `TIMESTAMP` but should always be interpreted + # as if the local time zone is UTC. + TIMESTAMP_NTZ = 13 + # 4-byte IEEE float. + FLOAT = 14 + # Binary value. The content is (4-byte little-endian unsigned integer representing the binary + # size) + (size bytes of binary content). + BINARY = 15 # Long string value. The content is (4-byte little-endian unsigned integer representing the # string size) + (size bytes of string content). LONG_STR = 16 U32_SIZE = 4 + EPOCH = datetime.datetime( + year=1970, month=1, day=1, hour=0, minute=0, second=0, tzinfo=datetime.timezone.utc + ) + EPOCH_NTZ = datetime.datetime(year=1970, month=1, day=1, hour=0, minute=0, second=0) + @classmethod - def to_json(cls, value: bytes, metadata: bytes) -> str: + def to_json(cls, value: bytes, metadata: bytes, zone_id: str = "UTC") -> str: """ - Convert the VariantVal to a JSON string. + Convert the VariantVal to a JSON string. The `zone_id` parameter denotes the time zone that + timestamp fields should be parsed in. It defaults to "UTC". The list of valid zone IDs can + found by importing the `zoneinfo` module and running `zoneinfo.available_timezones()`. :return: JSON string """ - return cls._to_json(value, metadata, 0) + return cls._to_json(value, metadata, 0, zone_id) @classmethod def to_python(cls, value: bytes, metadata: bytes) -> str: @@ -168,12 +193,41 @@ class VariantUtils: return cls._read_long(value, pos + 1, 1, signed=True) elif type_info == VariantUtils.INT2: return cls._read_long(value, pos + 1, 2, signed=True) - elif type_info == VariantUtils.INT4: + elif type_info == VariantUtils.INT4 or type_info == VariantUtils.DATE: return cls._read_long(value, pos + 1, 4, signed=True) elif type_info == VariantUtils.INT8: return cls._read_long(value, pos + 1, 8, signed=True) raise PySparkValueError(error_class="MALFORMED_VARIANT") + @classmethod + def _get_date(cls, value: bytes, pos: int) -> datetime.date: + cls._check_index(pos, len(value)) + basic_type, type_info = cls._get_type_info(value, pos) + if basic_type != VariantUtils.PRIMITIVE: + raise PySparkValueError(error_class="MALFORMED_VARIANT") + if type_info == VariantUtils.DATE: + days_since_epoch = cls._read_long(value, pos + 1, 4, signed=True) + return datetime.date.fromordinal(VariantUtils.EPOCH.toordinal() + days_since_epoch) + raise PySparkValueError(error_class="MALFORMED_VARIANT") + + @classmethod + def _get_timestamp(cls, value: bytes, pos: int, zone_id: str) -> datetime.datetime: + cls._check_index(pos, len(value)) + basic_type, type_info = cls._get_type_info(value, pos) + if basic_type != VariantUtils.PRIMITIVE: + raise PySparkValueError(error_class="MALFORMED_VARIANT") + if type_info == VariantUtils.TIMESTAMP_NTZ: + microseconds_since_epoch = cls._read_long(value, pos + 1, 8, signed=True) + return VariantUtils.EPOCH_NTZ + datetime.timedelta( + microseconds=microseconds_since_epoch + ) + if type_info == VariantUtils.TIMESTAMP: + microseconds_since_epoch = cls._read_long(value, pos + 1, 8, signed=True) + return ( + VariantUtils.EPOCH + datetime.timedelta(microseconds=microseconds_since_epoch) + ).astimezone(ZoneInfo(zone_id)) + raise PySparkValueError(error_class="MALFORMED_VARIANT") + @classmethod def _get_string(cls, value: bytes, pos: int) -> str: cls._check_index(pos, len(value)) @@ -197,9 +251,15 @@ class VariantUtils: def _get_double(cls, value: bytes, pos: int) -> float: cls._check_index(pos, len(value)) basic_type, type_info = cls._get_type_info(value, pos) - if basic_type != VariantUtils.PRIMITIVE or type_info != VariantUtils.DOUBLE: + if basic_type != VariantUtils.PRIMITIVE: raise PySparkValueError(error_class="MALFORMED_VARIANT") - return struct.unpack("d", value[pos + 1 : pos + 9])[0] + if type_info == VariantUtils.FLOAT: + cls._check_index(pos + 4, len(value)) + return struct.unpack("<f", value[pos + 1 : pos + 5])[0] + elif type_info == VariantUtils.DOUBLE: + cls._check_index(pos + 8, len(value)) + return struct.unpack("<d", value[pos + 1 : pos + 9])[0] + raise PySparkValueError(error_class="MALFORMED_VARIANT") @classmethod def _get_decimal(cls, value: bytes, pos: int) -> decimal.Decimal: @@ -220,6 +280,17 @@ class VariantUtils: raise PySparkValueError(error_class="MALFORMED_VARIANT") return decimal.Decimal(unscaled) * (decimal.Decimal(10) ** (-scale)) + @classmethod + def _get_binary(cls, value: bytes, pos: int) -> bytes: + cls._check_index(pos, len(value)) + basic_type, type_info = cls._get_type_info(value, pos) + if basic_type != VariantUtils.PRIMITIVE or type_info != VariantUtils.BINARY: + raise PySparkValueError(error_class="MALFORMED_VARIANT") + start = pos + 1 + VariantUtils.U32_SIZE + length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False) + cls._check_index(start + length - 1, len(value)) + return bytes(value[start : start + length]) + @classmethod def _get_type(cls, value: bytes, pos: int) -> Any: """ @@ -244,7 +315,7 @@ class VariantUtils: or type_info == VariantUtils.INT8 ): return int - elif type_info == VariantUtils.DOUBLE: + elif type_info == VariantUtils.DOUBLE or type_info == VariantUtils.FLOAT: return float elif ( type_info == VariantUtils.DECIMAL4 @@ -252,18 +323,24 @@ class VariantUtils: or type_info == VariantUtils.DECIMAL16 ): return decimal.Decimal + elif type_info == VariantUtils.BINARY: + return bytes + elif type_info == VariantUtils.DATE: + return datetime.date + elif type_info == VariantUtils.TIMESTAMP or type_info == VariantUtils.TIMESTAMP_NTZ: + return datetime.datetime elif type_info == VariantUtils.LONG_STR: return str raise PySparkValueError(error_class="MALFORMED_VARIANT") @classmethod - def _to_json(cls, value: bytes, metadata: bytes, pos: int) -> Any: + def _to_json(cls, value: bytes, metadata: bytes, pos: int, zone_id: str) -> str: variant_type = cls._get_type(value, pos) if variant_type == dict: def handle_object(key_value_pos_list: List[Tuple[str, int]]) -> str: key_value_list = [ - json.dumps(key) + ":" + cls._to_json(value, metadata, value_pos) + json.dumps(key) + ":" + cls._to_json(value, metadata, value_pos, zone_id) for (key, value_pos) in key_value_pos_list ] return "{" + ",".join(key_value_list) + "}" @@ -273,19 +350,25 @@ class VariantUtils: def handle_array(value_pos_list: List[int]) -> str: value_list = [ - cls._to_json(value, metadata, value_pos) for value_pos in value_pos_list + cls._to_json(value, metadata, value_pos, zone_id) + for value_pos in value_pos_list ] return "[" + ",".join(value_list) + "]" return cls._handle_array(value, pos, handle_array) else: - value = cls._get_scalar(variant_type, value, metadata, pos) + value = cls._get_scalar(variant_type, value, metadata, pos, zone_id) if value is None: return "null" if type(value) == bool: return "true" if value else "false" if type(value) == str: return json.dumps(value) + if type(value) == bytes: + # decoding simply converts byte array to string + return '"' + base64.b64encode(value).decode("utf-8") + '"' + if type(value) == datetime.date or type(value) == datetime.datetime: + return '"' + str(value) + '"' return str(value) @classmethod @@ -311,10 +394,12 @@ class VariantUtils: return cls._handle_array(value, pos, handle_array) else: - return cls._get_scalar(variant_type, value, metadata, pos) + return cls._get_scalar(variant_type, value, metadata, pos, zone_id="UTC") @classmethod - def _get_scalar(cls, variant_type: Any, value: bytes, metadata: bytes, pos: int) -> Any: + def _get_scalar( + cls, variant_type: Any, value: bytes, metadata: bytes, pos: int, zone_id: str + ) -> Any: if isinstance(None, variant_type): return None elif variant_type == bool: @@ -327,6 +412,12 @@ class VariantUtils: return cls._get_double(value, pos) elif variant_type == decimal.Decimal: return cls._get_decimal(value, pos) + elif variant_type == bytes: + return cls._get_binary(value, pos) + elif variant_type == datetime.date: + return cls._get_date(value, pos) + elif variant_type == datetime.datetime: + return cls._get_timestamp(value, pos, zone_id) else: raise PySparkValueError(error_class="MALFORMED_VARIANT") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e4e663d15167..5f43cc106e67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -822,6 +822,7 @@ object FunctionRegistry { // Variant expression[ParseJson]("parse_json"), + expression[TryParseJson]("try_parse_json"), expression[IsVariantNull]("is_variant_null"), expressionBuilder("variant_get", VariantGetExpressionBuilder), expressionBuilder("try_variant_get", TryVariantGetExpressionBuilder), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 1a4a0271c54b..66c2f736f235 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -48,7 +48,6 @@ trait ExpectsInputTypes extends Expression { } object ExpectsInputTypes extends QueryErrorsBase { - def checkInputDataTypes( inputs: Seq[Expression], inputTypes: Seq[AbstractDataType]): TypeCheckResult = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index 6c4a8f90e3b5..07f08aa7e70e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -75,6 +75,36 @@ case class ParseJson(child: Expression) copy(child = newChild) } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(jsonStr) - Parse a JSON string as an Variant value. Returns null when the string is not valid JSON value.", + examples = """ + Examples: + > SELECT _FUNC_('{"a":1,"b":0.8}'); + {"a":1,"b":0.8} + """, + since = "4.0.0", + group = "variant_funcs" +) +// scalastyle:on line.size.limit +case class TryParseJson(expr: Expression, replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { + def this(child: Expression) = this(child, TryEval(ParseJson(child))) + + override def parameters: Seq[Expression] = Seq(expr) + + override def dataType: DataType = VariantType + + override def prettyName: String = "try_parse_json" + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(replacement = newChild) + + override def checkInputDataTypes(): TypeCheckResult = { + ExpectsInputTypes.checkInputDataTypes(Seq(expr), Seq(StringType)) + } +} + // scalastyle:off line.size.limit @ExpressionDescription( usage = "_FUNC_(expr) - Check if a variant value is a variant null. Returns true if and only if the input is a variant null and false otherwise (including in the case of SQL NULL).", diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index ae9e68c4cbb1..a2fa30b7f364 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -440,6 +440,7 @@ | org.apache.spark.sql.catalyst.expressions.variant.ParseJson | parse_json | SELECT parse_json('{"a":1,"b":0.8}') | struct<parse_json({"a":1,"b":0.8}):variant> | | org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariant | schema_of_variant | SELECT schema_of_variant(parse_json('null')) | struct<schema_of_variant(parse_json(null)):string> | | org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariantAgg | schema_of_variant_agg | SELECT schema_of_variant_agg(parse_json(j)) FROM VALUES ('1'), ('2'), ('3') AS tab(j) | struct<schema_of_variant_agg(parse_json(j)):string> | +| org.apache.spark.sql.catalyst.expressions.variant.TryParseJson | try_parse_json | SELECT try_parse_json('{"a":1,"b":0.8}') | struct<try_parse_json({"a":1,"b":0.8}):variant> | | org.apache.spark.sql.catalyst.expressions.variant.TryVariantGetExpressionBuilder | try_variant_get | SELECT try_variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct<try_variant_get(parse_json({"a": 1}), $.a):int> | | org.apache.spark.sql.catalyst.expressions.variant.VariantGetExpressionBuilder | variant_get | SELECT variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct<variant_get(parse_json({"a": 1}), $.a):int> | | org.apache.spark.sql.catalyst.expressions.xml.XPathBoolean | xpath_boolean | SELECT xpath_boolean('<a><b>1</b></a>','a/b') | struct<xpath_boolean(<a><b>1</b></a>, a/b):boolean> | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala index d53b49f7ab5a..96e85dc58b40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala @@ -88,6 +88,46 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]") } + test("try_parse_json/to_json round-trip") { + def check(input: String, output: String = "INPUT IS OUTPUT"): Unit = { + val df = Seq(input).toDF("v") + val variantDF = df.selectExpr("to_json(try_parse_json(v)) as v").select(Column("v")) + val expected = if (output != "INPUT IS OUTPUT") output else input + checkAnswer(variantDF, Seq(Row(expected))) + } + + check("null") + check("true") + check("false") + check("-1") + check("1.0E10") + check("\"\"") + check("\"" + ("a" * 63) + "\"") + check("\"" + ("b" * 64) + "\"") + // scalastyle:off nonascii + check("\"" + ("你好,世界" * 20) + "\"") + // scalastyle:on nonascii + check("[]") + check("{}") + // scalastyle:off nonascii + check( + "[null, true, false,-1, 1e10, \"\\uD83D\\uDE05\", [ ], { } ]", + "[null,true,false,-1,1.0E10,\"😅\",[],{}]" + ) + // scalastyle:on nonascii + check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]") + // Places where parse_json should fail and therefore, try_parse_json should return null + check("{1:2}", null) + check("{\"a\":1", null) + check("{\"a\":[a,b,c]}", null) + } + + test("try_parse_json with invalid input type") { + // This test is required because the type checking logic in try_parse_json is custom. + val exception = intercept[Exception](spark.sql("select try_parse_json(1)")) + assert(exception != null) + } + test("to_json with nested variant") { val df = Seq(1).toDF("v") val variantDF1 = df.select( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org