This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f6d5ad3ec75b [SPARK-47366][SQL][PYTHON] Add VariantVal for PySpark f6d5ad3ec75b is described below commit f6d5ad3ec75be63472c6b21dda959972f5360ef2 Author: Gene Pang <gene.p...@databricks.com> AuthorDate: Thu Apr 11 09:16:10 2024 +0900 [SPARK-47366][SQL][PYTHON] Add VariantVal for PySpark ### What changes were proposed in this pull request? Add a `VariantVal` implementation for PySpark. It includes convenience methods to convert the Variant to a string, or to a Python object. ### Why are the changes needed? Allows users to work with Variant data more conveniently. ### Does this PR introduce _any_ user-facing change? This is new PySpark functionality to allow users to work with Variant data. ### How was this patch tested? Added unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #45826 from gene-db/variant-pyspark. Lead-authored-by: Gene Pang <gene.p...@databricks.com> Co-authored-by: Hyukjin Kwon <gurwls...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../source/reference/pyspark.sql/core_classes.rst | 1 + python/docs/source/reference/pyspark.sql/index.rst | 1 + .../pyspark.sql/{index.rst => variant_val.rst} | 32 +- python/pyspark/sql/__init__.py | 3 +- python/pyspark/sql/connect/conversion.py | 40 +++ python/pyspark/sql/pandas/types.py | 22 ++ python/pyspark/sql/tests/test_types.py | 64 ++++ python/pyspark/sql/types.py | 59 +++- python/pyspark/sql/variant_utils.py | 388 +++++++++++++++++++++ .../org/apache/spark/sql/util/ArrowUtils.scala | 10 + .../spark/sql/execution/arrow/ArrowWriter.scala | 27 +- 11 files changed, 614 insertions(+), 33 deletions(-) diff --git a/python/docs/source/reference/pyspark.sql/core_classes.rst b/python/docs/source/reference/pyspark.sql/core_classes.rst index 65096da21de5..d3dbbc129cb7 100644 --- a/python/docs/source/reference/pyspark.sql/core_classes.rst +++ b/python/docs/source/reference/pyspark.sql/core_classes.rst @@ -49,3 +49,4 @@ Core Classes datasource.DataSourceRegistration datasource.InputPartition datasource.WriterCommitMessage + VariantVal diff --git a/python/docs/source/reference/pyspark.sql/index.rst b/python/docs/source/reference/pyspark.sql/index.rst index 9322a91fba25..93901ab7ce12 100644 --- a/python/docs/source/reference/pyspark.sql/index.rst +++ b/python/docs/source/reference/pyspark.sql/index.rst @@ -41,5 +41,6 @@ This page gives an overview of all public Spark SQL API. observation udf udtf + variant_val protobuf datasource diff --git a/python/docs/source/reference/pyspark.sql/index.rst b/python/docs/source/reference/pyspark.sql/variant_val.rst similarity index 70% copy from python/docs/source/reference/pyspark.sql/index.rst copy to python/docs/source/reference/pyspark.sql/variant_val.rst index 9322a91fba25..a7f592c18e3a 100644 --- a/python/docs/source/reference/pyspark.sql/index.rst +++ b/python/docs/source/reference/pyspark.sql/variant_val.rst @@ -16,30 +16,12 @@ under the License. -========= -Spark SQL -========= +========== +VariantVal +========== +.. currentmodule:: pyspark.sql -This page gives an overview of all public Spark SQL API. +.. autosummary:: + :toctree: api/ -.. toctree:: - :maxdepth: 2 - - core_classes - spark_session - configuration - io - dataframe - column - data_types - row - functions - window - grouping - catalog - avro - observation - udf - udtf - protobuf - datasource + VariantVal.toPython diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index dd82b037a6b9..bc046da81d27 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -39,7 +39,7 @@ Important classes of Spark SQL and DataFrames: - :class:`pyspark.sql.Window` For working with window functions. """ -from pyspark.sql.types import Row +from pyspark.sql.types import Row, VariantVal from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration, UDTFRegistration from pyspark.sql.session import SparkSession from pyspark.sql.column import Column @@ -67,6 +67,7 @@ __all__ = [ "Row", "DataFrameNaFunctions", "DataFrameStatFunctions", + "VariantVal", "Window", "WindowSpec", "DataFrameReader", diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py index c86ee9c75fec..9b1007c41f9c 100644 --- a/python/pyspark/sql/connect/conversion.py +++ b/python/pyspark/sql/connect/conversion.py @@ -40,6 +40,8 @@ from pyspark.sql.types import ( DecimalType, StringType, UserDefinedType, + VariantType, + VariantVal, ) from pyspark.storagelevel import StorageLevel @@ -95,6 +97,8 @@ class LocalDataToArrowConversion: return True elif isinstance(dataType, UserDefinedType): return True + elif isinstance(dataType, VariantType): + return True else: return False @@ -290,6 +294,24 @@ class LocalDataToArrowConversion: return convert_udt + elif isinstance(dataType, VariantType): + + def convert_variant(value: Any) -> Any: + if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") + return None + elif ( + isinstance(value, dict) + and all(key in value for key in ["value", "metadata"]) + and all(isinstance(value[key], bytes) for key in ["value", "metadata"]) + ): + return VariantVal(value["value"], value["metadata"]) + else: + raise PySparkValueError(error_class="MALFORMED_VARIANT") + + return convert_variant + elif not nullable: def convert_other(value: Any) -> Any: @@ -381,6 +403,8 @@ class ArrowTableToRowsConversion: return True elif isinstance(dataType, UserDefinedType): return True + elif isinstance(dataType, VariantType): + return True else: return False @@ -488,6 +512,22 @@ class ArrowTableToRowsConversion: return convert_udt + elif isinstance(dataType, VariantType): + + def convert_variant(value: Any) -> Any: + if value is None: + return None + elif ( + isinstance(value, dict) + and all(key in value for key in ["value", "metadata"]) + and all(isinstance(value[key], bytes) for key in ["value", "metadata"]) + ): + return VariantVal(value["value"], value["metadata"]) + else: + raise PySparkValueError(error_class="MALFORMED_VARIANT") + + return convert_variant + else: return lambda value: value diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 3b48f8d8c319..559512bd00c1 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -47,6 +47,8 @@ from pyspark.sql.types import ( NullType, DataType, UserDefinedType, + VariantType, + VariantVal, _create_row, ) from pyspark.errors import PySparkTypeError, UnsupportedOperationException, PySparkValueError @@ -108,6 +110,12 @@ def to_arrow_type(dt: DataType) -> "pa.DataType": arrow_type = pa.null() elif isinstance(dt, UserDefinedType): arrow_type = to_arrow_type(dt.sqlType()) + elif type(dt) == VariantType: + fields = [ + pa.field("value", pa.binary(), nullable=False), + pa.field("metadata", pa.binary(), nullable=False), + ] + arrow_type = pa.struct(fields) else: raise PySparkTypeError( error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION", @@ -763,6 +771,20 @@ def _create_converter_to_pandas( return convert_udt + elif isinstance(dt, VariantType): + + def convert_variant(value: Any) -> Any: + if ( + isinstance(value, dict) + and all(key in value for key in ["value", "metadata"]) + and all(isinstance(value[key], bytes) for key in ["value", "metadata"]) + ): + return VariantVal(value["value"], value["metadata"]) + else: + raise PySparkValueError(error_class="MALFORMED_VARIANT") + + return convert_variant + else: return None diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index bb854641906a..af13adbc21bb 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -58,6 +58,7 @@ from pyspark.sql.types import ( BooleanType, NullType, VariantType, + VariantVal, ) from pyspark.sql.types import ( _array_signed_int_typecode_ctype_mappings, @@ -1406,6 +1407,69 @@ class TypesTestsMixin: schema1 = self.spark.range(1).select(F.make_interval(F.lit(1))).schema self.assertEqual(schema1.fields[0].dataType, CalendarIntervalType()) + def test_variant_type(self): + from decimal import Decimal + + self.assertEqual(VariantType().simpleString(), "variant") + + # Holds a tuple of (key, json string value, python value) + expected_values = [ + ("str", '"%s"' % ("0123456789" * 10), "0123456789" * 10), + ("short_str", '"abc"', "abc"), + ("null", "null", None), + ("true", "true", True), + ("false", "false", False), + ("int1", "1", 1), + ("-int1", "-5", -5), + ("int2", "257", 257), + ("-int2", "-124", -124), + ("int4", "65793", 65793), + ("-int4", "-69633", -69633), + ("int8", "4295033089", 4295033089), + ("-int8", "-4294967297", -4294967297), + ("float4", "1.23456789e-30", 1.23456789e-30), + ("-float4", "-4.56789e+29", -4.56789e29), + ("dec4", "123.456", Decimal("123.456")), + ("-dec4", "-321.654", Decimal("-321.654")), + ("dec8", "429.4967297", Decimal("429.4967297")), + ("-dec8", "-5.678373902", Decimal("-5.678373902")), + ("dec16", "467440737095.51617", Decimal("467440737095.51617")), + ("-dec16", "-67.849438003827263", Decimal("-67.849438003827263")), + ("arr", '[1.1,"2",[3],{"4":5}]', [Decimal("1.1"), "2", [3], {"4": 5}]), + ("obj", '{"a":["123",{"b":2}],"c":3}', {"a": ["123", {"b": 2}], "c": 3}), + ] + json_str = "{%s}" % ",".join(['"%s": %s' % (t[0], t[1]) for t in expected_values]) + + df = self.spark.createDataFrame([({"json": json_str})]) + row = df.select( + F.parse_json(df.json).alias("v"), + F.array([F.parse_json(F.lit('{"a": 1}'))]).alias("a"), + F.struct([F.parse_json(F.lit('{"b": "2"}'))]).alias("s"), + F.create_map([F.lit("k"), F.parse_json(F.lit('{"c": true}'))]).alias("m"), + ).collect()[0] + variants = [row["v"], row["a"][0], row["s"]["col1"], row["m"]["k"]] + for v in variants: + self.assertEqual(type(v), VariantVal) + + # check str + as_string = str(variants[0]) + for key, expected, _ in expected_values: + self.assertTrue('"%s":%s' % (key, expected) in as_string) + self.assertEqual(str(variants[1]), '{"a":1}') + self.assertEqual(str(variants[2]), '{"b":"2"}') + self.assertEqual(str(variants[3]), '{"c":true}') + + # check toPython + as_python = variants[0].toPython() + for key, _, obj in expected_values: + self.assertEqual(as_python[key], obj) + self.assertEqual(variants[1].toPython(), {"a": 1}) + self.assertEqual(variants[2].toPython(), {"b": "2"}) + self.assertEqual(variants[3].toPython(), {"c": True}) + + # check repr + self.assertEqual(str(variants[0]), str(eval(repr(variants[0])))) + def test_from_ddl(self): self.assertEqual(DataType.fromDDL("long"), LongType()) self.assertEqual( diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 9b1bab4c23fa..3546fd822814 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -48,6 +48,7 @@ from typing import ( from pyspark.util import is_remote_only from pyspark.serializers import CloudPickleSerializer from pyspark.sql.utils import has_numpy, get_active_spark_context +from pyspark.sql.variant_utils import VariantUtils from pyspark.errors import ( PySparkNotImplementedError, PySparkTypeError, @@ -95,6 +96,7 @@ __all__ = [ "StructField", "StructType", "VariantType", + "VariantVal", ] @@ -1341,7 +1343,13 @@ class VariantType(AtomicType): .. versionadded:: 4.0.0 """ - pass + def needConversion(self) -> bool: + return True + + def fromInternal(self, obj: Dict) -> Optional["VariantVal"]: + if obj is None or not all(key in obj for key in ["value", "metadata"]): + return None + return VariantVal(obj["value"], obj["metadata"]) class UserDefinedType(DataType): @@ -1465,6 +1473,55 @@ class UserDefinedType(DataType): return type(self) == type(other) +class VariantVal: + """ + A class to represent a Variant value in Python. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + value : bytes + The bytes representing the value component of the Variant. + metadata : bytes + The bytes representing the metadata component of the Variant. + + Methods + ------- + toPython() + Convert the VariantVal to a Python data structure. + + Examples + -------- + >>> from pyspark.sql.functions import * + >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ]) + >>> v = df.select(parse_json(df.json).alias("var")).collect()[0].var + >>> v.toPython() + {'a': 1} + """ + + def __init__(self, value: bytes, metadata: bytes): + self.value = value + self.metadata = metadata + + def __str__(self) -> str: + return VariantUtils.to_json(self.value, self.metadata) + + def __repr__(self) -> str: + return "VariantVal(%r, %r)" % (self.value, self.metadata) + + def toPython(self) -> Any: + """ + Convert the VariantVal to a Python data structure. + + Returns + ------- + Any + A Python object that represents the Variant. + """ + return VariantUtils.to_python(self.value, self.metadata) + + _atomic_types: List[Type[DataType]] = [ StringType, CharType, diff --git a/python/pyspark/sql/variant_utils.py b/python/pyspark/sql/variant_utils.py new file mode 100644 index 000000000000..9ca70365316d --- /dev/null +++ b/python/pyspark/sql/variant_utils.py @@ -0,0 +1,388 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import decimal +import json +import struct +from array import array +from typing import Any, Callable, Dict, List, Tuple +from pyspark.errors import PySparkValueError + + +class VariantUtils: + """ + A utility class for VariantVal. + + Adapted from library at: org.apache.spark.types.variant.VariantUtil + """ + + BASIC_TYPE_BITS = 2 + BASIC_TYPE_MASK = 0x3 + TYPE_INFO_MASK = 0x3F + # The inclusive maximum value of the type info value. It is the size limit of `SHORT_STR`. + MAX_SHORT_STR_SIZE = 0x3F + + # Below is all possible basic type values. + # Primitive value. The type info value must be one of the values in the below section. + PRIMITIVE = 0 + # Short string value. The type info value is the string size, which must be in `[0, + # MAX_SHORT_STR_SIZE]`. + # The string content bytes directly follow the header byte. + SHORT_STR = 1 + # Object value. The content contains a size, a list of field ids, a list of field offsets, and + # the actual field data. The length of the id list is `size`, while the length of the offset + # list is `size + 1`, where the last offset represent the total size of the field data. The + # fields in an object must be sorted by the field name in alphabetical order. Duplicate field + # names in one object are not allowed. + # We use 5 bits in the type info to specify the integer type of the object header: it should + # be 0_b4_b3b2_b1b0 (MSB is 0), where: + # - b4 specifies the type of size. When it is 0/1, `size` is a little-endian 1/4-byte + # unsigned integer. + # - b3b2/b1b0 specifies the integer type of id and offset. When the 2 bits are 0/1/2, the + # list contains 1/2/3-byte little-endian unsigned integers. + OBJECT = 2 + # Array value. The content contains a size, a list of field offsets, and the actual element + # data. It is similar to an object without the id list. The length of the offset list + # is `size + 1`, where the last offset represent the total size of the element data. + # Its type info should be: 000_b2_b1b0: + # - b2 specifies the type of size. + # - b1b0 specifies the integer type of offset. + ARRAY = 3 + + # Below is all possible type info values for `PRIMITIVE`. + # JSON Null value. Empty content. + NULL = 0 + # True value. Empty content. + TRUE = 1 + # False value. Empty content. + FALSE = 2 + # 1-byte little-endian signed integer. + INT1 = 3 + # 2-byte little-endian signed integer. + INT2 = 4 + # 4-byte little-endian signed integer. + INT4 = 5 + # 4-byte little-endian signed integer. + INT8 = 6 + # 8-byte IEEE double. + DOUBLE = 7 + # 4-byte decimal. Content is 1-byte scale + 4-byte little-endian signed integer. + DECIMAL4 = 8 + # 8-byte decimal. Content is 1-byte scale + 8-byte little-endian signed integer. + DECIMAL8 = 9 + # 16-byte decimal. Content is 1-byte scale + 16-byte little-endian signed integer. + DECIMAL16 = 10 + # Long string value. The content is (4-byte little-endian unsigned integer representing the + # string size) + (size bytes of string content). + LONG_STR = 16 + + U32_SIZE = 4 + + @classmethod + def to_json(cls, value: bytes, metadata: bytes) -> str: + """ + Convert the VariantVal to a JSON string. + :return: JSON string + """ + return cls._to_json(value, metadata, 0) + + @classmethod + def to_python(cls, value: bytes, metadata: bytes) -> str: + """ + Convert the VariantVal to a nested Python object of Python data types. + :return: Python representation of the Variant nested structure + """ + return cls._to_python(value, metadata, 0) + + @classmethod + def _read_long(cls, data: bytes, pos: int, num_bytes: int, signed: bool) -> int: + cls._check_index(pos, len(data)) + cls._check_index(pos + num_bytes - 1, len(data)) + return int.from_bytes(data[pos : pos + num_bytes], byteorder="little", signed=signed) + + @classmethod + def _check_index(cls, pos: int, length: int) -> None: + if pos < 0 or pos >= length: + raise PySparkValueError(error_class="MALFORMED_VARIANT") + + @classmethod + def _get_type_info(cls, value: bytes, pos: int) -> Tuple[int, int]: + """ + Returns the (basic_type, type_info) pair from the given position in the value. + """ + basic_type = value[pos] & VariantUtils.BASIC_TYPE_MASK + type_info = (value[pos] >> VariantUtils.BASIC_TYPE_BITS) & VariantUtils.TYPE_INFO_MASK + return (basic_type, type_info) + + @classmethod + def _get_metadata_key(cls, metadata: bytes, id: int) -> str: + """ + Returns the key string from the dictionary in the metadata, corresponding to `id`. + """ + cls._check_index(0, len(metadata)) + offset_size = ((metadata[0] >> 6) & 0x3) + 1 + dict_size = cls._read_long(metadata, 1, offset_size, signed=False) + if id >= dict_size: + raise PySparkValueError(error_class="MALFORMED_VARIANT") + string_start = 1 + (dict_size + 2) * offset_size + offset = cls._read_long(metadata, 1 + (id + 1) * offset_size, offset_size, signed=False) + next_offset = cls._read_long( + metadata, 1 + (id + 2) * offset_size, offset_size, signed=False + ) + if offset > next_offset: + raise PySparkValueError(error_class="MALFORMED_VARIANT") + cls._check_index(string_start + next_offset - 1, len(metadata)) + return metadata[string_start + offset : (string_start + next_offset)].decode("utf-8") + + @classmethod + def _get_boolean(cls, value: bytes, pos: int) -> bool: + cls._check_index(pos, len(value)) + basic_type, type_info = cls._get_type_info(value, pos) + if basic_type != VariantUtils.PRIMITIVE or ( + type_info != VariantUtils.TRUE and type_info != VariantUtils.FALSE + ): + raise PySparkValueError(error_class="MALFORMED_VARIANT") + return type_info == VariantUtils.TRUE + + @classmethod + def _get_long(cls, value: bytes, pos: int) -> int: + cls._check_index(pos, len(value)) + basic_type, type_info = cls._get_type_info(value, pos) + if basic_type != VariantUtils.PRIMITIVE: + raise PySparkValueError(error_class="MALFORMED_VARIANT") + if type_info == VariantUtils.INT1: + return cls._read_long(value, pos + 1, 1, signed=True) + elif type_info == VariantUtils.INT2: + return cls._read_long(value, pos + 1, 2, signed=True) + elif type_info == VariantUtils.INT4: + return cls._read_long(value, pos + 1, 4, signed=True) + elif type_info == VariantUtils.INT8: + return cls._read_long(value, pos + 1, 8, signed=True) + raise PySparkValueError(error_class="MALFORMED_VARIANT") + + @classmethod + def _get_string(cls, value: bytes, pos: int) -> str: + cls._check_index(pos, len(value)) + basic_type, type_info = cls._get_type_info(value, pos) + if basic_type == VariantUtils.SHORT_STR or ( + basic_type == VariantUtils.PRIMITIVE and type_info == VariantUtils.LONG_STR + ): + start = 0 + length = 0 + if basic_type == VariantUtils.SHORT_STR: + start = pos + 1 + length = type_info + else: + start = pos + 1 + VariantUtils.U32_SIZE + length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False) + cls._check_index(start + length - 1, len(value)) + return value[start : start + length].decode("utf-8") + raise PySparkValueError(error_class="MALFORMED_VARIANT") + + @classmethod + def _get_double(cls, value: bytes, pos: int) -> float: + cls._check_index(pos, len(value)) + basic_type, type_info = cls._get_type_info(value, pos) + if basic_type != VariantUtils.PRIMITIVE or type_info != VariantUtils.DOUBLE: + raise PySparkValueError(error_class="MALFORMED_VARIANT") + return struct.unpack("d", value[pos + 1 : pos + 9])[0] + + @classmethod + def _get_decimal(cls, value: bytes, pos: int) -> decimal.Decimal: + cls._check_index(pos, len(value)) + basic_type, type_info = cls._get_type_info(value, pos) + if basic_type != VariantUtils.PRIMITIVE: + raise PySparkValueError(error_class="MALFORMED_VARIANT") + scale = value[pos + 1] + unscaled = 0 + if type_info == VariantUtils.DECIMAL4: + unscaled = cls._read_long(value, pos + 2, 4, signed=True) + elif type_info == VariantUtils.DECIMAL8: + unscaled = cls._read_long(value, pos + 2, 8, signed=True) + elif type_info == VariantUtils.DECIMAL16: + cls._check_index(pos + 17, len(value)) + unscaled = int.from_bytes(value[pos + 2 : pos + 18], byteorder="little", signed=True) + else: + raise PySparkValueError(error_class="MALFORMED_VARIANT") + return decimal.Decimal(unscaled) * (decimal.Decimal(10) ** (-scale)) + + @classmethod + def _get_type(cls, value: bytes, pos: int) -> Any: + """ + Returns the Python type of the Variant at the given position. + """ + cls._check_index(pos, len(value)) + basic_type, type_info = cls._get_type_info(value, pos) + if basic_type == VariantUtils.SHORT_STR: + return str + elif basic_type == VariantUtils.OBJECT: + return dict + elif basic_type == VariantUtils.ARRAY: + return array + elif type_info == VariantUtils.NULL: + return type(None) + elif type_info == VariantUtils.TRUE or type_info == VariantUtils.FALSE: + return bool + elif ( + type_info == VariantUtils.INT1 + or type_info == VariantUtils.INT2 + or type_info == VariantUtils.INT4 + or type_info == VariantUtils.INT8 + ): + return int + elif type_info == VariantUtils.DOUBLE: + return float + elif ( + type_info == VariantUtils.DECIMAL4 + or type_info == VariantUtils.DECIMAL8 + or type_info == VariantUtils.DECIMAL16 + ): + return decimal.Decimal + elif type_info == VariantUtils.LONG_STR: + return str + raise PySparkValueError(error_class="MALFORMED_VARIANT") + + @classmethod + def _to_json(cls, value: bytes, metadata: bytes, pos: int) -> Any: + variant_type = cls._get_type(value, pos) + if variant_type == dict: + + def handle_object(key_value_pos_list: list[Tuple[str, int]]) -> str: + key_value_list = [ + json.dumps(key) + ":" + cls._to_json(value, metadata, value_pos) + for (key, value_pos) in key_value_pos_list + ] + return "{" + ",".join(key_value_list) + "}" + + return cls._handle_object(value, metadata, pos, handle_object) + elif variant_type == array: + + def handle_array(value_pos_list: list[int]) -> str: + value_list = [ + cls._to_json(value, metadata, value_pos) for value_pos in value_pos_list + ] + return "[" + ",".join(value_list) + "]" + + return cls._handle_array(value, pos, handle_array) + else: + value = cls._get_scalar(variant_type, value, metadata, pos) + if value is None: + return "null" + if type(value) == bool: + return "true" if value else "false" + if type(value) == str: + return json.dumps(value) + return str(value) + + @classmethod + def _to_python(cls, value: bytes, metadata: bytes, pos: int) -> Any: + variant_type = cls._get_type(value, pos) + if variant_type == dict: + + def handle_object(key_value_pos_list: list[Tuple[str, int]]) -> Dict[str, Any]: + key_value_list = [ + (key, cls._to_python(value, metadata, value_pos)) + for (key, value_pos) in key_value_pos_list + ] + return dict(key_value_list) + + return cls._handle_object(value, metadata, pos, handle_object) + elif variant_type == array: + + def handle_array(value_pos_list: list[int]) -> List[Any]: + value_list = [ + cls._to_python(value, metadata, value_pos) for value_pos in value_pos_list + ] + return value_list + + return cls._handle_array(value, pos, handle_array) + else: + return cls._get_scalar(variant_type, value, metadata, pos) + + @classmethod + def _get_scalar(cls, variant_type: Any, value: bytes, metadata: bytes, pos: int) -> Any: + if isinstance(None, variant_type): + return None + elif variant_type == bool: + return cls._get_boolean(value, pos) + elif variant_type == int: + return cls._get_long(value, pos) + elif variant_type == str: + return cls._get_string(value, pos) + elif variant_type == float: + return cls._get_double(value, pos) + elif variant_type == decimal.Decimal: + return cls._get_decimal(value, pos) + else: + raise PySparkValueError(error_class="MALFORMED_VARIANT") + + @classmethod + def _handle_object( + cls, value: bytes, metadata: bytes, pos: int, func: Callable[[list[Tuple[str, int]]], Any] + ) -> Any: + """ + Parses the variant object at position `pos`. + Calls `func` with a list of (key, value position) pairs of the object. + """ + cls._check_index(pos, len(value)) + basic_type, type_info = cls._get_type_info(value, pos) + if basic_type != VariantUtils.OBJECT: + raise PySparkValueError(error_class="MALFORMED_VARIANT") + large_size = ((type_info >> 4) & 0x1) != 0 + size_bytes = VariantUtils.U32_SIZE if large_size else 1 + num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False) + id_size = ((type_info >> 2) & 0x3) + 1 + offset_size = ((type_info) & 0x3) + 1 + id_start = pos + 1 + size_bytes + offset_start = id_start + num_fields * id_size + data_start = offset_start + (num_fields + 1) * offset_size + + key_value_pos_list = [] + for i in range(num_fields): + id = cls._read_long(value, id_start + id_size * i, id_size, signed=False) + offset = cls._read_long( + value, offset_start + offset_size * i, offset_size, signed=False + ) + value_pos = data_start + offset + key_value_pos_list.append((cls._get_metadata_key(metadata, id), value_pos)) + return func(key_value_pos_list) + + @classmethod + def _handle_array(cls, value: bytes, pos: int, func: Callable[[list[int]], Any]) -> Any: + """ + Parses the variant array at position `pos`. + Calls `func` with a list of element positions of the array. + """ + cls._check_index(pos, len(value)) + basic_type, type_info = cls._get_type_info(value, pos) + if basic_type != VariantUtils.ARRAY: + raise PySparkValueError(error_class="MALFORMED_VARIANT") + large_size = ((type_info >> 2) & 0x1) != 0 + size_bytes = VariantUtils.U32_SIZE if large_size else 1 + num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False) + offset_size = (type_info & 0x3) + 1 + offset_start = pos + 1 + size_bytes + data_start = offset_start + (num_fields + 1) * offset_size + + value_pos_list = [] + for i in range(num_fields): + offset = cls._read_long( + value, offset_start + offset_size * i, offset_size, signed=False + ) + element_pos = data_start + offset + value_pos_list.append(element_pos) + return func(value_pos_list) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 92a4c687362d..d9bd3b0e612b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -125,6 +125,12 @@ private[sql] object ArrowUtils { largeVarTypes)).asJava) case udt: UserDefinedType[_] => toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes) + case _: VariantType => + val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null, + Map("variant" -> "true").asJava) + new Field(name, fieldType, + Seq(toArrowField("value", BinaryType, false, timeZoneId, largeVarTypes), + toArrowField("metadata", BinaryType, false, timeZoneId, largeVarTypes)).asJava) case dataType => val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId, largeVarTypes), null) @@ -143,6 +149,10 @@ private[sql] object ArrowUtils { val elementField = field.getChildren().get(0) val elementType = fromArrowField(elementField) ArrayType(elementType, containsNull = elementField.isNullable) + case ArrowType.Struct.INSTANCE if field.getMetadata.getOrDefault("variant", "") == "true" + && field.getChildren.asScala.map(_.getName).asJava + .containsAll(Seq("value", "metadata").asJava) => + VariantType case ArrowType.Struct.INSTANCE => val fields = field.getChildren().asScala.map { child => val dt = fromArrowField(child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 6680a5320fe3..ca7703bef48b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -84,6 +84,11 @@ object ArrowWriter { case (_: DayTimeIntervalType, vector: DurationVector) => new DurationWriter(vector) case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) => new IntervalMonthDayNanoWriter(vector) + case (VariantType, vector: StructVector) => + val children = (0 until vector.size()).map { ordinal => + createFieldWriter(vector.getChildByOrdinal(ordinal)) + } + new StructWriter(vector, children.toArray) case (dt, _) => throw ExecutionErrors.unsupportedDataTypeError(dt) } @@ -368,6 +373,8 @@ private[arrow] class StructWriter( val valueVector: StructVector, children: Array[ArrowFieldWriter]) extends ArrowFieldWriter { + lazy val isVariant = valueVector.getField.getMetadata.get("variant") == "true" + override def setNull(): Unit = { var i = 0 while (i < children.length) { @@ -379,12 +386,20 @@ private[arrow] class StructWriter( } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - val struct = input.getStruct(ordinal, children.length) - var i = 0 - valueVector.setIndexDefined(count) - while (i < struct.numFields) { - children(i).write(struct, i) - i += 1 + if (isVariant) { + valueVector.setIndexDefined(count) + val v = input.getVariant(ordinal) + val row = InternalRow(v.getValue, v.getMetadata) + children(0).write(row, 0) + children(1).write(row, 1) + } else { + val struct = input.getStruct(ordinal, children.length) + var i = 0 + valueVector.setIndexDefined(count) + while (i < struct.numFields) { + children(i).write(struct, i) + i += 1 + } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org