This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4a1f2419e243 [SPARK-47451][SQL] Support to_json(variant).
4a1f2419e243 is described below

commit 4a1f2419e243272064fde96529149316fe53bc10
Author: Chenhao Li <chenhao...@databricks.com>
AuthorDate: Mon Mar 25 14:58:01 2024 +0800

    [SPARK-47451][SQL] Support to_json(variant).
    
    ### What changes were proposed in this pull request?
    
    This PR adds the functionality to format a variant value as a JSON string. 
It is exposed in the `to_json` expression by allowing the variant type (or a 
nested type containing the variant type) as its input.
    
    ### How was this patch tested?
    
    Unit tests that validate the `to_json` result. The input includes both 
`parse_json` results and manually constructed bytes.
    Negative cases with malformed inputs are also covered.
    
    Some tests disabled in https://github.com/apache/spark/pull/45479 are 
re-enabled.
    
    Closes #45575 from chenhao-db/variant_to_json.
    
    Authored-by: Chenhao Li <chenhao...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 common/unsafe/pom.xml                              |   6 +
 .../org/apache/spark/unsafe/types/VariantVal.java  |   4 +-
 .../src/main/resources/error/error-classes.json    |  12 +
 .../org/apache/spark/types/variant/Variant.java    |  91 +++++++
 .../apache/spark/types/variant/VariantUtil.java    | 298 +++++++++++++++++++++
 docs/sql-error-conditions.md                       |  12 +
 .../sql/catalyst/expressions/jsonExpressions.scala |   8 +-
 .../spark/sql/catalyst/json/JacksonGenerator.scala |  16 +-
 .../variant/VariantExpressionSuite.scala           | 255 +++++++++++++++++-
 .../scala/org/apache/spark/sql/VariantSuite.scala  |   9 +-
 .../sql/expressions/ExpressionInfoSuite.scala      |   3 -
 11 files changed, 696 insertions(+), 18 deletions(-)

diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml
index 13b45f55a4ad..a5ef9847859a 100644
--- a/common/unsafe/pom.xml
+++ b/common/unsafe/pom.xml
@@ -47,6 +47,12 @@
       <version>${project.version}</version>
     </dependency>
 
+    <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-variant_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+    </dependency>
+
     <dependency>
       <groupId>org.scala-lang.modules</groupId>
       
<artifactId>scala-parallel-collections_${scala.binary.version}</artifactId>
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java
index e0f04d816d0d..652c05daf344 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java
@@ -18,6 +18,7 @@
 package org.apache.spark.unsafe.types;
 
 import org.apache.spark.unsafe.Platform;
+import org.apache.spark.types.variant.Variant;
 
 import java.io.Serializable;
 import java.util.Arrays;
