This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4a1f2419e243 [SPARK-47451][SQL] Support to_json(variant). 4a1f2419e243 is described below commit 4a1f2419e243272064fde96529149316fe53bc10 Author: Chenhao Li <chenhao...@databricks.com> AuthorDate: Mon Mar 25 14:58:01 2024 +0800 [SPARK-47451][SQL] Support to_json(variant). ### What changes were proposed in this pull request? This PR adds the functionality to format a variant value as a JSON string. It is exposed in the `to_json` expression by allowing the variant type (or a nested type containing the variant type) as its input. ### How was this patch tested? Unit tests that validate the `to_json` result. The input includes both `parse_json` results and manually constructed bytes. Negative cases with malformed inputs are also covered. Some tests disabled in https://github.com/apache/spark/pull/45479 are re-enabled. Closes #45575 from chenhao-db/variant_to_json. Authored-by: Chenhao Li <chenhao...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- common/unsafe/pom.xml | 6 + .../org/apache/spark/unsafe/types/VariantVal.java | 4 +- .../src/main/resources/error/error-classes.json | 12 + .../org/apache/spark/types/variant/Variant.java | 91 +++++++ .../apache/spark/types/variant/VariantUtil.java | 298 +++++++++++++++++++++ docs/sql-error-conditions.md | 12 + .../sql/catalyst/expressions/jsonExpressions.scala | 8 +- .../spark/sql/catalyst/json/JacksonGenerator.scala | 16 +- .../variant/VariantExpressionSuite.scala | 255 +++++++++++++++++- .../scala/org/apache/spark/sql/VariantSuite.scala | 9 +- .../sql/expressions/ExpressionInfoSuite.scala | 3 - 11 files changed, 696 insertions(+), 18 deletions(-) diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 13b45f55a4ad..a5ef9847859a 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -47,6 +47,12 @@ <version>${project.version}</version> </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-variant_${scala.binary.version}</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> <groupId>org.scala-lang.modules</groupId> <artifactId>scala-parallel-collections_${scala.binary.version}</artifactId> diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java index e0f04d816d0d..652c05daf344 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe.types; import org.apache.spark.unsafe.Platform; +import org.apache.spark.types.variant.Variant; import java.io.Serializable; import java.util.Arrays; @@ -104,8 +105,7 @@ public class VariantVal implements Serializable { */ @Override public String toString() { - // NOTE: the encoding is not yet implemented, this is not the final implementation. - return new String(value); + return new Variant(value, metadata).toJson(); } /** diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 091f24d44f66..c219db8c6969 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -2876,6 +2876,12 @@ }, "sqlState" : "22023" }, + "MALFORMED_VARIANT" : { + "message" : [ + "Variant binary is malformed. Please check the data source is valid." + ], + "sqlState" : "22023" + }, "MERGE_CARDINALITY_VIOLATION" : { "message" : [ "The ON search condition of the MERGE statement matched a single row from the target table with multiple rows of the source table.", @@ -4555,6 +4561,12 @@ ], "sqlState" : "42883" }, + "VARIANT_CONSTRUCTOR_SIZE_LIMIT" : { + "message" : [ + "Cannot construct a Variant larger than 16 MiB. The maximum allowed size of a Variant value is 16 MiB." + ], + "sqlState" : "22023" + }, "VARIANT_SIZE_LIMIT" : { "message" : [ "Cannot build variant bigger than <sizeLimit> in <functionName>.", diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java index e43b7ec8ac54..746b38c697d0 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java @@ -17,6 +17,14 @@ package org.apache.spark.types.variant; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonGenerator; + +import java.io.CharArrayWriter; +import java.io.IOException; + +import static org.apache.spark.types.variant.VariantUtil.*; + /** * This class is structurally equivalent to {@link org.apache.spark.unsafe.types.VariantVal}. We * define a new class to avoid depending on or modifying Spark. @@ -28,6 +36,15 @@ public final class Variant { public Variant(byte[] value, byte[] metadata) { this.value = value; this.metadata = metadata; + // There is currently only one allowed version. + if (metadata.length < 1 || (metadata[0] & VERSION_MASK) != VERSION) { + throw malformedVariant(); + } + // Don't attempt to use a Variant larger than 16 MiB. We'll never produce one, and it risks + // memory instability. + if (metadata.length > SIZE_LIMIT || value.length > SIZE_LIMIT) { + throw variantConstructorSizeLimit(); + } } public byte[] getValue() { @@ -37,4 +54,78 @@ public final class Variant { public byte[] getMetadata() { return metadata; } + + // Stringify the variant in JSON format. + // Throw `MALFORMED_VARIANT` if the variant is malformed. + public String toJson() { + StringBuilder sb = new StringBuilder(); + toJsonImpl(value, metadata, 0, sb); + return sb.toString(); + } + + // Escape a string so that it can be pasted into JSON structure. + // For example, if `str` only contains a new-line character, then the result content is "\n" + // (4 characters). + static String escapeJson(String str) { + try (CharArrayWriter writer = new CharArrayWriter(); + JsonGenerator gen = new JsonFactory().createGenerator(writer)) { + gen.writeString(str); + gen.flush(); + return writer.toString(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder sb) { + switch (VariantUtil.getType(value, pos)) { + case OBJECT: + handleObject(value, pos, (size, idSize, offsetSize, idStart, offsetStart, dataStart) -> { + sb.append('{'); + for (int i = 0; i < size; ++i) { + int id = readUnsigned(value, idStart + idSize * i, idSize); + int offset = readUnsigned(value, offsetStart + offsetSize * i, offsetSize); + int elementPos = dataStart + offset; + if (i != 0) sb.append(','); + sb.append(escapeJson(getMetadataKey(metadata, id))); + sb.append(':'); + toJsonImpl(value, metadata, elementPos, sb); + } + sb.append('}'); + return null; + }); + break; + case ARRAY: + handleArray(value, pos, (size, offsetSize, offsetStart, dataStart) -> { + sb.append('['); + for (int i = 0; i < size; ++i) { + int offset = readUnsigned(value, offsetStart + offsetSize * i, offsetSize); + int elementPos = dataStart + offset; + if (i != 0) sb.append(','); + toJsonImpl(value, metadata, elementPos, sb); + } + sb.append(']'); + return null; + }); + break; + case NULL: + sb.append("null"); + break; + case BOOLEAN: + sb.append(VariantUtil.getBoolean(value, pos)); + break; + case LONG: + sb.append(VariantUtil.getLong(value, pos)); + break; + case STRING: + sb.append(escapeJson(VariantUtil.getString(value, pos))); + break; + case DOUBLE: + sb.append(VariantUtil.getDouble(value, pos)); + break; + case DECIMAL: + sb.append(VariantUtil.getDecimal(value, pos).toPlainString()); + break; + } + } } diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java index d6e572f98901..b601b7c75eff 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java @@ -17,6 +17,13 @@ package org.apache.spark.types.variant; +import org.apache.spark.QueryContext; +import org.apache.spark.SparkRuntimeException; +import scala.collection.immutable.Map$; + +import java.math.BigDecimal; +import java.math.BigInteger; + /** * This class defines constants related to the variant format and provides functions for * manipulating variant binaries. @@ -141,4 +148,295 @@ public class VariantUtil { return (byte) (((largeSize ? 1 : 0) << (BASIC_TYPE_BITS + 2)) | ((offsetSize - 1) << BASIC_TYPE_BITS) | ARRAY); } + + // An exception indicating that the variant value or metadata doesn't + static SparkRuntimeException malformedVariant() { + return new SparkRuntimeException("MALFORMED_VARIANT", + Map$.MODULE$.<String, String>empty(), null, new QueryContext[]{}, ""); + } + + // An exception indicating that an external caller tried to call the Variant constructor with + // value or metadata exceeding the 16MiB size limit. We will never construct a Variant this large, + // so it should only be possible to encounter this exception when reading a Variant produced by + // another tool. + static SparkRuntimeException variantConstructorSizeLimit() { + return new SparkRuntimeException("VARIANT_CONSTRUCTOR_SIZE_LIMIT", + Map$.MODULE$.<String, String>empty(), null, new QueryContext[]{}, ""); + } + + // Check the validity of an array index `pos`. Throw `MALFORMED_VARIANT` if it is out of bound, + // meaning that the variant is malformed. + static void checkIndex(int pos, int length) { + if (pos < 0 || pos >= length) throw malformedVariant(); + } + + // Read a little-endian signed long value from `bytes[pos, pos + numBytes)`. + static long readLong(byte[] bytes, int pos, int numBytes) { + checkIndex(pos, bytes.length); + checkIndex(pos + numBytes - 1, bytes.length); + long result = 0; + // All bytes except the most significant byte should be unsign-extended and shifted (so we need + // `& 0xFF`). The most significant byte should be sign-extended and is handled after the loop. + for (int i = 0; i < numBytes - 1; ++i) { + long unsignedByteValue = bytes[pos + i] & 0xFF; + result |= unsignedByteValue << (8 * i); + } + long signedByteValue = bytes[pos + numBytes - 1]; + result |= signedByteValue << (8 * (numBytes - 1)); + return result; + } + + // Read a little-endian unsigned int value from `bytes[pos, pos + numBytes)`. The value must fit + // into a non-negative int (`[0, Integer.MAX_VALUE]`). + static int readUnsigned(byte[] bytes, int pos, int numBytes) { + checkIndex(pos, bytes.length); + checkIndex(pos + numBytes - 1, bytes.length); + int result = 0; + // Similar to the `readLong` loop, but all bytes should be unsign-extended. + for (int i = 0; i < numBytes; ++i) { + int unsignedByteValue = bytes[pos + i] & 0xFF; + result |= unsignedByteValue << (8 * i); + } + if (result < 0) throw malformedVariant(); + return result; + } + + // The value type of variant value. It is determined by the header byte but not a 1:1 mapping + // (for example, INT1/2/4/8 all maps to `Type.LONG`). + public enum Type { + OBJECT, + ARRAY, + NULL, + BOOLEAN, + LONG, + STRING, + DOUBLE, + DECIMAL, + } + + // Get the value type of variant value `value[pos...]`. It is only legal to call `get*` if + // `getType` returns this type (for example, it is only legal to call `getLong` if `getType` + // returns `Type.Long`). + // Throw `MALFORMED_VARIANT` if the variant is malformed. + public static Type getType(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; + switch (basicType) { + case SHORT_STR: + return Type.STRING; + case OBJECT: + return Type.OBJECT; + case ARRAY: + return Type.ARRAY; + default: + switch (typeInfo) { + case NULL: + return Type.NULL; + case TRUE: + case FALSE: + return Type.BOOLEAN; + case INT1: + case INT2: + case INT4: + case INT8: + return Type.LONG; + case DOUBLE: + return Type.DOUBLE; + case DECIMAL4: + case DECIMAL8: + case DECIMAL16: + return Type.DECIMAL; + case LONG_STR: + return Type.STRING; + default: + throw malformedVariant(); + } + } + } + + static IllegalStateException unexpectedType(Type type) { + return new IllegalStateException("Expect type to be " + type); + } + + // Get a boolean value from variant value `value[pos...]`. + // Throw `MALFORMED_VARIANT` if the variant is malformed. + public static boolean getBoolean(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; + if (basicType != PRIMITIVE || (typeInfo != TRUE && typeInfo != FALSE)) { + throw unexpectedType(Type.BOOLEAN); + } + return typeInfo == TRUE; + } + + // Get a long value from variant value `value[pos...]`. + // Throw `MALFORMED_VARIANT` if the variant is malformed. + public static long getLong(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; + if (basicType != PRIMITIVE) throw unexpectedType(Type.LONG); + switch (typeInfo) { + case INT1: + return readLong(value, pos + 1, 1); + case INT2: + return readLong(value, pos + 1, 2); + case INT4: + return readLong(value, pos + 1, 4); + case INT8: + return readLong(value, pos + 1, 8); + default: + throw unexpectedType(Type.LONG); + } + } + + // Get a double value from variant value `value[pos...]`. + // Throw `MALFORMED_VARIANT` if the variant is malformed. + public static double getDouble(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; + if (basicType != PRIMITIVE || typeInfo != DOUBLE) throw unexpectedType(Type.DOUBLE); + return Double.longBitsToDouble(readLong(value, pos + 1, 8)); + } + + // Get a decimal value from variant value `value[pos...]`. + // Throw `MALFORMED_VARIANT` if the variant is malformed. + public static BigDecimal getDecimal(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; + if (basicType != PRIMITIVE) throw unexpectedType(Type.DECIMAL); + int scale = value[pos + 1]; + BigDecimal result; + switch (typeInfo) { + case DECIMAL4: + result = BigDecimal.valueOf(readLong(value, pos + 2, 4), scale); + break; + case DECIMAL8: + result = BigDecimal.valueOf(readLong(value, pos + 2, 8), scale); + break; + case DECIMAL16: + checkIndex(pos + 17, value.length); + byte[] bytes = new byte[16]; + // Copy the bytes reversely because the `BigInteger` constructor expects a big-endian + // representation. + for (int i = 0; i < 16; ++i) { + bytes[i] = value[pos + 17 - i]; + } + result = new BigDecimal(new BigInteger(bytes), scale); + break; + default: + throw unexpectedType(Type.DECIMAL); + } + return result.stripTrailingZeros(); + } + + // Get a string value from variant value `value[pos...]`. + // Throw `MALFORMED_VARIANT` if the variant is malformed. + public static String getString(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; + if (basicType == SHORT_STR || (basicType == PRIMITIVE && typeInfo == LONG_STR)) { + int start; + int length; + if (basicType == SHORT_STR) { + start = pos + 1; + length = typeInfo; + } else { + start = pos + 1 + U32_SIZE; + length = readUnsigned(value, pos + 1, U32_SIZE); + } + checkIndex(start + length - 1, value.length); + return new String(value, start, length); + } + throw unexpectedType(Type.STRING); + } + + public interface ObjectHandler<T> { + /** + * @param size Number of object fields. + * @param idSize The integer size of the field id list. + * @param offsetSize The integer size of the offset list. + * @param idStart The starting index of the field id list in the variant value array. + * @param offsetStart The starting index of the offset list in the variant value array. + * @param dataStart The starting index of field data in the variant value array. + */ + T apply(int size, int idSize, int offsetSize, int idStart, int offsetStart, int dataStart); + } + + // A helper function to access a variant object. It provides `handler` with its required + // parameters and returns what it returns. + public static <T> T handleObject(byte[] value, int pos, ObjectHandler<T> handler) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; + if (basicType != OBJECT) throw unexpectedType(Type.OBJECT); + // Refer to the comment of the `OBJECT` constant for the details of the object header encoding. + // Suppose `typeInfo` has a bit representation of 0_b4_b3b2_b1b0, the following line extracts + // b4 to determine whether the object uses a 1/4-byte size. + boolean largeSize = ((typeInfo >> 4) & 0x1) != 0; + int sizeBytes = (largeSize ? U32_SIZE : 1); + int size = readUnsigned(value, pos + 1, sizeBytes); + // Extracts b3b2 to determine the integer size of the field id list. + int idSize = ((typeInfo >> 2) & 0x3) + 1; + // Extracts b1b0 to determine the integer size of the offset list. + int offsetSize = (typeInfo & 0x3) + 1; + int idStart = pos + 1 + sizeBytes; + int offsetStart = idStart + size * idSize; + int dataStart = offsetStart + (size + 1) * offsetSize; + return handler.apply(size, idSize, offsetSize, idStart, offsetStart, dataStart); + } + + public interface ArrayHandler<T> { + /** + * @param size Number of array elements. + * @param offsetSize The integer size of the offset list. + * @param offsetStart The starting index of the offset list in the variant value array. + * @param dataStart The starting index of element data in the variant value array. + */ + T apply(int size, int offsetSize, int offsetStart, int dataStart); + } + + // A helper function to access a variant array. + public static <T> T handleArray(byte[] value, int pos, ArrayHandler<T> handler) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; + if (basicType != ARRAY) throw unexpectedType(Type.ARRAY); + // Refer to the comment of the `ARRAY` constant for the details of the object header encoding. + // Suppose `typeInfo` has a bit representation of 000_b2_b1b0, the following line extracts + // b2 to determine whether the object uses a 1/4-byte size. + boolean largeSize = ((typeInfo >> 2) & 0x1) != 0; + int sizeBytes = (largeSize ? U32_SIZE : 1); + int size = readUnsigned(value, pos + 1, sizeBytes); + // Extracts b1b0 to determine the integer size of the offset list. + int offsetSize = (typeInfo & 0x3) + 1; + int offsetStart = pos + 1 + sizeBytes; + int dataStart = offsetStart + (size + 1) * offsetSize; + return handler.apply(size, offsetSize, offsetStart, dataStart); + } + + // Get a key at `id` in the variant metadata. + // Throw `MALFORMED_VARIANT` if the variant is malformed. An out-of-bound `id` is also considered + // a malformed variant because it is read from the corresponding variant value. + public static String getMetadataKey(byte[] metadata, int id) { + checkIndex(0, metadata.length); + // Extracts the highest 2 bits in the metadata header to determine the integer size of the + // offset list. + int offsetSize = ((metadata[0] >> 6) & 0x3) + 1; + int dictSize = readUnsigned(metadata, 1, offsetSize); + if (id >= dictSize) throw malformedVariant(); + // There are a header byte, a `dictSize` with `offsetSize` bytes, and `(dictSize + 1)` offsets + // before the string data. + int stringStart = 1 + (dictSize + 2) * offsetSize; + int offset = readUnsigned(metadata, 1 + (id + 1) * offsetSize, offsetSize); + int nextOffset = readUnsigned(metadata, 1 + (id + 2) * offsetSize, offsetSize); + if (offset > nextOffset) throw malformedVariant(); + checkIndex(stringStart + nextOffset - 1, metadata.length); + return new String(metadata, stringStart + offset, nextOffset - offset); + } } diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index bab64caa3888..8b666c1ef9c8 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -1589,6 +1589,12 @@ Parse Mode: `<failFastMode>`. To process malformed records as null result, try s For more details see [MALFORMED_RECORD_IN_PARSING](sql-error-conditions-malformed-record-in-parsing-error-class.html) +### MALFORMED_VARIANT + +[SQLSTATE: 22023](sql-error-conditions-sqlstates.html#class-22-data-exception) + +Variant binary is malformed. Please check the data source is valid. + ### MERGE_CARDINALITY_VIOLATION [SQLSTATE: 23K01](sql-error-conditions-sqlstates.html#class-23-integrity-constraint-violation) @@ -2732,6 +2738,12 @@ The variable `<variableName>` cannot be found. Verify the spelling and correctne If you did not qualify the name with a schema and catalog, verify the current_schema() output, or qualify the name with the correct schema and catalog. To tolerate the error on drop use DROP VARIABLE IF EXISTS. +### VARIANT_CONSTRUCTOR_SIZE_LIMIT + +[SQLSTATE: 22023](sql-error-conditions-sqlstates.html#class-22-data-exception) + +Cannot construct a Variant larger than 16 MiB. The maximum allowed size of a Variant value is 16 MiB. + ### VARIANT_SIZE_LIMIT [SQLSTATE: 22023](sql-error-conditions-sqlstates.html#class-22-data-exception) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index b155987242b3..f35c6da4f8af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{UTF8String, VariantVal} import org.apache.spark.util.Utils private[this] sealed trait PathInstruction @@ -813,13 +813,17 @@ case class StructsToJson( (map: Any) => gen.write(map.asInstanceOf[MapData]) getAndReset() + case _: VariantType => + (v: Any) => + gen.write(v.asInstanceOf[VariantVal]) + getAndReset() } } override def dataType: DataType = StringType override def checkInputDataTypes(): TypeCheckResult = inputSchema match { - case dt @ (_: StructType | _: MapType | _: ArrayType) => + case dt @ (_: StructType | _: MapType | _: ArrayType | _: VariantType) => JacksonUtils.verifyType(prettyName, dt) case _ => DataTypeMismatch( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index c2c6117e1e3a..1964b5f24b34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.VariantVal import org.apache.spark.util.ArrayImplicits._ /** @@ -46,11 +47,13 @@ class JacksonGenerator( // we can directly access data in `ArrayData` without the help of `SpecificMutableRow`. private type ValueWriter = (SpecializedGetters, Int) => Unit - // `JackGenerator` can only be initialized with a `StructType`, a `MapType` or a `ArrayType`. + // `JackGenerator` can only be initialized with a `StructType`, a `MapType`, a `ArrayType` or a + // `VariantType`. require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType] - || dataType.isInstanceOf[ArrayType], + || dataType.isInstanceOf[ArrayType] || dataType.isInstanceOf[VariantType], s"JacksonGenerator only supports to be initialized with a ${StructType.simpleString}, " + - s"${MapType.simpleString} or ${ArrayType.simpleString} but got ${dataType.catalogString}") + s"${MapType.simpleString}, ${ArrayType.simpleString} or ${VariantType.simpleString} but " + + s"got ${dataType.catalogString}") // `ValueWriter`s for all fields of the schema private lazy val rootFieldWriters: Array[ValueWriter] = dataType match { @@ -202,6 +205,9 @@ class JacksonGenerator( (row: SpecializedGetters, ordinal: Int) => writeObject(writeMapData(row.getMap(ordinal), mt, valueWriter)) + case VariantType => + (row: SpecializedGetters, ordinal: Int) => write(row.getVariant(ordinal)) + // For UDT values, they should be in the SQL type's corresponding value type. // We should not see values in the user-defined class at here. // For example, VectorUDT's SQL type is an array of double. So, we should expect that v is @@ -310,6 +316,10 @@ class JacksonGenerator( mapType = dataType.asInstanceOf[MapType])) } + def write(v: VariantVal): Unit = { + gen.writeRawValue(v.toString) + } + def writeLineEnding(): Unit = { // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. gen.writeRaw(lineSeparator) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala index 22155c927e37..2793b1c8c1fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala @@ -18,11 +18,22 @@ package org.apache.spark.sql.catalyst.expressions.variant import org.apache.spark.{SparkException, SparkFunSuite, SparkRuntimeException} -import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal} +import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.types.variant.VariantUtil._ import org.apache.spark.unsafe.types.VariantVal class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + // Zero-extend each byte in the array with the appropriate number of bytes. + // Used to manually construct variant binary values with a given offset size. + // E.g. padded(Array(1,2,3), 3) will produce Array(1,0,0,2,0,0,3,0,0). + private def padded(a: Array[Byte], size: Int): Array[Byte] = { + a.flatMap { b => + val padding = List.fill(size - 1)(0.toByte) + b :: padding + } + } + test("parse_json") { def check(json: String, expectedValue: Array[Byte], expectedMetadata: Array[Byte]): Unit = { checkEvaluation(ParseJson(Literal(json)), new VariantVal(expectedValue, expectedMetadata)) @@ -111,4 +122,246 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { "Cannot build variant bigger than 16.0 MiB") } } + + test("round-trip") { + def check(input: String, output: String = null): Unit = { + checkEvaluation( + StructsToJson(Map.empty, ParseJson(Literal(input))), + if (output != null) output else input + ) + } + + check("null") + check("true") + check("false") + check("-1") + check("1.0E10") + check("\"\"") + check("\"" + ("a" * 63) + "\"") + check("\"" + ("b" * 64) + "\"") + // scalastyle:off nonascii + check("\"" + ("你好,世界" * 20) + "\"") + // scalastyle:on nonascii + check("[]") + check("{}") + // scalastyle:off nonascii + check( + "[null, true, false,-1, 1e10, \"\\uD83D\\uDE05\", [ ], { } ]", + "[null,true,false,-1,1.0E10,\"😅\",[],{}]" + ) + // scalastyle:on nonascii + check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]") + } + + test("to_json with nested variant") { + checkEvaluation( + StructsToJson(Map.empty, CreateArray(Seq(ParseJson(Literal("{}")), + ParseJson(Literal("\"\"")), + ParseJson(Literal("[1, 2, 3]"))))), + "[{},\"\",[1,2,3]]" + ) + checkEvaluation( + StructsToJson(Map.empty, CreateNamedStruct(Seq( + Literal("a"), ParseJson(Literal("""{ "x": 1, "y": null, "z": "str" }""")), + Literal("b"), ParseJson(Literal("[[]]")), + Literal("c"), ParseJson(Literal("false"))))), + """{"a":{"x":1,"y":null,"z":"str"},"b":[[]],"c":false}""" + ) + } + + test("to_json malformed") { + def check(value: Array[Byte], metadata: Array[Byte], + errorClass: String = "MALFORMED_VARIANT"): Unit = { + checkErrorInExpression[SparkRuntimeException]( + ResolveTimeZone.resolveTimeZones( + StructsToJson(Map.empty, Literal(new VariantVal(value, metadata)))), + errorClass + ) + } + + val emptyMetadata = Array[Byte](VERSION, 0, 0) + // INT8 only has 7 byte content. + check(Array(primitiveHeader(INT8), 0, 0, 0, 0, 0, 0, 0), emptyMetadata) + // DECIMAL16 only has 15 byte content. + check(Array(primitiveHeader(DECIMAL16)) ++ Array.fill(16)(0.toByte), emptyMetadata) + // Short string content too short. + check(Array(shortStrHeader(2), 'x'), emptyMetadata) + // Long string length too short (requires 4 bytes). + check(Array(primitiveHeader(LONG_STR), 0, 0, 0), emptyMetadata) + // Long string content too short. + check(Array(primitiveHeader(LONG_STR), 1, 0, 0, 0), emptyMetadata) + // Size is 1 but no content. + check(Array(arrayHeader(false, 1), + /* size */ 1, + /* offset list */ 0), emptyMetadata) + // Requires 4-byte size is but the actual size only has one byte. + check(Array(arrayHeader(true, 1), + /* size */ 0, + /* offset list */ 0), emptyMetadata) + // Offset out of bound. + check(Array(arrayHeader(false, 1), + /* size */ 1, + /* offset list */ 1, 1), emptyMetadata) + // Id out of bound. + check(Array(objectHeader(false, 1, 1), + /* size */ 1, + /* id list */ 0, + /* offset list */ 0, 2, + /* field data */ primitiveHeader(INT1), 1), emptyMetadata) + // Variant version is not 1. + check(Array(primitiveHeader(INT1), 0), Array[Byte](3, 0, 0)) + check(Array(primitiveHeader(INT1), 0), Array[Byte](2, 0, 0)) + + // Construct binary values that are over 1 << 24 bytes, but otherwise valid. + val bigVersion = Array[Byte]((VERSION | (3 << 6)).toByte) + val a = Array.fill(1 << 24)('a'.toByte) + val hugeMetadata = bigVersion ++ Array[Byte](2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1) ++ + a ++ Array[Byte]('b') + check(Array(primitiveHeader(TRUE)), hugeMetadata, "VARIANT_CONSTRUCTOR_SIZE_LIMIT") + + // The keys are 'aaa....' and 'b'. Values are "yyy..." and 'true'. + val y = Array.fill(1 << 24)('y'.toByte) + val hugeObject = Array[Byte](objectHeader(true, 4, 4)) ++ + /* size */ padded(Array(2), 4) ++ + /* id list */ padded(Array(0, 1), 4) ++ + // Second value starts at offset 5 + (1 << 24), which is `5001` little-endian. The last value + // is 1 byte, so the one-past-the-end value is `6001` + /* offset list */ Array[Byte](0, 0, 0, 0, 5, 0, 0, 1, 6, 0, 0, 1) ++ + /* field data */ Array[Byte](primitiveHeader(LONG_STR), 0, 0, 0, 1) ++ y ++ Array[Byte]( + primitiveHeader(TRUE) + ) + + val smallMetadata = bigVersion ++ Array[Byte](2, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0) ++ + Array[Byte]('a', 'b') + check(hugeObject, smallMetadata, "VARIANT_CONSTRUCTOR_SIZE_LIMIT") + check(hugeObject, hugeMetadata, "VARIANT_CONSTRUCTOR_SIZE_LIMIT") + } + + // Test valid forms of Variant that our writer would never produce. + test("to_json valid input") { + def check(expectedJson: String, value: Array[Byte], metadata: Array[Byte]): Unit = { + checkEvaluation( + StructsToJson(Map.empty, Literal(new VariantVal(value, metadata))), + expectedJson + ) + } + // Some valid metadata formats. Check that they aren't rejected. + // Sorted string bit is set, and can be ignored. + val emptyMetadata2 = Array[Byte](VERSION | 1 << 4, 0, 0) + // Bit 5 is not defined in the spec, and can be ignored. + val emptyMetadata3 = Array[Byte](VERSION | 1 << 5, 0, 0) + // Can specify 3 bytes per size/offset, even if they aren't needed. + val header = (VERSION | (2 << 6)).toByte + val emptyMetadata4 = Array[Byte](header, 0, 0, 0, 0, 0, 0) + check("true", Array(primitiveHeader(TRUE)), emptyMetadata2) + check("true", Array(primitiveHeader(TRUE)), emptyMetadata3) + check("true", Array(primitiveHeader(TRUE)), emptyMetadata4) + } + + // Test StructsToJson with manually constructed input that uses up to 4 bytes for offsets and + // sizes. We never produce 4-byte offsets, since they're only needed for >16 MiB values, which we + // error out on, but the reader should be able to handle them if some other writer decides to use + // them for smaller values. + test("to_json with large offsets and sizes") { + def check(expectedJson: String, value: Array[Byte], metadata: Array[Byte]): Unit = { + checkEvaluation( + StructsToJson(Map.empty, Literal(new VariantVal(value, metadata))), + expectedJson + ) + } + + for { + offsetSize <- 1 to 4 + idSize <- 1 to 4 + metadataSize <- 1 to 4 + largeSize <- Seq(false, true) + } { + // Test array + val version = Array[Byte]((VERSION | ((metadataSize - 1) << 6)).toByte) + val emptyMetadata = version ++ padded(Array(0, 0), metadataSize) + // Construct a binary with the given sizes. Regardless, to_json should produce the same + // result. + val arrayValue = Array[Byte](arrayHeader(largeSize, offsetSize)) ++ + /* size */ padded(Array(3), if (largeSize) 4 else 1) ++ + /* offset list */ padded(Array(0, 1, 4, 5), offsetSize) ++ + Array[Byte](/* values */ primitiveHeader(FALSE), + primitiveHeader(INT2), 2, 1, primitiveHeader(NULL)) + check("[false,258,null]", arrayValue, emptyMetadata) + + // Test object + val metadata = version ++ + padded(Array(3, 0, 1, 2, 3), metadataSize) ++ + Array[Byte]('a', 'b', 'c') + val objectValue = Array[Byte](objectHeader(largeSize, idSize, offsetSize)) ++ + /* size */ padded(Array(3), if (largeSize) 4 else 1) ++ + /* id list */ padded(Array(0, 1, 2), idSize) ++ + /* offset list */ padded(Array(0, 2, 4, 6), offsetSize) ++ + /* field data */ Array[Byte](primitiveHeader(INT1), 1, + primitiveHeader(INT1), 2, shortStrHeader(1), '3') + + check("""{"a":1,"b":2,"c":"3"}""", objectValue, metadata) + } + } + + test("to_json large binary") { + def check(expectedJson: String, value: Array[Byte], metadata: Array[Byte]): Unit = { + checkEvaluation( + StructsToJson(Map.empty, Literal(new VariantVal(value, metadata))), + expectedJson + ) + } + + // Create a binary that uses the max 1 << 24 bytes for both metadata and value. + val bigVersion = Array[Byte]((VERSION | (2 << 6)).toByte) + // Create a single huge value, followed by a one-byte string. We'll have 1 header byte, plus 12 + // bytes for size and offsets, plus 1 byte for the final value, so the large value is 1 << 24 - + // 14 bytes, or (-14, -1, -1) as a signed little-endian value. + val aSize = (1 << 24) - 14 + val a = Array.fill(aSize)('a'.toByte) + val hugeMetadata = bigVersion ++ Array[Byte](2, 0, 0, 0, 0, 0, -14, -1, -1, -13, -1, -1) ++ + a ++ Array[Byte]('b') + // Validate metadata in isolation. + check("true", Array(primitiveHeader(TRUE)), hugeMetadata) + + // The object will contain a large string, and the following bytes: + // - object header and size: 1+4 bytes + // - ID list: 6 bytes + // - offset list: 9 bytes + // - field headers and string length: 6 bytes + // In order to get the full binary to 1 << 24, the large string is (1 << 24) - 26 bytes. As a + // signed little-endian value, this is (-26, -1, -1). + val ySize = (1 << 24) - 26 + val y = Array.fill(ySize)('y'.toByte) + val hugeObject = Array[Byte](objectHeader(true, 3, 3)) ++ + /* size */ padded(Array(2), 4) ++ + /* id list */ padded(Array(0, 1), 3) ++ + // Second offset is (-26,-1,-1), plus 5 bytes for string header, so (-21,-1,-1) + /* offset list */ Array[Byte](0, 0, 0, -21, -1, -1, -20, -1, -1) ++ + /* field data */ Array[Byte](primitiveHeader(LONG_STR), -26, -1, -1, 0) ++ y ++ Array[Byte]( + primitiveHeader(TRUE) + ) + // Same as hugeObject, but with a short string. + val smallObject = Array[Byte](objectHeader(false, 1, 1)) ++ + /* size */ Array[Byte](2) ++ + /* id list */ Array[Byte](0, 1) ++ + /* offset list */ Array[Byte](0, 6, 7) ++ + /* field data */ Array[Byte](primitiveHeader(LONG_STR), 1, 0, 0, 0, 'y', + primitiveHeader(TRUE)) + val smallMetadata = bigVersion ++ Array[Byte](2, 0, 0, 0, 0, 0, 1, 0, 0, 2, 0, 0) ++ + Array[Byte]('a', 'b') + + // Check all combinations of large/small value and metadata. + val expectedResult1 = + s"""{"${a.map(_.toChar).mkString}":"${y.map(_.toChar).mkString}","b":true}""" + check(expectedResult1, hugeObject, hugeMetadata) + val expectedResult2 = + s"""{"${a.map(_.toChar).mkString}":"y","b":true}""" + check(expectedResult2, smallObject, hugeMetadata) + val expectedResult3 = + s"""{"a":"${y.map(_.toChar).mkString}","b":true}""" + check(expectedResult3, hugeObject, smallMetadata) + val expectedResult4 = + s"""{"a":"y","b":true}""" + check(expectedResult4, smallObject, smallMetadata) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala index 1a2c424938a1..3991b44d0bbb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala @@ -30,10 +30,7 @@ import org.apache.spark.unsafe.types.VariantVal import org.apache.spark.util.ArrayImplicits._ class VariantSuite extends QueryTest with SharedSparkSession { - // TODO(SPARK-45891): We need to ignore some tests for now because the `toString` implementation - // doesn't match the `parse_json` implementation yet. We will shortly add a new `toString` - // implementation and re-enable the tests. - ignore("basic tests") { + test("basic tests") { def verifyResult(df: DataFrame): Unit = { val result = df.collect() .map(_.get(0).asInstanceOf[VariantVal].toString) @@ -43,8 +40,6 @@ class VariantSuite extends QueryTest with SharedSparkSession { assert(result == expected) } - // At this point, JSON parsing logic is not really implemented. We just construct some number - // inputs that are also valid JSON. This exercises passing VariantVal throughout the system. val query = spark.sql("select parse_json(repeat('1', id)) as v from range(1, 10)") verifyResult(query) @@ -142,7 +137,7 @@ class VariantSuite extends QueryTest with SharedSparkSession { } } - ignore("write partitioned file") { + test("write partitioned file") { def verifyResult(df: DataFrame): Unit = { val result = df.selectExpr("v").collect() .map(_.get(0).asInstanceOf[VariantVal].toString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index 4c77f26949b0..19251330cffe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -196,9 +196,6 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { } val exampleRe = """^(.+);\n(?s)(.+)$""".r val ignoreSet = Set( - // TODO(SPARK-45891): need to temporarily ignore it because the `toString` implementation - // doesn't match the `parse_json` implementation yet. - "org.apache.spark.sql.catalyst.expressions.variant.ParseJson", // One of examples shows getting the current timestamp "org.apache.spark.sql.catalyst.expressions.UnixTimestamp", "org.apache.spark.sql.catalyst.expressions.CurrentDate", --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org