This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b371e7dd8800 [SPARK-48224][SQL] Disallow map keys from being of 
variant type
b371e7dd8800 is described below

commit b371e7dd88009195740f8f5b591447441ea43d0b
Author: Harsh Motwani <harsh.motw...@databricks.com>
AuthorDate: Thu May 9 21:47:05 2024 -0700

    [SPARK-48224][SQL] Disallow map keys from being of variant type
    
    ### What changes were proposed in this pull request?
    
    This PR disallows map keys from being of variant type. Therefore, SQL 
statements like `select map(parse_json('{"a": 1}'), 1)`, which would work 
earlier, will throw an exception now.
    
    ### Why are the changes needed?
    
    Allowing variant to be the key type of a map can result in undefined 
behavior as this has not been tested.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, users could use variants as keys in maps earlier. However, this PR 
disallows this possibility.
    
    ### How was this patch tested?
    
    Unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #46516 from harshmotw-db/map_variant_key.
    
    Authored-by: Harsh Motwani <harsh.motw...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../apache/spark/sql/catalyst/util/TypeUtils.scala |  2 +-
 .../catalyst/expressions/ComplexTypeSuite.scala    | 34 +++++++++++++++++++++-
 2 files changed, 34 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index d2c708b380cf..a0d578c66e73 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -58,7 +58,7 @@ object TypeUtils extends QueryErrorsBase {
   }
 
   def checkForMapKeyType(keyType: DataType): TypeCheckResult = {
-    if (keyType.existsRecursively(_.isInstanceOf[MapType])) {
+    if (keyType.existsRecursively(dt => dt.isInstanceOf[MapType] || 
dt.isInstanceOf[VariantType])) {
       DataTypeMismatch(
         errorSubClass = "INVALID_MAP_KEY_TYPE",
         messageParameters = Map(
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 5f135e46a377..497b335289b1 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.catalyst.util.TypeUtils.ordinalNumber
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
 
 class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
 
@@ -359,6 +359,38 @@ class ComplexTypeSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     )
   }
 
+  // map key can't be variant
+  val map6 = CreateMap(Seq(
+    Literal.create(new VariantVal(Array[Byte](), Array[Byte]())),
+    Literal.create(1)
+  ))
+  map6.checkInputDataTypes() match {
+    case TypeCheckResult.TypeCheckSuccess => fail("should not allow variant as 
a part of map key")
+    case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) =>
+      assert(errorSubClass == "INVALID_MAP_KEY_TYPE")
+      assert(messageParameters === Map("keyType" -> "\"VARIANT\""))
+  }
+
+  // map key can't contain variant
+  val map7 = CreateMap(
+    Seq(
+      CreateStruct(
+        Seq(Literal.create(1), Literal.create(new VariantVal(Array[Byte](), 
Array[Byte]())))
+      ),
+      Literal.create(1)
+    )
+  )
+  map7.checkInputDataTypes() match {
+    case TypeCheckResult.TypeCheckSuccess => fail("should not allow variant as 
a part of map key")
+    case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) =>
+      assert(errorSubClass == "INVALID_MAP_KEY_TYPE")
+      assert(
+        messageParameters === Map(
+          "keyType" -> "\"STRUCT<col1: INT NOT NULL, col2: VARIANT NOT NULL>\""
+        )
+      )
+  }
+
   test("MapFromArrays") {
     val intSeq = Seq(5, 10, 15, 20, 25)
     val longSeq = intSeq.map(_.toLong)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to