This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d67752a8f3d7 [SPARK-47681][SQL][FOLLOWUP] Fix variant decimal handling d67752a8f3d7 is described below commit d67752a8f3d7c5bda1f56c940b5112c1d5d82d07 Author: Chenhao Li <chenhao...@databricks.com> AuthorDate: Tue May 7 08:10:41 2024 +0800 [SPARK-47681][SQL][FOLLOWUP] Fix variant decimal handling ### What changes were proposed in this pull request? There are two issues with the current variant decimal handling: 1. The precision and scale of the `BigDecimal` returned by `getDecimal` is not checked. Based on the variant spec, they must be within the corresponding limit for DECIMAL4/8/16. An out-of-range decimal can lead to failure in the downstream Spark operations. 2. The current `schema_of_variant` implementation doesn't correctly handle the case where precision is smaller than scale. Spark's `DecimalType` requires `precision >= scale`. The Python side requires a similar fix for 1. During the fix, I found that Python error reporting was not correctly implemented (it was never tested either) and I also fixed it. ### Why are the changes needed? They are bug fixes and are required to process decimals correctly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46338 from chenhao-db/fix_variant_decimal. Authored-by: Chenhao Li <chenhao...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../apache/spark/types/variant/VariantUtil.java | 14 ++++- python/pyspark/errors/error-conditions.json | 5 ++ python/pyspark/sql/tests/test_types.py | 13 +++++ python/pyspark/sql/variant_utils.py | 59 ++++++++++++++-------- .../expressions/variant/variantExpressions.scala | 3 +- .../variant/VariantExpressionSuite.scala | 3 ++ .../apache/spark/sql/VariantEndToEndSuite.scala | 1 + 7 files changed, 76 insertions(+), 22 deletions(-) diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java index e4e9cc8b4cfa..84e3a45e4b0e 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java @@ -392,6 +392,13 @@ public class VariantUtil { return Double.longBitsToDouble(readLong(value, pos + 1, 8)); } + // Check whether the precision and scale of the decimal are within the limit. + private static void checkDecimal(BigDecimal d, int maxPrecision) { + if (d.precision() > maxPrecision || d.scale() > maxPrecision) { + throw malformedVariant(); + } + } + // Get a decimal value from variant value `value[pos...]`. // Throw `MALFORMED_VARIANT` if the variant is malformed. public static BigDecimal getDecimal(byte[] value, int pos) { @@ -399,14 +406,18 @@ public class VariantUtil { int basicType = value[pos] & BASIC_TYPE_MASK; int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; if (basicType != PRIMITIVE) throw unexpectedType(Type.DECIMAL); - int scale = value[pos + 1]; + // Interpret the scale byte as unsigned. If it is a negative byte, the unsigned value must be + // greater than `MAX_DECIMAL16_PRECISION` and will trigger an error in `checkDecimal`. + int scale = value[pos + 1] & 0xFF; BigDecimal result; switch (typeInfo) { case DECIMAL4: result = BigDecimal.valueOf(readLong(value, pos + 2, 4), scale); + checkDecimal(result, MAX_DECIMAL4_PRECISION); break; case DECIMAL8: result = BigDecimal.valueOf(readLong(value, pos + 2, 8), scale); + checkDecimal(result, MAX_DECIMAL8_PRECISION); break; case DECIMAL16: checkIndex(pos + 17, value.length); @@ -417,6 +428,7 @@ public class VariantUtil { bytes[i] = value[pos + 17 - i]; } result = new BigDecimal(new BigInteger(bytes), scale); + checkDecimal(result, MAX_DECIMAL16_PRECISION); break; default: throw unexpectedType(Type.DECIMAL); diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 7771791e41ca..906bf781e1bb 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -482,6 +482,11 @@ "<arg1> and <arg2> should be of the same length, got <arg1_length> and <arg2_length>." ] }, + "MALFORMED_VARIANT" : { + "message" : [ + "Variant binary is malformed. Please check the data source is valid." + ] + }, "MASTER_URL_NOT_SET": { "message": [ "A master URL must be set in your configuration." diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 7d45adb832c8..40eded6a4433 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -1577,6 +1577,19 @@ class TypesTestsMixin: # check repr self.assertEqual(str(variants[0]), str(eval(repr(variants[0])))) + metadata = bytes([1, 0, 0]) + self.assertEqual(str(VariantVal(bytes([32, 0, 1, 0, 0, 0]), metadata)), "1") + self.assertEqual(str(VariantVal(bytes([32, 1, 2, 0, 0, 0]), metadata)), "0.2") + self.assertEqual(str(VariantVal(bytes([32, 2, 3, 0, 0, 0]), metadata)), "0.03") + self.assertEqual(str(VariantVal(bytes([32, 0, 1, 0, 0, 0]), metadata)), "1") + self.assertEqual(str(VariantVal(bytes([32, 0, 255, 201, 154, 59]), metadata)), "999999999") + self.assertRaises( + PySparkValueError, lambda: str(VariantVal(bytes([32, 0, 0, 202, 154, 59]), metadata)) + ) + self.assertRaises( + PySparkValueError, lambda: str(VariantVal(bytes([32, 10, 1, 0, 0, 0]), metadata)) + ) + def test_from_ddl(self): self.assertEqual(DataType.fromDDL("long"), LongType()) self.assertEqual( diff --git a/python/pyspark/sql/variant_utils.py b/python/pyspark/sql/variant_utils.py index 1ee139506b91..95084fc7d932 100644 --- a/python/pyspark/sql/variant_utils.py +++ b/python/pyspark/sql/variant_utils.py @@ -115,6 +115,13 @@ class VariantUtils: ) EPOCH_NTZ = datetime.datetime(year=1970, month=1, day=1, hour=0, minute=0, second=0) + MAX_DECIMAL4_PRECISION = 9 + MAX_DECIMAL4_VALUE = 10**MAX_DECIMAL4_PRECISION + MAX_DECIMAL8_PRECISION = 18 + MAX_DECIMAL8_VALUE = 10**MAX_DECIMAL8_PRECISION + MAX_DECIMAL16_PRECISION = 38 + MAX_DECIMAL16_VALUE = 10**MAX_DECIMAL16_PRECISION + @classmethod def to_json(cls, value: bytes, metadata: bytes, zone_id: str = "UTC") -> str: """ @@ -142,7 +149,7 @@ class VariantUtils: @classmethod def _check_index(cls, pos: int, length: int) -> None: if pos < 0 or pos >= length: - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) @classmethod def _get_type_info(cls, value: bytes, pos: int) -> Tuple[int, int]: @@ -162,14 +169,14 @@ class VariantUtils: offset_size = ((metadata[0] >> 6) & 0x3) + 1 dict_size = cls._read_long(metadata, 1, offset_size, signed=False) if id >= dict_size: - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) string_start = 1 + (dict_size + 2) * offset_size offset = cls._read_long(metadata, 1 + (id + 1) * offset_size, offset_size, signed=False) next_offset = cls._read_long( metadata, 1 + (id + 2) * offset_size, offset_size, signed=False ) if offset > next_offset: - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) cls._check_index(string_start + next_offset - 1, len(metadata)) return metadata[string_start + offset : (string_start + next_offset)].decode("utf-8") @@ -180,7 +187,7 @@ class VariantUtils: if basic_type != VariantUtils.PRIMITIVE or ( type_info != VariantUtils.TRUE and type_info != VariantUtils.FALSE ): - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) return type_info == VariantUtils.TRUE @classmethod @@ -188,7 +195,7 @@ class VariantUtils: cls._check_index(pos, len(value)) basic_type, type_info = cls._get_type_info(value, pos) if basic_type != VariantUtils.PRIMITIVE: - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) if type_info == VariantUtils.INT1: return cls._read_long(value, pos + 1, 1, signed=True) elif type_info == VariantUtils.INT2: @@ -197,25 +204,25 @@ class VariantUtils: return cls._read_long(value, pos + 1, 4, signed=True) elif type_info == VariantUtils.INT8: return cls._read_long(value, pos + 1, 8, signed=True) - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) @classmethod def _get_date(cls, value: bytes, pos: int) -> datetime.date: cls._check_index(pos, len(value)) basic_type, type_info = cls._get_type_info(value, pos) if basic_type != VariantUtils.PRIMITIVE: - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) if type_info == VariantUtils.DATE: days_since_epoch = cls._read_long(value, pos + 1, 4, signed=True) return datetime.date.fromordinal(VariantUtils.EPOCH.toordinal() + days_since_epoch) - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) @classmethod def _get_timestamp(cls, value: bytes, pos: int, zone_id: str) -> datetime.datetime: cls._check_index(pos, len(value)) basic_type, type_info = cls._get_type_info(value, pos) if basic_type != VariantUtils.PRIMITIVE: - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) if type_info == VariantUtils.TIMESTAMP_NTZ: microseconds_since_epoch = cls._read_long(value, pos + 1, 8, signed=True) return VariantUtils.EPOCH_NTZ + datetime.timedelta( @@ -226,7 +233,7 @@ class VariantUtils: return ( VariantUtils.EPOCH + datetime.timedelta(microseconds=microseconds_since_epoch) ).astimezone(ZoneInfo(zone_id)) - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) @classmethod def _get_string(cls, value: bytes, pos: int) -> str: @@ -245,39 +252,51 @@ class VariantUtils: length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False) cls._check_index(start + length - 1, len(value)) return value[start : start + length].decode("utf-8") - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) @classmethod def _get_double(cls, value: bytes, pos: int) -> float: cls._check_index(pos, len(value)) basic_type, type_info = cls._get_type_info(value, pos) if basic_type != VariantUtils.PRIMITIVE: - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) if type_info == VariantUtils.FLOAT: cls._check_index(pos + 4, len(value)) return struct.unpack("<f", value[pos + 1 : pos + 5])[0] elif type_info == VariantUtils.DOUBLE: cls._check_index(pos + 8, len(value)) return struct.unpack("<d", value[pos + 1 : pos + 9])[0] - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) + + @classmethod + def _check_decimal(cls, unscaled: int, scale: int, max_unscaled: int, max_scale: int) -> None: + # max_unscaled == 10**max_scale, but we pass a literal parameter to avoid redundant + # computation. + if unscaled >= max_unscaled or unscaled <= -max_unscaled or scale > max_scale: + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) @classmethod def _get_decimal(cls, value: bytes, pos: int) -> decimal.Decimal: cls._check_index(pos, len(value)) basic_type, type_info = cls._get_type_info(value, pos) if basic_type != VariantUtils.PRIMITIVE: - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) scale = value[pos + 1] unscaled = 0 if type_info == VariantUtils.DECIMAL4: unscaled = cls._read_long(value, pos + 2, 4, signed=True) + cls._check_decimal(unscaled, scale, cls.MAX_DECIMAL4_VALUE, cls.MAX_DECIMAL4_PRECISION) elif type_info == VariantUtils.DECIMAL8: unscaled = cls._read_long(value, pos + 2, 8, signed=True) + cls._check_decimal(unscaled, scale, cls.MAX_DECIMAL8_VALUE, cls.MAX_DECIMAL8_PRECISION) elif type_info == VariantUtils.DECIMAL16: cls._check_index(pos + 17, len(value)) unscaled = int.from_bytes(value[pos + 2 : pos + 18], byteorder="little", signed=True) + cls._check_decimal( + unscaled, scale, cls.MAX_DECIMAL16_VALUE, cls.MAX_DECIMAL16_PRECISION + ) else: - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) return decimal.Decimal(unscaled) * (decimal.Decimal(10) ** (-scale)) @classmethod @@ -285,7 +304,7 @@ class VariantUtils: cls._check_index(pos, len(value)) basic_type, type_info = cls._get_type_info(value, pos) if basic_type != VariantUtils.PRIMITIVE or type_info != VariantUtils.BINARY: - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) start = pos + 1 + VariantUtils.U32_SIZE length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False) cls._check_index(start + length - 1, len(value)) @@ -331,7 +350,7 @@ class VariantUtils: return datetime.datetime elif type_info == VariantUtils.LONG_STR: return str - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) @classmethod def _to_json(cls, value: bytes, metadata: bytes, pos: int, zone_id: str) -> str: @@ -419,7 +438,7 @@ class VariantUtils: elif variant_type == datetime.datetime: return cls._get_timestamp(value, pos, zone_id) else: - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) @classmethod def _handle_object( @@ -432,7 +451,7 @@ class VariantUtils: cls._check_index(pos, len(value)) basic_type, type_info = cls._get_type_info(value, pos) if basic_type != VariantUtils.OBJECT: - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) large_size = ((type_info >> 4) & 0x1) != 0 size_bytes = VariantUtils.U32_SIZE if large_size else 1 num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False) @@ -461,7 +480,7 @@ class VariantUtils: cls._check_index(pos, len(value)) basic_type, type_info = cls._get_type_info(value, pos) if basic_type != VariantUtils.ARRAY: - raise PySparkValueError(error_class="MALFORMED_VARIANT") + raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) large_size = ((type_info >> 2) & 0x1) != 0 size_bytes = VariantUtils.U32_SIZE if large_size else 1 num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index 43c561e10b6d..3dbc72415ff0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -680,7 +680,8 @@ object SchemaOfVariant { case Type.DOUBLE => DoubleType case Type.DECIMAL => val d = v.getDecimal - DecimalType(d.precision(), d.scale()) + // Spark doesn't allow `DecimalType` to have `precision < scale`. + DecimalType(d.precision().max(d.scale()), d.scale()) case Type.DATE => DateType case Type.TIMESTAMP => TimestampType case Type.TIMESTAMP_NTZ => TimestampNTZType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala index d001f0ec051e..f4a6a144c221 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala @@ -58,6 +58,9 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { check(Array(primitiveHeader(INT8), 0, 0, 0, 0, 0, 0, 0), emptyMetadata) // DECIMAL16 only has 15 byte content. check(Array(primitiveHeader(DECIMAL16)) ++ Array.fill(16)(0.toByte), emptyMetadata) + // 1e38 has a precision of 39. Even if it still fits into 16 bytes, it is not a valid decimal. + check(Array[Byte](primitiveHeader(DECIMAL16), 0) ++ + BigDecimal(1e38).toBigInt.toByteArray.reverse, emptyMetadata) // Short string content too short. check(Array(shortStrHeader(2), 'x'), emptyMetadata) // Long string length too short (requires 4 bytes). diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala index c8d267ff5ecc..3964bf3aedec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala @@ -157,6 +157,7 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { check("null", "VOID") check("1", "BIGINT") check("1.0", "DECIMAL(1,0)") + check("0.01", "DECIMAL(2,2)") check("1E0", "DOUBLE") check("true", "BOOLEAN") check("\"2000-01-01\"", "STRING") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org