This is an automated email from the ASF dual-hosted git repository.

yuanzhou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 8e55381e18 [GLUTEN-11088] Fix GlutenDataFrameFunctionsSuite in 
Spark-4.0 (#11195)
8e55381e18 is described below

commit 8e55381e1806b2045f2a0993f8d861e76fca1135
Author: Mingliang Zhu <[email protected]>
AuthorDate: Fri Nov 28 18:24:57 2025 +0800

    [GLUTEN-11088] Fix GlutenDataFrameFunctionsSuite in Spark-4.0 (#11195)
    
    
https://github.com/apache/spark/blob/29434ea766b0fc3c3bf6eaadb43a8f931133649e/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L2928-L2937
    Vanilla spark throw SparkRuntimeException, gluten throw SparkException. 
This patch modified the tests to adapt with Gluten code
---
 .../gluten/utils/velox/VeloxTestSettings.scala     |   2 +-
 .../spark/sql/GlutenDataFrameFunctionsSuite.scala  | 229 +++++++++++++++++++++
 2 files changed, 230 insertions(+), 1 deletion(-)

diff --git 
a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
 
b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index 74ace889e9..5fa43d50f7 100644
--- 
a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++ 
b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -758,7 +758,7 @@ class VeloxTestSettings extends BackendTestSettings {
     .exclude("aggregate function - array for non-primitive type")
     // Rewrite this test because Velox sorts rows by key for primitive data 
types, which disrupts the original row sequence.
     .exclude("map_zip_with function - map of primitive types")
-    // TODO: fix in Spark-4.0
+    // Vanilla spark throw SparkRuntimeException, gluten throw SparkException.
     .exclude("map_concat function")
     .exclude("transform keys function - primitive data types")
   enableSuite[GlutenDataFrameHintSuite]
diff --git 
a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
 
b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
index 2b0b40790a..49f6052b20 100644
--- 
a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
+++ 
b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
@@ -16,7 +16,10 @@
  */
 package org.apache.spark.sql
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{IntegerType, MapType, StringType, 
StructField, StructType}
 
 class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with 
GlutenSQLTestsTrait {
   import testImplicits._
@@ -49,4 +52,230 @@ class GlutenDataFrameFunctionsSuite extends 
DataFrameFunctionsSuite with GlutenS
       false
     )
   }
+
+  testGluten("map_concat function") {
+    val df1 = Seq(
+      (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 4 -> 400)),
+      (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 1 -> 400)),
+      (null, Map[Int, Int](3 -> 300, 4 -> 400))
+    ).toDF("map1", "map2")
+
+    val expected1a = Seq(
+      Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)),
+      Row(Map(1 -> 400, 2 -> 200, 3 -> 300)),
+      Row(null)
+    )
+
+    intercept[SparkException](df1.selectExpr("map_concat(map1, 
map2)").collect())
+    intercept[SparkException](df1.select(map_concat($"map1", 
$"map2")).collect())
+    withSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key -> 
SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
+      checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a)
+      checkAnswer(df1.select(map_concat($"map1", $"map2")), expected1a)
+    }
+
+    val expected1b = Seq(
+      Row(Map(1 -> 100, 2 -> 200)),
+      Row(Map(1 -> 100, 2 -> 200)),
+      Row(null)
+    )
+
+    checkAnswer(df1.selectExpr("map_concat(map1)"), expected1b)
+    checkAnswer(df1.select(map_concat($"map1")), expected1b)
+
+    val df2 = Seq(
+      (
+        Map[Array[Int], Int](Array(1) -> 100, Array(2) -> 200),
+        Map[String, Int]("3" -> 300, "4" -> 400)
+      )
+    ).toDF("map1", "map2")
+
+    val expected2 = Seq(Row(Map()))
+
+    checkAnswer(df2.selectExpr("map_concat()"), expected2)
+    checkAnswer(df2.select(map_concat()), expected2)
+
+    val df3 = {
+      val schema = StructType(
+        StructField("map1", MapType(StringType, IntegerType, true), false) ::
+          StructField("map2", MapType(StringType, IntegerType, false), false) 
:: Nil
+      )
+      val data = Seq(
+        Row(Map[String, Any]("a" -> 1, "b" -> null), Map[String, Any]("c" -> 
3, "d" -> 4)),
+        Row(Map[String, Any]("a" -> 1, "b" -> 2), Map[String, Any]("c" -> 3, 
"d" -> 4))
+      )
+      spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
+    }
+
+    val expected3 = Seq(
+      Row(Map[String, Any]("a" -> 1, "b" -> null, "c" -> 3, "d" -> 4)),
+      Row(Map[String, Any]("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4))
+    )
+
+    checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected3)
+    checkAnswer(df3.select(map_concat($"map1", $"map2")), expected3)
+
+    checkError(
+      exception = intercept[AnalysisException] {
+        df2.selectExpr("map_concat(map1, map2)").collect()
+      },
+      condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
+      sqlState = None,
+      parameters = Map(
+        "sqlExpr" -> "\"map_concat(map1, map2)\"",
+        "dataType" -> "(\"MAP<ARRAY<INT>, INT>\" or \"MAP<STRING, INT>\")",
+        "functionName" -> "`map_concat`"),
+      context = ExpectedContext(fragment = "map_concat(map1, map2)", start = 
0, stop = 21)
+    )
+
+    checkError(
+      exception = intercept[AnalysisException] {
+        df2.select(map_concat($"map1", $"map2")).collect()
+      },
+      condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
+      sqlState = None,
+      parameters = Map(
+        "sqlExpr" -> "\"map_concat(map1, map2)\"",
+        "dataType" -> "(\"MAP<ARRAY<INT>, INT>\" or \"MAP<STRING, INT>\")",
+        "functionName" -> "`map_concat`"),
+      context =
+        ExpectedContext(fragment = "map_concat", callSitePattern = 
getCurrentClassCallSitePattern)
+    )
+
+    checkError(
+      exception = intercept[AnalysisException] {
+        df2.selectExpr("map_concat(map1, 12)").collect()
+      },
+      condition = "DATATYPE_MISMATCH.MAP_CONCAT_DIFF_TYPES",
+      sqlState = None,
+      parameters = Map(
+        "sqlExpr" -> "\"map_concat(map1, 12)\"",
+        "dataType" -> "[\"MAP<ARRAY<INT>, INT>\", \"INT\"]",
+        "functionName" -> "`map_concat`"),
+      context = ExpectedContext(fragment = "map_concat(map1, 12)", start = 0, 
stop = 19)
+    )
+
+    checkError(
+      exception = intercept[AnalysisException] {
+        df2.select(map_concat($"map1", lit(12))).collect()
+      },
+      condition = "DATATYPE_MISMATCH.MAP_CONCAT_DIFF_TYPES",
+      sqlState = None,
+      parameters = Map(
+        "sqlExpr" -> "\"map_concat(map1, 12)\"",
+        "dataType" -> "[\"MAP<ARRAY<INT>, INT>\", \"INT\"]",
+        "functionName" -> "`map_concat`"),
+      context =
+        ExpectedContext(fragment = "map_concat", callSitePattern = 
getCurrentClassCallSitePattern)
+    )
+  }
+
+  testGluten("transform keys function - primitive data types") {
+    val dfExample1 = Seq(
+      Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7)
+    ).toDF("i")
+
+    val dfExample2 = Seq(
+      Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70)
+    ).toDF("j")
+
+    val dfExample3 = Seq(
+      Map[Int, Boolean](25 -> true, 26 -> false)
+    ).toDF("x")
+
+    val dfExample4 = Seq(
+      Map[Array[Int], Boolean](Array(1, 2) -> false)
+    ).toDF("y")
+
+    def testMapOfPrimitiveTypesCombination(): Unit = {
+      checkAnswer(
+        dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"),
+        Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7))))
+
+      checkAnswer(
+        dfExample1.select(transform_keys(col("i"), (k, v) => k + v)),
+        Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7))))
+
+      checkAnswer(
+        dfExample2.selectExpr(
+          "transform_keys(j, " +
+            "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 
'three'))[k])"),
+        Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))
+      )
+
+      checkAnswer(
+        dfExample2.select(
+          transform_keys(
+            col("j"),
+            (k, v) =>
+              element_at(
+                map_from_arrays(
+                  array(lit(1), lit(2), lit(3)),
+                  array(lit("one"), lit("two"), lit("three"))
+                ),
+                k
+              )
+          )
+        ),
+        Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))
+      )
+
+      checkAnswer(
+        dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS 
BIGINT) + k)"),
+        Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7))))
+
+      checkAnswer(
+        dfExample2.select(transform_keys(col("j"), (k, v) => (v * 
2).cast("bigint") + k)),
+        Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7))))
+
+      checkAnswer(
+        dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"),
+        Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))))
+
+      checkAnswer(
+        dfExample2.select(transform_keys(col("j"), (k, v) => k + v)),
+        Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))))
+
+      intercept[SparkException] {
+        dfExample3.selectExpr("transform_keys(x, (k, v) ->  k % 2 = 0 OR 
v)").collect()
+      }
+      intercept[SparkException] {
+        dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || 
v)).collect()
+      }
+      withSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key -> 
SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
+        checkAnswer(
+          dfExample3.selectExpr("transform_keys(x, (k, v) ->  k % 2 = 0 OR 
v)"),
+          Seq(Row(Map(true -> true, true -> false))))
+
+        checkAnswer(
+          dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || 
v)),
+          Seq(Row(Map(true -> true, true -> false))))
+      }
+
+      checkAnswer(
+        dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * 
k))"),
+        Seq(Row(Map(50 -> true, 78 -> false))))
+
+      checkAnswer(
+        dfExample3.select(transform_keys(col("x"), (k, v) => when(v, k * 
2).otherwise(k * 3))),
+        Seq(Row(Map(50 -> true, 78 -> false))))
+
+      checkAnswer(
+        dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 
3) AND v)"),
+        Seq(Row(Map(false -> false))))
+
+      checkAnswer(
+        dfExample4.select(transform_keys(col("y"), (k, v) => array_contains(k, 
lit(3)) && v)),
+        Seq(Row(Map(false -> false))))
+    }
+
+    // Test with local relation, the Project will be evaluated without codegen
+    testMapOfPrimitiveTypesCombination()
+    dfExample1.cache()
+    dfExample2.cache()
+    dfExample3.cache()
+    dfExample4.cache()
+    // Test with cached relation, the Project will be evaluated with codegen
+    testMapOfPrimitiveTypesCombination()
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to