This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ffa4d198cec6 [SPARK-48067][SQL] Fix variant default columns
ffa4d198cec6 is described below

commit ffa4d198cec6620f0385a0e428b023d2ac4e3d5c
Author: Richard Chen <r.c...@databricks.com>
AuthorDate: Thu May 2 12:22:02 2024 -0700

    [SPARK-48067][SQL] Fix variant default columns
    
    ### What changes were proposed in this pull request?
    
    Changes the literal `sql` representation of a variant value to 
`parse_json(variant.toJson)`. This is because there is no other representation 
of a literal variant.
    
    This allows variant default columns to work because default columns store a 
literal string representation in the schema struct fields metadata as the 
default value.
    
    ### Why are the changes needed?
    
    previously we could not set a variant default column like
    ```
    create table t(
            v6 variant default parse_json('{\"k\": \"v\"}')
    )
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    added UT
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #46312 from richardc-db/fix_variant_default_cols.
    
    Authored-by: Richard Chen <r.c...@databricks.com>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../spark/sql/catalyst/expressions/literals.scala  |   4 +
 .../scala/org/apache/spark/sql/VariantSuite.scala  | 145 ++++++++++++++++++++-
 2 files changed, 146 insertions(+), 3 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 0fad3eff2da5..4cffc7f0b53a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -42,6 +42,7 @@ import org.json4s.JsonAST._
 
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, 
ScalaReflection}
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import 
org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
 import org.apache.spark.sql.catalyst.trees.TreePattern
 import org.apache.spark.sql.catalyst.trees.TreePattern.{LITERAL, NULL_LITERAL, 
TRUE_OR_FALSE_LITERAL}
 import org.apache.spark.sql.catalyst.types._
@@ -204,6 +205,8 @@ object Literal {
       create(new GenericInternalRow(
         struct.fields.map(f => default(f.dataType).value)), struct)
     case udt: UserDefinedType[_] => Literal(default(udt.sqlType).value, udt)
+    case VariantType =>
+      create(VariantExpressionEvalUtils.castToVariant(0, IntegerType), 
VariantType)
     case other =>
       throw QueryExecutionErrors.noDefaultForDataTypeError(dataType)
   }
@@ -549,6 +552,7 @@ case class Literal (value: Any, dataType: DataType) extends 
LeafExpression {
           s"${Literal(kv._1, mapType.keyType).sql}, ${Literal(kv._2, 
mapType.valueType).sql}"
         }
       s"MAP(${keysAndValues.mkString(", ")})"
+    case (v: VariantVal, variantType: VariantType) => 
s"PARSE_JSON('${v.toJson(timeZoneId)}')"
     case _ => value.toString
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
index 19e5f9ba63e6..caab98b6239a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
@@ -26,15 +26,17 @@ import scala.jdk.CollectionConverters._
 import scala.util.Random
 
 import org.apache.spark.SparkRuntimeException
-import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
+import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, 
ExpressionEvalHelper, Literal}
+import 
org.apache.spark.sql.catalyst.expressions.variant.{VariantExpressionEvalUtils, 
VariantGet}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, 
GenericArrayData}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.VariantVal
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
 import org.apache.spark.util.ArrayImplicits._
 
-class VariantSuite extends QueryTest with SharedSparkSession {
+class VariantSuite extends QueryTest with SharedSparkSession with 
ExpressionEvalHelper {
   import testImplicits._
 
   test("basic tests") {
@@ -445,4 +447,141 @@ class VariantSuite extends QueryTest with 
SharedSparkSession {
       }
     }
   }
+
+  test("SPARK-48067: default variant columns works") {
+    withTable("t") {
+      sql("""create table t(
+        v1 variant default null,
+        v2 variant default parse_json(null),
+        v3 variant default cast(null as variant),
+        v4 variant default parse_json('1'),
+        v5 variant default parse_json('1'),
+        v6 variant default parse_json('{\"k\": \"v\"}'),
+        v7 variant default cast(5 as int),
+        v8 variant default cast('hello' as string),
+        v9 variant default parse_json(to_json(parse_json('{\"k\": \"v\"}')))
+      ) using parquet""")
+      sql("""insert into t values(DEFAULT, DEFAULT, DEFAULT, DEFAULT, DEFAULT, 
DEFAULT, DEFAULT,
+        DEFAULT, DEFAULT)""")
+
+      val expected = sql("""select
+        cast(null as variant) as v1,
+        parse_json(null) as v2,
+        cast(null as variant) as v3,
+        parse_json('1') as v4,
+        parse_json('1') as v5,
+        parse_json('{\"k\": \"v\"}') as v6,
+        cast(cast(5 as int) as variant) as v7,
+        cast('hello' as variant) as v8,
+        parse_json(to_json(parse_json('{\"k\": \"v\"}'))) as v9
+      """)
+      val actual = sql("select * from t")
+      checkAnswer(actual, expected.collect())
+    }
+  }
+
+  Seq(
+    (
+      "basic int parse json",
+      VariantExpressionEvalUtils.parseJson(UTF8String.fromString("1")),
+      VariantType
+    ),
+    (
+      "basic json parse json",
+      VariantExpressionEvalUtils.parseJson(UTF8String.fromString("{\"k\": 
\"v\"}")),
+      VariantType
+    ),
+    (
+      "basic null parse json",
+      VariantExpressionEvalUtils.parseJson(UTF8String.fromString("null")),
+      VariantType
+    ),
+    (
+      "basic null",
+      null,
+      VariantType
+    ),
+    (
+      "basic array",
+      new GenericArrayData(Array[Int](1, 2, 3, 4, 5)),
+      new ArrayType(IntegerType, false)
+    ),
+    (
+      "basic string",
+      UTF8String.fromString("literal string"),
+      StringType
+    ),
+    (
+      "basic timestamp",
+      0L,
+      TimestampType
+    ),
+    (
+      "basic int",
+      0,
+      IntegerType
+    ),
+    (
+      "basic struct",
+      Literal.default(new StructType().add("col0", StringType)).eval(),
+      new StructType().add("col0", StringType)
+    ),
+    (
+      "complex struct with child variant",
+      Literal.default(new StructType()
+        .add("col0", StringType)
+        .add("col1", new StructType().add("col0", VariantType))
+        .add("col2", VariantType)
+        .add("col3", new ArrayType(VariantType, false))
+      ).eval(),
+      new StructType()
+        .add("col0", StringType)
+        .add("col1", new StructType().add("col0", VariantType))
+        .add("col2", VariantType)
+        .add("col3", new ArrayType(VariantType, false))
+    ),
+    (
+      "basic array with null",
+      new GenericArrayData(Array[Any](1, 2, null)),
+      new ArrayType(IntegerType, true)
+    ),
+    (
+      "basic map with null",
+      new ArrayBasedMapData(
+        new GenericArrayData(Array[Any](UTF8String.fromString("k1"), 
UTF8String.fromString("k2"))),
+        new GenericArrayData(Array[Any](1, null))
+      ),
+      new MapType(StringType, IntegerType, true)
+    )
+  ).foreach { case (testName, value, dt) =>
+    test(s"SPARK-48067: Variant literal `sql` correctly recreates the variant 
- $testName") {
+      val l = Literal.create(
+        VariantExpressionEvalUtils.castToVariant(value, 
dt.asInstanceOf[DataType]), VariantType)
+      val jsonString = l.eval().asInstanceOf[VariantVal]
+        .toJson(DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone))
+      val expectedSql = s"PARSE_JSON('$jsonString')"
+      assert(l.sql == expectedSql)
+      val valueFromLiteralSql =
+        spark.sql(s"select ${l.sql}").collect()(0).getAs[VariantVal](0)
+
+      // Cast the variants to their specified type to compare for logical 
equality.
+      // Currently, variant equality naively compares its value and metadata 
binaries. However,
+      // variant equality is more complex than this.
+      val castVariantExpr = VariantGet(
+        l,
+        Literal.create(UTF8String.fromString("$"), StringType),
+        dt,
+        true,
+        
Some(DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone).toString())
+      )
+      val sqlVariantExpr = VariantGet(
+        Literal.create(valueFromLiteralSql, VariantType),
+        Literal.create(UTF8String.fromString("$"), StringType),
+        dt,
+        true,
+        
Some(DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone).toString())
+      )
+      checkEvaluation(castVariantExpr, sqlVariantExpr.eval())
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to