This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 134a13928965 [SPARK-47681][SQL] Add schema_of_variant expression
134a13928965 is described below

commit 134a13928965e9818393511eadd504b9f1679766
Author: Chenhao Li <chenhao...@databricks.com>
AuthorDate: Tue Apr 9 00:04:02 2024 +0800

    [SPARK-47681][SQL] Add schema_of_variant expression
    
    ### What changes were proposed in this pull request?
    
    This PR adds a new `SchemaOfVariant` expression. It returns schema in the 
SQL format of a variant.
    
    Usage examples:
    
    ```
    > SELECT schema_of_variant(parse_json('null'));
     VOID
    > SELECT schema_of_variant(parse_json('[{"b":true,"a":0}]'));
     ARRAY<STRUCT<a: BIGINT, b: BOOLEAN>>
    ```
    
    ### Why are the changes needed?
    
    This expression can help the user explore the content of variant values.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.  A new SQL expression is added.
    
    ### How was this patch tested?
    
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #45806 from chenhao-db/variant_schema.
    
    Authored-by: Chenhao Li <chenhao...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/analysis/FunctionRegistry.scala   |  1 +
 .../spark/sql/catalyst/encoders/EncoderUtils.scala |  7 +-
 .../expressions/variant/variantExpressions.scala   | 86 ++++++++++++++++++++++
 .../spark/sql/catalyst/json/JsonInferSchema.scala  | 18 +++--
 .../sql-functions/sql-expression-schema.md         |  1 +
 .../apache/spark/sql/VariantEndToEndSuite.scala    | 31 ++++++++
 6 files changed, 134 insertions(+), 10 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index ecba8b263c41..bbc063c32103 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -822,6 +822,7 @@ object FunctionRegistry {
     expression[ParseJson]("parse_json"),
     expressionBuilder("variant_get", VariantGetExpressionBuilder),
     expressionBuilder("try_variant_get", TryVariantGetExpressionBuilder),
+    expression[SchemaOfVariant]("schema_of_variant"),
 
     // cast
     expression[Cast]("cast"),
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
index 45598b6a66f2..20f86a32c1a1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
@@ -23,8 +23,8 @@ import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, C
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.types.{PhysicalBinaryType, 
PhysicalIntegerType, PhysicalLongType}
 import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
-import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, 
ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, 
Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, 
ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, 
UserDefinedType, YearMonthIntervalType}
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, 
ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, 
Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, 
ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, 
UserDefinedType, VariantType, YearMonthIntervalType}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
 
 /**
  * Helper class for Generating [[ExpressionEncoder]]s.
@@ -122,7 +122,8 @@ object EncoderUtils {
     TimestampType -> classOf[PhysicalLongType.InternalType],
     TimestampNTZType -> classOf[PhysicalLongType.InternalType],
     BinaryType -> classOf[PhysicalBinaryType.InternalType],
-    CalendarIntervalType -> classOf[CalendarInterval]
+    CalendarIntervalType -> classOf[CalendarInterval],
+    VariantType -> classOf[VariantVal]
   )
 
   val typeBoxedJavaMapping: Map[DataType, Class[_]] = Map[DataType, Class[_]](
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index 4681326136c7..2f2b5923fed7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.variant
 
 import scala.util.parsing.combinator.RegexParsers
 
+import org.apache.spark.SparkRuntimeException
 import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -26,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
+import org.apache.spark.sql.catalyst.json.JsonInferSchema
 import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, 
VARIANT_GET}
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
 import org.apache.spark.sql.catalyst.util.DateTimeConstants._
@@ -403,3 +405,87 @@ object VariantGetExpressionBuilder extends 
VariantGetExpressionBuilderBase(true)
 )
 // scalastyle:on line.size.limit
 object TryVariantGetExpressionBuilder extends 
VariantGetExpressionBuilderBase(false)
+
+@ExpressionDescription(
+  usage = "_FUNC_(v) - Returns schema in the SQL format of a variant.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(parse_json('null'));
+       VOID
+      > SELECT _FUNC_(parse_json('[{"b":true,"a":0}]'));
+       ARRAY<STRUCT<a: BIGINT, b: BOOLEAN>>
+  """,
+  since = "4.0.0",
+  group = "variant_funcs"
+)
+case class SchemaOfVariant(child: Expression)
+  extends UnaryExpression
+    with RuntimeReplaceable
+    with ExpectsInputTypes {
+  override lazy val replacement: Expression = StaticInvoke(
+    SchemaOfVariant.getClass,
+    StringType,
+    "schemaOfVariant",
+    Seq(child),
+    inputTypes,
+    returnNullable = false)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType)
+
+  override def dataType: DataType = StringType
+
+  override def prettyName: String = "schema_of_variant"
+
+  override protected def withNewChildInternal(newChild: Expression): 
SchemaOfVariant =
+    copy(child = newChild)
+}
+
+object SchemaOfVariant {
+  /** The actual implementation of the `SchemaOfVariant` expression. */
+  def schemaOfVariant(input: VariantVal): UTF8String = {
+    val v = new Variant(input.getValue, input.getMetadata)
+    UTF8String.fromString(schemaOf(v).sql)
+  }
+
+  /**
+   * Return the schema of a variant. Struct fields are guaranteed to be sorted 
alphabetically.
+   */
+  def schemaOf(v: Variant): DataType = v.getType match {
+    case Type.OBJECT =>
+      val size = v.objectSize()
+      val fields = new Array[StructField](size)
+      for (i <- 0 until size) {
+        val field = v.getFieldAtIndex(i)
+        fields(i) = StructField(field.key, schemaOf(field.value))
+      }
+      // According to the variant spec, object fields must be sorted 
alphabetically. So we don't
+      // have to sort, but just need to validate they are sorted.
+      for (i <- 1 until size) {
+        if (fields(i - 1).name >= fields(i).name) {
+          throw new SparkRuntimeException("MALFORMED_VARIANT", Map.empty)
+        }
+      }
+      StructType(fields)
+    case Type.ARRAY =>
+      var elementType: DataType = NullType
+      for (i <- 0 until v.arraySize()) {
+        elementType = mergeSchema(elementType, 
schemaOf(v.getElementAtIndex(i)))
+      }
+      ArrayType(elementType)
+    case Type.NULL => NullType
+    case Type.BOOLEAN => BooleanType
+    case Type.LONG => LongType
+    case Type.STRING => StringType
+    case Type.DOUBLE => DoubleType
+    case Type.DECIMAL =>
+      val d = v.getDecimal
+      DecimalType(d.precision(), d.scale())
+  }
+
+  /**
+   * Returns the tightest common type for two given data types. Input struct 
fields are assumed to
+   * be sorted alphabetically.
+   */
+  def mergeSchema(t1: DataType, t2: DataType): DataType =
+    JsonInferSchema.compatibleType(t1, t2, VariantType)
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
index 12c1be7c0de7..7ee522226e3e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
@@ -360,8 +360,10 @@ object JsonInferSchema {
 
   /**
    * Returns the most general data type for two given data types.
+   * When the two types are incompatible, return `defaultDataType` as a 
fallback result.
    */
-  def compatibleType(t1: DataType, t2: DataType): DataType = {
+  def compatibleType(
+      t1: DataType, t2: DataType, defaultDataType: DataType = StringType): 
DataType = {
     TypeCoercion.findTightestCommonType(t1, t2).getOrElse {
       // t1 or t2 is a StructType, ArrayType, or an unexpected type.
       (t1, t2) match {
@@ -399,7 +401,8 @@ object JsonInferSchema {
             val f2Name = fields2(f2Idx).name
             val comp = f1Name.compareTo(f2Name)
             if (comp == 0) {
-              val dataType = compatibleType(fields1(f1Idx).dataType, 
fields2(f2Idx).dataType)
+              val dataType = compatibleType(
+                fields1(f1Idx).dataType, fields2(f2Idx).dataType, 
defaultDataType)
               newFields.add(StructField(f1Name, dataType, nullable = true))
               f1Idx += 1
               f2Idx += 1
@@ -422,21 +425,22 @@ object JsonInferSchema {
           StructType(newFields.toArray(emptyStructFieldArray))
 
         case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, 
containsNull2)) =>
-          ArrayType(compatibleType(elementType1, elementType2), containsNull1 
|| containsNull2)
+          ArrayType(
+            compatibleType(elementType1, elementType2, defaultDataType),
+            containsNull1 || containsNull2)
 
         // The case that given `DecimalType` is capable of given 
`IntegralType` is handled in
         // `findTightestCommonType`. Both cases below will be executed only 
when the given
         // `DecimalType` is not capable of the given `IntegralType`.
         case (t1: IntegralType, t2: DecimalType) =>
-          compatibleType(DecimalType.forType(t1), t2)
+          compatibleType(DecimalType.forType(t1), t2, defaultDataType)
         case (t1: DecimalType, t2: IntegralType) =>
-          compatibleType(t1, DecimalType.forType(t2))
+          compatibleType(t1, DecimalType.forType(t2), defaultDataType)
 
         case (TimestampNTZType, TimestampType) | (TimestampType, 
TimestampNTZType) =>
           TimestampType
 
-        // strings and every string is a Json object.
-        case (_, _) => StringType
+        case (_, _) => defaultDataType
       }
     }
   }
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md 
b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index c461e3ec09fb..05491034e6c7 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -437,6 +437,7 @@
 | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | var_samp 
| SELECT var_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | 
struct<var_samp(col):double> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | variance 
| SELECT variance(col) FROM VALUES (1), (2), (3) AS tab(col) | 
struct<variance(col):double> |
 | org.apache.spark.sql.catalyst.expressions.variant.ParseJson | parse_json | 
SELECT parse_json('{"a":1,"b":0.8}') | 
struct<parse_json({"a":1,"b":0.8}):variant> |
+| org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariant | 
schema_of_variant | SELECT schema_of_variant(parse_json('null')) | 
struct<schema_of_variant(parse_json(null)):string> |
 | 
org.apache.spark.sql.catalyst.expressions.variant.TryVariantGetExpressionBuilder
 | try_variant_get | SELECT try_variant_get(parse_json('{"a": 1}'), '$.a', 
'int') | struct<try_variant_get(parse_json({"a": 1}), $.a):int> |
 | 
org.apache.spark.sql.catalyst.expressions.variant.VariantGetExpressionBuilder | 
variant_get | SELECT variant_get(parse_json('{"a": 1}'), '$.a', 'int') | 
struct<variant_get(parse_json({"a": 1}), $.a):int> |
 | org.apache.spark.sql.catalyst.expressions.xml.XPathBoolean | xpath_boolean | 
SELECT xpath_boolean('<a><b>1</b></a>','a/b') | 
struct<xpath_boolean(<a><b>1</b></a>, a/b):boolean> |
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
index cf12001fa71b..d8b1dca21ca6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
@@ -81,4 +81,35 @@ class VariantEndToEndSuite extends QueryTest with 
SharedSparkSession {
     val expected = new VariantVal(v.getValue, v.getMetadata)
     checkAnswer(variantDF, Seq(Row(expected)))
   }
+
+  test("schema_of_variant") {
+    def check(json: String, expected: String): Unit = {
+      val df = 
Seq(json).toDF("j").selectExpr("schema_of_variant(parse_json(j))")
+      checkAnswer(df, Seq(Row(expected)))
+    }
+
+    check("null", "VOID")
+    check("1", "BIGINT")
+    check("1.0", "DECIMAL(1,0)")
+    check("1E0", "DOUBLE")
+    check("true", "BOOLEAN")
+    check("\"2000-01-01\"", "STRING")
+    check("""{"a":0}""", "STRUCT<a: BIGINT>")
+    check("""{"b": {"c": "c"}, "a":["a"]}""", "STRUCT<a: ARRAY<STRING>, b: 
STRUCT<c: STRING>>")
+    check("[]", "ARRAY<VOID>")
+    check("[false]", "ARRAY<BOOLEAN>")
+    check("[null, 1, 1.0]", "ARRAY<DECIMAL(20,0)>")
+    check("[null, 1, 1.1]", "ARRAY<DECIMAL(21,1)>")
+    check("[123456.789, 123.456789]", "ARRAY<DECIMAL(12,6)>")
+    check("[1, 11111111111111111111111111111111111111]", 
"ARRAY<DECIMAL(38,0)>")
+    check("[1.1, 11111111111111111111111111111111111111]", "ARRAY<DOUBLE>")
+    check("[1, \"1\"]", "ARRAY<VARIANT>")
+    check("[{}, true]", "ARRAY<VARIANT>")
+    check("""[{"c": ""}, {"a": null}, {"b": 1}]""", "ARRAY<STRUCT<a: VOID, b: 
BIGINT, c: STRING>>")
+    check("""[{"a": ""}, {"a": null}, {"b": 1}]""", "ARRAY<STRUCT<a: STRING, 
b: BIGINT>>")
+    check(
+      """[{"a": 1, "b": null}, {"b": true, "a": 1E0}]""",
+      "ARRAY<STRUCT<a: DOUBLE, b: BOOLEAN>>"
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to