This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b371e7dd8800 [SPARK-48224][SQL] Disallow map keys from being of variant type b371e7dd8800 is described below commit b371e7dd88009195740f8f5b591447441ea43d0b Author: Harsh Motwani <harsh.motw...@databricks.com> AuthorDate: Thu May 9 21:47:05 2024 -0700 [SPARK-48224][SQL] Disallow map keys from being of variant type ### What changes were proposed in this pull request? This PR disallows map keys from being of variant type. Therefore, SQL statements like `select map(parse_json('{"a": 1}'), 1)`, which would work earlier, will throw an exception now. ### Why are the changes needed? Allowing variant to be the key type of a map can result in undefined behavior as this has not been tested. ### Does this PR introduce _any_ user-facing change? Yes, users could use variants as keys in maps earlier. However, this PR disallows this possibility. ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #46516 from harshmotw-db/map_variant_key. Authored-by: Harsh Motwani <harsh.motw...@databricks.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../apache/spark/sql/catalyst/util/TypeUtils.scala | 2 +- .../catalyst/expressions/ComplexTypeSuite.scala | 34 +++++++++++++++++++++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index d2c708b380cf..a0d578c66e73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -58,7 +58,7 @@ object TypeUtils extends QueryErrorsBase { } def checkForMapKeyType(keyType: DataType): TypeCheckResult = { - if (keyType.existsRecursively(_.isInstanceOf[MapType])) { + if (keyType.existsRecursively(dt => dt.isInstanceOf[MapType] || dt.isInstanceOf[VariantType])) { DataTypeMismatch( errorSubClass = "INVALID_MAP_KEY_TYPE", messageParameters = Map( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 5f135e46a377..497b335289b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.TypeUtils.ordinalNumber import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{UTF8String, VariantVal} class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -359,6 +359,38 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { ) } + // map key can't be variant + val map6 = CreateMap(Seq( + Literal.create(new VariantVal(Array[Byte](), Array[Byte]())), + Literal.create(1) + )) + map6.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow variant as a part of map key") + case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => + assert(errorSubClass == "INVALID_MAP_KEY_TYPE") + assert(messageParameters === Map("keyType" -> "\"VARIANT\"")) + } + + // map key can't contain variant + val map7 = CreateMap( + Seq( + CreateStruct( + Seq(Literal.create(1), Literal.create(new VariantVal(Array[Byte](), Array[Byte]()))) + ), + Literal.create(1) + ) + ) + map7.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow variant as a part of map key") + case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => + assert(errorSubClass == "INVALID_MAP_KEY_TYPE") + assert( + messageParameters === Map( + "keyType" -> "\"STRUCT<col1: INT NOT NULL, col2: VARIANT NOT NULL>\"" + ) + ) + } + test("MapFromArrays") { val intSeq = Seq(5, 10, 15, 20, 25) val longSeq = intSeq.map(_.toLong) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org