@@ -104,8 +105,7 @@ public class VariantVal implements Serializable {
    */
   @Override
   public String toString() {
-    // NOTE: the encoding is not yet implemented, this is not the final 
implementation.
-    return new String(value);
+    return new Variant(value, metadata).toJson();
   }
 
   /**
diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index 091f24d44f66..c219db8c6969 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -2876,6 +2876,12 @@
     },
     "sqlState" : "22023"
   },
+  "MALFORMED_VARIANT" : {
+    "message" : [
+      "Variant binary is malformed. Please check the data source is valid."
+    ],
+    "sqlState" : "22023"
+  },
   "MERGE_CARDINALITY_VIOLATION" : {
     "message" : [
       "The ON search condition of the MERGE statement matched a single row 
from the target table with multiple rows of the source table.",
@@ -4555,6 +4561,12 @@
     ],
     "sqlState" : "42883"
   },
+  "VARIANT_CONSTRUCTOR_SIZE_LIMIT" : {
+    "message" : [
+      "Cannot construct a Variant larger than 16 MiB. The maximum allowed size 
of a Variant value is 16 MiB."
+    ],
+    "sqlState" : "22023"
+  },
   "VARIANT_SIZE_LIMIT" : {
     "message" : [
       "Cannot build variant bigger than <sizeLimit> in <functionName>.",
diff --git 
a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java 
b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java
index e43b7ec8ac54..746b38c697d0 100644
--- a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java
+++ b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java
@@ -17,6 +17,14 @@
 
 package org.apache.spark.types.variant;
 
+import com.fasterxml.jackson.core.JsonFactory;
+import com.fasterxml.jackson.core.JsonGenerator;
+
+import java.io.CharArrayWriter;
+import java.io.IOException;
+
+import static org.apache.spark.types.variant.VariantUtil.*;
+
 /**
  * This class is structurally equivalent to {@link 
org.apache.spark.unsafe.types.VariantVal}. We
  * define a new class to avoid depending on or modifying Spark.
@@ -28,6 +36,15 @@ public final class Variant {
   public Variant(byte[] value, byte[] metadata) {
     this.value = value;
     this.metadata = metadata;
+    // There is currently only one allowed version.
+    if (metadata.length < 1 || (metadata[0] & VERSION_MASK) != VERSION) {
+      throw malformedVariant();
+    }
+    // Don't attempt to use a Variant larger than 16 MiB. We'll never produce 
one, and it risks
+    // memory instability.
+    if (metadata.length > SIZE_LIMIT || value.length > SIZE_LIMIT) {
+      throw variantConstructorSizeLimit();
+    }
   }
 
   public byte[] getValue() {
@@ -37,4 +54,78 @@ public final class Variant {
   public byte[] getMetadata() {
     return metadata;
   }
+
+  // Stringify the variant in JSON format.
+  // Throw `MALFORMED_VARIANT` if the variant is malformed.
+  public String toJson() {
+    StringBuilder sb = new StringBuilder();
+    toJsonImpl(value, metadata, 0, sb);
+    return sb.toString();
+  }
+
+  // Escape a string so that it can be pasted into JSON structure.
+  // For example, if `str` only contains a new-line character, then the result 
content is "\n"
+  // (4 characters).
+  static String escapeJson(String str) {
+    try (CharArrayWriter writer = new CharArrayWriter();
+         JsonGenerator gen = new JsonFactory().createGenerator(writer)) {
+      gen.writeString(str);
+      gen.flush();
+      return writer.toString();
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder 
sb) {
+    switch (VariantUtil.getType(value, pos)) {
+      case OBJECT:
+        handleObject(value, pos, (size, idSize, offsetSize, idStart, 
offsetStart, dataStart) -> {
+          sb.append('{');
+          for (int i = 0; i < size; ++i) {
+            int id = readUnsigned(value, idStart + idSize * i, idSize);
+            int offset = readUnsigned(value, offsetStart + offsetSize * i, 
offsetSize);
+            int elementPos = dataStart + offset;
+            if (i != 0) sb.append(',');
+            sb.append(escapeJson(getMetadataKey(metadata, id)));
+            sb.append(':');
+            toJsonImpl(value, metadata, elementPos, sb);
+          }
+          sb.append('}');
+          return null;
+        });
+        break;
+      case ARRAY:
+        handleArray(value, pos, (size, offsetSize, offsetStart, dataStart) -> {
+          sb.append('[');
+          for (int i = 0; i < size; ++i) {
+            int offset = readUnsigned(value, offsetStart + offsetSize * i, 
offsetSize);
+            int elementPos = dataStart + offset;
+            if (i != 0) sb.append(',');
+            toJsonImpl(value, metadata, elementPos, sb);
+          }
+          sb.append(']');
+          return null;
+        });
+        break;
+      case NULL:
+        sb.append("null");
+        break;
+      case BOOLEAN:
+        sb.append(VariantUtil.getBoolean(value, pos));
+        break;
+      case LONG:
+        sb.append(VariantUtil.getLong(value, pos));
+        break;
+      case STRING:
+        sb.append(escapeJson(VariantUtil.getString(value, pos)));
+        break;
+      case DOUBLE:
+        sb.append(VariantUtil.getDouble(value, pos));
+        break;
+      case DECIMAL:
+        sb.append(VariantUtil.getDecimal(value, pos).toPlainString());
+        break;
+    }
+  }
 }
diff --git 
a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java 
b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
index d6e572f98901..b601b7c75eff 100644
--- 
a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
+++ 
b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
@@ -17,6 +17,13 @@
 
 package org.apache.spark.types.variant;
 
+import org.apache.spark.QueryContext;
+import org.apache.spark.SparkRuntimeException;
+import scala.collection.immutable.Map$;
+
+import java.math.BigDecimal;
+import java.math.BigInteger;
+
 /**
  * This class defines constants related to the variant format and provides 
functions for
  * manipulating variant binaries.
@@ -141,4 +148,295 @@ public class VariantUtil {
     return (byte) (((largeSize ? 1 : 0) << (BASIC_TYPE_BITS + 2)) |
         ((offsetSize - 1) << BASIC_TYPE_BITS) | ARRAY);
   }
+
+  // An exception indicating that the variant value or metadata doesn't
+  static SparkRuntimeException malformedVariant() {
+    return new SparkRuntimeException("MALFORMED_VARIANT",
+        Map$.MODULE$.<String, String>empty(), null, new QueryContext[]{}, "");
+  }
+
+  // An exception indicating that an external caller tried to call the Variant 
constructor with
+  // value or metadata exceeding the 16MiB size limit. We will never construct 
a Variant this large,
+  // so it should only be possible to encounter this exception when reading a 
Variant produced by
+  // another tool.
+  static SparkRuntimeException variantConstructorSizeLimit() {
+    return new SparkRuntimeException("VARIANT_CONSTRUCTOR_SIZE_LIMIT",
+        Map$.MODULE$.<String, String>empty(), null, new QueryContext[]{}, "");
+  }
+
+  // Check the validity of an array index `pos`. Throw `MALFORMED_VARIANT` if 
it is out of bound,
+  // meaning that the variant is malformed.
+  static void checkIndex(int pos, int length) {
+    if (pos < 0 || pos >= length) throw malformedVariant();
+  }
+
+  // Read a little-endian signed long value from `bytes[pos, pos + numBytes)`.
+  static long readLong(byte[] bytes, int pos, int numBytes) {
+    checkIndex(pos, bytes.length);
+    checkIndex(pos + numBytes - 1, bytes.length);
+    long result = 0;
+    // All bytes except the most significant byte should be unsign-extended 
and shifted (so we need
+    // `& 0xFF`). The most significant byte should be sign-extended and is 
handled after the loop.
+    for (int i = 0; i < numBytes - 1; ++i) {
+      long unsignedByteValue = bytes[pos + i] & 0xFF;
+      result |= unsignedByteValue << (8 * i);
+    }
+    long signedByteValue = bytes[pos + numBytes - 1];
+    result |= signedByteValue << (8 * (numBytes - 1));
+    return result;
+  }
+
+  // Read a little-endian unsigned int value from `bytes[pos, pos + 
numBytes)`. The value must fit
+  // into a non-negative int (`[0, Integer.MAX_VALUE]`).
+  static int readUnsigned(byte[] bytes, int pos, int numBytes) {
+    checkIndex(pos, bytes.length);
+    checkIndex(pos + numBytes - 1, bytes.length);
+    int result = 0;
+    // Similar to the `readLong` loop, but all bytes should be unsign-extended.
+    for (int i = 0; i < numBytes; ++i) {
+      int unsignedByteValue = bytes[pos + i] & 0xFF;
+      result |= unsignedByteValue << (8 * i);
+    }
+    if (result < 0) throw malformedVariant();
+    return result;
+  }
+
+  // The value type of variant value. It is determined by the header byte but 
not a 1:1 mapping
+  // (for example, INT1/2/4/8 all maps to `Type.LONG`).
+  public enum Type {
+    OBJECT,
+    ARRAY,
+    NULL,
+    BOOLEAN,
+    LONG,
+    STRING,
+    DOUBLE,
+    DECIMAL,
+  }
+
+  // Get the value type of variant value `value[pos...]`. It is only legal to 
call `get*` if
+  // `getType` returns this type (for example, it is only legal to call 
`getLong` if `getType`
+  // returns `Type.Long`).
+  // Throw `MALFORMED_VARIANT` if the variant is malformed.
+  public static Type getType(byte[] value, int pos) {
+    checkIndex(pos, value.length);
+    int basicType = value[pos] & BASIC_TYPE_MASK;
+    int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+    switch (basicType) {
+      case SHORT_STR:
+        return Type.STRING;
+      case OBJECT:
+        return Type.OBJECT;
+      case ARRAY:
+        return Type.ARRAY;
+      default:
+        switch (typeInfo) {
+          case NULL:
+            return Type.NULL;
+          case TRUE:
+          case FALSE:
+            return Type.BOOLEAN;
+          case INT1:
+          case INT2:
+          case INT4:
+          case INT8:
+            return Type.LONG;
+          case DOUBLE:
+            return Type.DOUBLE;
+          case DECIMAL4:
+          case DECIMAL8:
+          case DECIMAL16:
+            return Type.DECIMAL;
+          case LONG_STR:
+            return Type.STRING;
+          default:
+            throw malformedVariant();
+        }
+    }
+  }
+
+  static IllegalStateException unexpectedType(Type type) {
+    return new IllegalStateException("Expect type to be " + type);
+  }
+
+  // Get a boolean value from variant value `value[pos...]`.
+  // Throw `MALFORMED_VARIANT` if the variant is malformed.
+  public static boolean getBoolean(byte[] value, int pos) {
+    checkIndex(pos, value.length);
+    int basicType = value[pos] & BASIC_TYPE_MASK;
+    int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+    if (basicType != PRIMITIVE || (typeInfo != TRUE && typeInfo != FALSE)) {
+      throw unexpectedType(Type.BOOLEAN);
+    }
+    return typeInfo == TRUE;
+  }
+
+  // Get a long value from variant value `value[pos...]`.
+  // Throw `MALFORMED_VARIANT` if the variant is malformed.
+  public static long getLong(byte[] value, int pos) {
+    checkIndex(pos, value.length);
+    int basicType = value[pos] & BASIC_TYPE_MASK;
+    int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+    if (basicType != PRIMITIVE) throw unexpectedType(Type.LONG);
+    switch (typeInfo) {
+      case INT1:
+        return readLong(value, pos + 1, 1);
+      case INT2:
+        return readLong(value, pos + 1, 2);
+      case INT4:
+        return readLong(value, pos + 1, 4);
+      case INT8:
+        return readLong(value, pos + 1, 8);
+      default:
+        throw unexpectedType(Type.LONG);
+    }
+  }
+
+  // Get a double value from variant value `value[pos...]`.
+  // Throw `MALFORMED_VARIANT` if the variant is malformed.
+  public static double getDouble(byte[] value, int pos) {
+    checkIndex(pos, value.length);
+    int basicType = value[pos] & BASIC_TYPE_MASK;
+    int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+    if (basicType != PRIMITIVE || typeInfo != DOUBLE) throw 
unexpectedType(Type.DOUBLE);
+    return Double.longBitsToDouble(readLong(value, pos + 1, 8));
+  }
+
+  // Get a decimal value from variant value `value[pos...]`.
+  // Throw `MALFORMED_VARIANT` if the variant is malformed.
+  public static BigDecimal getDecimal(byte[] value, int pos) {
+    checkIndex(pos, value.length);
+    int basicType = value[pos] & BASIC_TYPE_MASK;
+    int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+    if (basicType != PRIMITIVE) throw unexpectedType(Type.DECIMAL);
+    int scale = value[pos + 1];
+    BigDecimal result;
+    switch (typeInfo) {
+      case DECIMAL4:
+        result = BigDecimal.valueOf(readLong(value, pos + 2, 4), scale);
+        break;
+      case DECIMAL8:
+        result = BigDecimal.valueOf(readLong(value, pos + 2, 8), scale);
+        break;
+      case DECIMAL16:
+        checkIndex(pos + 17, value.length);
+        byte[] bytes = new byte[16];
+        // Copy the bytes reversely because the `BigInteger` constructor 
expects a big-endian
+        // representation.
+        for (int i = 0; i < 16; ++i) {
+          bytes[i] = value[pos + 17 - i];
+        }
+        result = new BigDecimal(new BigInteger(bytes), scale);
+        break;
+      default:
+        throw unexpectedType(Type.DECIMAL);
+    }
+    return result.stripTrailingZeros();
+  }
+
+  // Get a string value from variant value `value[pos...]`.
+  // Throw `MALFORMED_VARIANT` if the variant is malformed.
+  public static String getString(byte[] value, int pos) {
+    checkIndex(pos, value.length);
+    int basicType = value[pos] & BASIC_TYPE_MASK;
+    int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+    if (basicType == SHORT_STR || (basicType == PRIMITIVE && typeInfo == 
LONG_STR)) {
+      int start;
+      int length;
+      if (basicType == SHORT_STR) {
+        start = pos + 1;
+        length = typeInfo;
+      } else {
+        start = pos + 1 + U32_SIZE;
+        length = readUnsigned(value, pos + 1, U32_SIZE);
+      }
+      checkIndex(start + length - 1, value.length);
+      return new String(value, start, length);
+    }
+    throw unexpectedType(Type.STRING);
+  }
+
+  public interface ObjectHandler<T> {
+    /**
+     * @param size Number of object fields.
+     * @param idSize The integer size of the field id list.
+     * @param offsetSize The integer size of the offset list.
+     * @param idStart The starting index of the field id list in the variant 
value array.
+     * @param offsetStart The starting index of the offset list in the variant 
value array.
+     * @param dataStart The starting index of field data in the variant value 
array.
+     */
+    T apply(int size, int idSize, int offsetSize, int idStart, int 
offsetStart, int dataStart);
+  }
+
+  // A helper function to access a variant object. It provides `handler` with 
its required
+  // parameters and returns what it returns.
+  public static <T> T handleObject(byte[] value, int pos, ObjectHandler<T> 
handler) {
+    checkIndex(pos, value.length);
+    int basicType = value[pos] & BASIC_TYPE_MASK;
+    int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+    if (basicType != OBJECT) throw unexpectedType(Type.OBJECT);
+    // Refer to the comment of the `OBJECT` constant for the details of the 
object header encoding.
+    // Suppose `typeInfo` has a bit representation of 0_b4_b3b2_b1b0, the 
following line extracts
+    // b4 to determine whether the object uses a 1/4-byte size.
+    boolean largeSize = ((typeInfo >> 4) & 0x1) != 0;
+    int sizeBytes = (largeSize ? U32_SIZE : 1);
+    int size = readUnsigned(value, pos + 1, sizeBytes);
+    // Extracts b3b2 to determine the integer size of the field id list.
+    int idSize = ((typeInfo >> 2) & 0x3) + 1;
+    // Extracts b1b0 to determine the integer size of the offset list.
+    int offsetSize = (typeInfo & 0x3) + 1;
+    int idStart = pos + 1 + sizeBytes;
+    int offsetStart = idStart + size * idSize;
+    int dataStart = offsetStart + (size + 1) * offsetSize;
+    return handler.apply(size, idSize, offsetSize, idStart, offsetStart, 
dataStart);
+  }
+
+  public interface ArrayHandler<T> {
+    /**
+     * @param size Number of array elements.
+     * @param offsetSize The integer size of the offset list.
+     * @param offsetStart The starting index of the offset list in the variant 
value array.
+     * @param dataStart The starting index of element data in the variant 
value array.
+     */
+    T apply(int size, int offsetSize, int offsetStart, int dataStart);
+  }
+
+  // A helper function to access a variant array.
+  public static <T> T handleArray(byte[] value, int pos, ArrayHandler<T> 
handler) {
+    checkIndex(pos, value.length);
+    int basicType = value[pos] & BASIC_TYPE_MASK;
+    int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+    if (basicType != ARRAY) throw unexpectedType(Type.ARRAY);
+    // Refer to the comment of the `ARRAY` constant for the details of the 
object header encoding.
+    // Suppose `typeInfo` has a bit representation of 000_b2_b1b0, the 
following line extracts
+    // b2 to determine whether the object uses a 1/4-byte size.
+    boolean largeSize = ((typeInfo >> 2) & 0x1) != 0;
+    int sizeBytes = (largeSize ? U32_SIZE : 1);
+    int size = readUnsigned(value, pos + 1, sizeBytes);
+    // Extracts b1b0 to determine the integer size of the offset list.
+    int offsetSize = (typeInfo & 0x3) + 1;
+    int offsetStart = pos + 1 + sizeBytes;
+    int dataStart = offsetStart + (size + 1) * offsetSize;
+    return handler.apply(size, offsetSize, offsetStart, dataStart);
+  }
+
+  // Get a key at `id` in the variant metadata.
+  // Throw `MALFORMED_VARIANT` if the variant is malformed. An out-of-bound 
`id` is also considered
+  // a malformed variant because it is read from the corresponding variant 
value.
+  public static String getMetadataKey(byte[] metadata, int id) {
+    checkIndex(0, metadata.length);
+    // Extracts the highest 2 bits in the metadata header to determine the 
integer size of the
+    // offset list.
+    int offsetSize = ((metadata[0] >> 6) & 0x3) + 1;
+    int dictSize = readUnsigned(metadata, 1, offsetSize);
+    if (id >= dictSize) throw malformedVariant();
+    // There are a header byte, a `dictSize` with `offsetSize` bytes, and 
`(dictSize + 1)` offsets
+    // before the string data.
+    int stringStart = 1 + (dictSize + 2) * offsetSize;
+    int offset = readUnsigned(metadata, 1 + (id + 1) * offsetSize, offsetSize);
+    int nextOffset = readUnsigned(metadata, 1 + (id + 2) * offsetSize, 
offsetSize);
+    if (offset > nextOffset) throw malformedVariant();
+    checkIndex(stringStart + nextOffset - 1, metadata.length);
+    return new String(metadata, stringStart + offset, nextOffset - offset);
+  }
 }
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index bab64caa3888..8b666c1ef9c8 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -1589,6 +1589,12 @@ Parse Mode: `<failFastMode>`. To process malformed 
records as null result, try s
 
 For more details see 
[MALFORMED_RECORD_IN_PARSING](sql-error-conditions-malformed-record-in-parsing-error-class.html)
 
+### MALFORMED_VARIANT
+
+[SQLSTATE: 22023](sql-error-conditions-sqlstates.html#class-22-data-exception)
+
+Variant binary is malformed. Please check the data source is valid.
+
 ### MERGE_CARDINALITY_VIOLATION
 
 [SQLSTATE: 
23K01](sql-error-conditions-sqlstates.html#class-23-integrity-constraint-violation)
@@ -2732,6 +2738,12 @@ The variable `<variableName>` cannot be found. Verify 
the spelling and correctne
 If you did not qualify the name with a schema and catalog, verify the 
current_schema() output, or qualify the name with the correct schema and 
catalog.
 To tolerate the error on drop use DROP VARIABLE IF EXISTS.
 
+### VARIANT_CONSTRUCTOR_SIZE_LIMIT
+
+[SQLSTATE: 22023](sql-error-conditions-sqlstates.html#class-22-data-exception)
+
+Cannot construct a Variant larger than 16 MiB. The maximum allowed size of a 
Variant value is 16 MiB.
+
 ### VARIANT_SIZE_LIMIT
 
 [SQLSTATE: 22023](sql-error-conditions-sqlstates.html#class-22-data-exception)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index b155987242b3..f35c6da4f8af 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
 import org.apache.spark.util.Utils
 
 private[this] sealed trait PathInstruction
@@ -813,13 +813,17 @@ case class StructsToJson(
         (map: Any) =>
           gen.write(map.asInstanceOf[MapData])
           getAndReset()
+      case _: VariantType =>
+        (v: Any) =>
+          gen.write(v.asInstanceOf[VariantVal])
+          getAndReset()
     }
   }
 
   override def dataType: DataType = StringType
 
   override def checkInputDataTypes(): TypeCheckResult = inputSchema match {
-    case dt @ (_: StructType | _: MapType | _: ArrayType) =>
+    case dt @ (_: StructType | _: MapType | _: ArrayType | _: VariantType) =>
       JacksonUtils.verifyType(prettyName, dt)
     case _ =>
       DataTypeMismatch(
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
index c2c6117e1e3a..1964b5f24b34 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.VariantVal
 import org.apache.spark.util.ArrayImplicits._
 
 /**
@@ -46,11 +47,13 @@ class JacksonGenerator(
   // we can directly access data in `ArrayData` without the help of 
`SpecificMutableRow`.
   private type ValueWriter = (SpecializedGetters, Int) => Unit
 
-  // `JackGenerator` can only be initialized with a `StructType`, a `MapType` 
or a `ArrayType`.
+  // `JackGenerator` can only be initialized with a `StructType`, a `MapType`, 
a `ArrayType` or a
+  // `VariantType`.
   require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType]
-    || dataType.isInstanceOf[ArrayType],
+    || dataType.isInstanceOf[ArrayType] || dataType.isInstanceOf[VariantType],
     s"JacksonGenerator only supports to be initialized with a 
${StructType.simpleString}, " +
-      s"${MapType.simpleString} or ${ArrayType.simpleString} but got 
${dataType.catalogString}")
+      s"${MapType.simpleString}, ${ArrayType.simpleString} or 
${VariantType.simpleString} but " +
+      s"got ${dataType.catalogString}")
 
   // `ValueWriter`s for all fields of the schema
   private lazy val rootFieldWriters: Array[ValueWriter] = dataType match {
@@ -202,6 +205,9 @@ class JacksonGenerator(
       (row: SpecializedGetters, ordinal: Int) =>
         writeObject(writeMapData(row.getMap(ordinal), mt, valueWriter))
 
+    case VariantType =>
+      (row: SpecializedGetters, ordinal: Int) => write(row.getVariant(ordinal))
+
     // For UDT values, they should be in the SQL type's corresponding value 
type.
     // We should not see values in the user-defined class at here.
     // For example, VectorUDT's SQL type is an array of double. So, we should 
expect that v is
@@ -310,6 +316,10 @@ class JacksonGenerator(
       mapType = dataType.asInstanceOf[MapType]))
   }
 
+  def write(v: VariantVal): Unit = {
+    gen.writeRawValue(v.toString)
+  }
+
   def writeLineEnding(): Unit = {
     // Note that JSON uses writer with UTF-8 charset. This string will be 
written out as UTF-8.
     gen.writeRaw(lineSeparator)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
index 22155c927e37..2793b1c8c1fb 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
@@ -18,11 +18,22 @@
 package org.apache.spark.sql.catalyst.expressions.variant
 
 import org.apache.spark.{SparkException, SparkFunSuite, SparkRuntimeException}
-import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, 
Literal}
+import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.types.variant.VariantUtil._
 import org.apache.spark.unsafe.types.VariantVal
 
 class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
+  // Zero-extend each byte in the array with the appropriate number of bytes.
+  // Used to manually construct variant binary values with a given offset size.
+  // E.g. padded(Array(1,2,3), 3) will produce Array(1,0,0,2,0,0,3,0,0).
+  private def padded(a: Array[Byte], size: Int): Array[Byte] = {
+    a.flatMap { b =>
+      val padding = List.fill(size - 1)(0.toByte)
+      b :: padding
+    }
+  }
+
   test("parse_json") {
     def check(json: String, expectedValue: Array[Byte], expectedMetadata: 
Array[Byte]): Unit = {
       checkEvaluation(ParseJson(Literal(json)), new VariantVal(expectedValue, 
expectedMetadata))
@@ -111,4 +122,246 @@ class VariantExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
         "Cannot build variant bigger than 16.0 MiB")
     }
   }
+
+  test("round-trip") {
+    def check(input: String, output: String = null): Unit = {
+      checkEvaluation(
+        StructsToJson(Map.empty, ParseJson(Literal(input))),
+        if (output != null) output else input
+      )
+    }
+
+    check("null")
+    check("true")
+    check("false")
+    check("-1")
+    check("1.0E10")
+    check("\"\"")
+    check("\"" + ("a" * 63) + "\"")
+    check("\"" + ("b" * 64) + "\"")
+    // scalastyle:off nonascii
+    check("\"" + ("你好,世界" * 20) + "\"")
+    // scalastyle:on nonascii
+    check("[]")
+    check("{}")
+    // scalastyle:off nonascii
+    check(
+      "[null, true,   false,-1, 1e10, \"\\uD83D\\uDE05\", [ ], { } ]",
+      "[null,true,false,-1,1.0E10,\"😅\",[],{}]"
+    )
+    // scalastyle:on nonascii
+    check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]")
+  }
+
+  test("to_json with nested variant") {
+    checkEvaluation(
+      StructsToJson(Map.empty, CreateArray(Seq(ParseJson(Literal("{}")),
+        ParseJson(Literal("\"\"")),
+        ParseJson(Literal("[1, 2, 3]"))))),
+      "[{},\"\",[1,2,3]]"
+    )
+    checkEvaluation(
+      StructsToJson(Map.empty, CreateNamedStruct(Seq(
+        Literal("a"), ParseJson(Literal("""{ "x": 1, "y": null, "z": "str" 
}""")),
+        Literal("b"), ParseJson(Literal("[[]]")),
+        Literal("c"), ParseJson(Literal("false"))))),
+      """{"a":{"x":1,"y":null,"z":"str"},"b":[[]],"c":false}"""
+    )
+  }
+
+  test("to_json malformed") {
+    def check(value: Array[Byte], metadata: Array[Byte],
+              errorClass: String = "MALFORMED_VARIANT"): Unit = {
+      checkErrorInExpression[SparkRuntimeException](
+        ResolveTimeZone.resolveTimeZones(
+          StructsToJson(Map.empty, Literal(new VariantVal(value, metadata)))),
+        errorClass
+      )
+    }
+
+    val emptyMetadata = Array[Byte](VERSION, 0, 0)
+    // INT8 only has 7 byte content.
+    check(Array(primitiveHeader(INT8), 0, 0, 0, 0, 0, 0, 0), emptyMetadata)
+    // DECIMAL16 only has 15 byte content.
+    check(Array(primitiveHeader(DECIMAL16)) ++ Array.fill(16)(0.toByte), 
emptyMetadata)
+    // Short string content too short.
+    check(Array(shortStrHeader(2), 'x'), emptyMetadata)
+    // Long string length too short (requires 4 bytes).
+    check(Array(primitiveHeader(LONG_STR), 0, 0, 0), emptyMetadata)
+    // Long string content too short.
+    check(Array(primitiveHeader(LONG_STR), 1, 0, 0, 0), emptyMetadata)
+    // Size is 1 but no content.
+    check(Array(arrayHeader(false, 1),
+      /* size */ 1,
+      /* offset list */ 0), emptyMetadata)
+    // Requires 4-byte size is but the actual size only has one byte.
+    check(Array(arrayHeader(true, 1),
+      /* size */ 0,
+      /* offset list */ 0), emptyMetadata)
+    // Offset out of bound.
+    check(Array(arrayHeader(false, 1),
+      /* size */ 1,
+      /* offset list */ 1, 1), emptyMetadata)
+    // Id out of bound.
+    check(Array(objectHeader(false, 1, 1),
+      /* size */ 1,
+      /* id list */ 0,
+      /* offset list */ 0, 2,
+      /* field data */ primitiveHeader(INT1), 1), emptyMetadata)
+    // Variant version is not 1.
+    check(Array(primitiveHeader(INT1), 0), Array[Byte](3, 0, 0))
+    check(Array(primitiveHeader(INT1), 0), Array[Byte](2, 0, 0))
+
+    // Construct binary values that are over 1 << 24 bytes, but otherwise 
valid.
+    val bigVersion = Array[Byte]((VERSION | (3 << 6)).toByte)
+    val a = Array.fill(1 << 24)('a'.toByte)
+    val hugeMetadata = bigVersion ++ Array[Byte](2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 1, 1, 0, 0, 1) ++
+      a ++ Array[Byte]('b')
+    check(Array(primitiveHeader(TRUE)), hugeMetadata, 
"VARIANT_CONSTRUCTOR_SIZE_LIMIT")
+
+    // The keys are 'aaa....' and 'b'. Values are "yyy..." and 'true'.
+    val y = Array.fill(1 << 24)('y'.toByte)
+    val hugeObject = Array[Byte](objectHeader(true, 4, 4)) ++
+      /* size */ padded(Array(2), 4) ++
+      /* id list */ padded(Array(0, 1), 4) ++
+      // Second value starts at offset 5 + (1 << 24), which is `5001` 
little-endian. The last value
+      // is 1 byte, so the one-past-the-end value is `6001`
+      /* offset list */ Array[Byte](0, 0, 0, 0, 5, 0, 0, 1, 6, 0, 0, 1) ++
+      /* field data */ Array[Byte](primitiveHeader(LONG_STR), 0, 0, 0, 1) ++ y 
++ Array[Byte](
+        primitiveHeader(TRUE)
+      )
+
+    val smallMetadata = bigVersion ++ Array[Byte](2, 0, 0, 0, 0, 0, 0, 0, 1, 
0, 0, 0, 2, 0, 0, 0) ++
+      Array[Byte]('a', 'b')
+    check(hugeObject, smallMetadata, "VARIANT_CONSTRUCTOR_SIZE_LIMIT")
+    check(hugeObject, hugeMetadata, "VARIANT_CONSTRUCTOR_SIZE_LIMIT")
+  }
+
+  // Test valid forms of Variant that our writer would never produce.
+  test("to_json valid input") {
+    def check(expectedJson: String, value: Array[Byte], metadata: 
Array[Byte]): Unit = {
+      checkEvaluation(
+        StructsToJson(Map.empty, Literal(new VariantVal(value, metadata))),
+        expectedJson
+      )
+    }
+    // Some valid metadata formats. Check that they aren't rejected.
+    // Sorted string bit is set, and can be ignored.
+    val emptyMetadata2 = Array[Byte](VERSION | 1 << 4, 0, 0)
+    // Bit 5 is not defined in the spec, and can be ignored.
+    val emptyMetadata3 = Array[Byte](VERSION | 1 << 5, 0, 0)
+    // Can specify 3 bytes per size/offset, even if they aren't needed.
+    val header = (VERSION | (2 << 6)).toByte
+    val emptyMetadata4 = Array[Byte](header, 0, 0, 0, 0, 0, 0)
+    check("true", Array(primitiveHeader(TRUE)), emptyMetadata2)
+    check("true", Array(primitiveHeader(TRUE)), emptyMetadata3)
+    check("true", Array(primitiveHeader(TRUE)), emptyMetadata4)
+  }
+
+  // Test StructsToJson with manually constructed input that uses up to 4 
bytes for offsets and
+  // sizes.  We never produce 4-byte offsets, since they're only needed for 
>16 MiB values, which we
+  // error out on, but the reader should be able to handle them if some other 
writer decides to use
+  // them for smaller values.
+  test("to_json with large offsets and sizes") {
+    def check(expectedJson: String, value: Array[Byte], metadata: 
Array[Byte]): Unit = {
+      checkEvaluation(
+        StructsToJson(Map.empty, Literal(new VariantVal(value, metadata))),
+        expectedJson
+      )
+    }
+
+    for {
+      offsetSize <- 1 to 4
+      idSize <- 1 to 4
+      metadataSize <- 1 to 4
+      largeSize <- Seq(false, true)
+    } {
+      // Test array
+      val version = Array[Byte]((VERSION | ((metadataSize - 1) << 6)).toByte)
+      val emptyMetadata = version ++ padded(Array(0, 0), metadataSize)
+      // Construct a binary with the given sizes. Regardless, to_json should 
produce the same
+      // result.
+      val arrayValue = Array[Byte](arrayHeader(largeSize, offsetSize)) ++
+        /* size */ padded(Array(3), if (largeSize) 4 else 1) ++
+        /* offset list */ padded(Array(0, 1, 4, 5), offsetSize) ++
+        Array[Byte](/* values */ primitiveHeader(FALSE),
+            primitiveHeader(INT2), 2, 1, primitiveHeader(NULL))
+      check("[false,258,null]", arrayValue, emptyMetadata)
+
+      // Test object
+      val metadata = version ++
+                     padded(Array(3, 0, 1, 2, 3), metadataSize) ++
+                     Array[Byte]('a', 'b', 'c')
+      val objectValue = Array[Byte](objectHeader(largeSize, idSize, 
offsetSize)) ++
+        /* size */ padded(Array(3), if (largeSize) 4 else 1) ++
+        /* id list */ padded(Array(0, 1, 2), idSize) ++
+        /* offset list */ padded(Array(0, 2, 4, 6), offsetSize) ++
+        /* field data */ Array[Byte](primitiveHeader(INT1), 1,
+            primitiveHeader(INT1), 2, shortStrHeader(1), '3')
+
+      check("""{"a":1,"b":2,"c":"3"}""", objectValue, metadata)
+    }
+  }
+
+  test("to_json large binary") {
+    def check(expectedJson: String, value: Array[Byte], metadata: 
Array[Byte]): Unit = {
+      checkEvaluation(
+        StructsToJson(Map.empty, Literal(new VariantVal(value, metadata))),
+        expectedJson
+      )
+    }
+
+    // Create a binary that uses the max 1 << 24 bytes for both metadata and 
value.
+    val bigVersion = Array[Byte]((VERSION | (2 << 6)).toByte)
+    // Create a single huge value, followed by a one-byte string. We'll have 1 
header byte, plus 12
+    // bytes for size and offsets, plus 1 byte for the final value, so the 
large value is 1 << 24 -
+    // 14 bytes, or (-14, -1, -1) as a signed little-endian value.
+    val aSize = (1 << 24) - 14
+    val a = Array.fill(aSize)('a'.toByte)
+    val hugeMetadata = bigVersion ++ Array[Byte](2, 0, 0, 0, 0, 0, -14, -1, 
-1, -13, -1, -1) ++
+      a ++ Array[Byte]('b')
+    // Validate metadata in isolation.
+    check("true", Array(primitiveHeader(TRUE)), hugeMetadata)
+
+    // The object will contain a large string, and the following bytes:
+    // - object header and size: 1+4 bytes
+    // - ID list: 6 bytes
+    // - offset list: 9 bytes
+    // - field headers and string length: 6 bytes
+    // In order to get the full binary to 1 << 24, the large string is (1 << 
24) - 26 bytes. As a
+    // signed little-endian value, this is (-26, -1, -1).
+    val ySize = (1 << 24) - 26
+    val y = Array.fill(ySize)('y'.toByte)
+    val hugeObject = Array[Byte](objectHeader(true, 3, 3)) ++
+      /* size */ padded(Array(2), 4) ++
+      /* id list */ padded(Array(0, 1), 3) ++
+      // Second offset is (-26,-1,-1), plus 5 bytes for string header, so 
(-21,-1,-1)
+      /* offset list */ Array[Byte](0, 0, 0, -21, -1, -1, -20, -1, -1) ++
+      /* field data */ Array[Byte](primitiveHeader(LONG_STR), -26, -1, -1, 0) 
++ y ++ Array[Byte](
+        primitiveHeader(TRUE)
+      )
+    // Same as hugeObject, but with a short string.
+    val smallObject = Array[Byte](objectHeader(false, 1, 1)) ++
+      /* size */ Array[Byte](2) ++
+      /* id list */ Array[Byte](0, 1) ++
+      /* offset list */ Array[Byte](0, 6, 7) ++
+      /* field data */ Array[Byte](primitiveHeader(LONG_STR), 1, 0, 0, 0, 'y',
+          primitiveHeader(TRUE))
+    val smallMetadata = bigVersion ++ Array[Byte](2, 0, 0, 0, 0, 0, 1, 0, 0, 
2, 0, 0) ++
+      Array[Byte]('a', 'b')
+
+    // Check all combinations of large/small value and metadata.
+    val expectedResult1 =
+      
s"""{"${a.map(_.toChar).mkString}":"${y.map(_.toChar).mkString}","b":true}"""
+    check(expectedResult1, hugeObject, hugeMetadata)
+    val expectedResult2 =
+      s"""{"${a.map(_.toChar).mkString}":"y","b":true}"""
+    check(expectedResult2, smallObject, hugeMetadata)
+    val expectedResult3 =
+      s"""{"a":"${y.map(_.toChar).mkString}","b":true}"""
+    check(expectedResult3, hugeObject, smallMetadata)
+    val expectedResult4 =
+      s"""{"a":"y","b":true}"""
+    check(expectedResult4, smallObject, smallMetadata)
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
index 1a2c424938a1..3991b44d0bbb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
@@ -30,10 +30,7 @@ import org.apache.spark.unsafe.types.VariantVal
 import org.apache.spark.util.ArrayImplicits._
 
 class VariantSuite extends QueryTest with SharedSparkSession {
-  // TODO(SPARK-45891): We need to ignore some tests for now because the 
`toString` implementation
-  // doesn't match the `parse_json` implementation yet. We will shortly add a 
new `toString`
-  // implementation and re-enable the tests.
-  ignore("basic tests") {
+  test("basic tests") {
     def verifyResult(df: DataFrame): Unit = {
       val result = df.collect()
         .map(_.get(0).asInstanceOf[VariantVal].toString)
@@ -43,8 +40,6 @@ class VariantSuite extends QueryTest with SharedSparkSession {
       assert(result == expected)
     }
 
-    // At this point, JSON parsing logic is not really implemented. We just 
construct some number
-    // inputs that are also valid JSON. This exercises passing VariantVal 
throughout the system.
     val query = spark.sql("select parse_json(repeat('1', id)) as v from 
range(1, 10)")
     verifyResult(query)
 
@@ -142,7 +137,7 @@ class VariantSuite extends QueryTest with 
SharedSparkSession {
     }
   }
 
-  ignore("write partitioned file") {
+  test("write partitioned file") {
     def verifyResult(df: DataFrame): Unit = {
       val result = df.selectExpr("v").collect()
         .map(_.get(0).asInstanceOf[VariantVal].toString)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
index 4c77f26949b0..19251330cffe 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
@@ -196,9 +196,6 @@ class ExpressionInfoSuite extends SparkFunSuite with 
SharedSparkSession {
     }
     val exampleRe = """^(.+);\n(?s)(.+)$""".r
     val ignoreSet = Set(
-      // TODO(SPARK-45891): need to temporarily ignore it because the 
`toString` implementation
-      // doesn't match the `parse_json` implementation yet.
-      "org.apache.spark.sql.catalyst.expressions.variant.ParseJson",
       // One of examples shows getting the current timestamp
       "org.apache.spark.sql.catalyst.expressions.UnixTimestamp",
       "org.apache.spark.sql.catalyst.expressions.CurrentDate",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to