This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new cef3e0528 test: fix Spark 3.5 tests (#1482)
cef3e0528 is described below

commit cef3e05283f98f294eecae9c3460b0d79b222fcc
Author: KAZUYUKI TANIMURA <[email protected]>
AuthorDate: Tue Mar 11 14:22:14 2025 -0700

    test: fix Spark 3.5 tests (#1482)
---
 .github/workflows/spark_sql_test.yml               |   2 +-
 .../java/org/apache/comet/parquet/BatchReader.java |  32 +--
 .../spark/sql/comet/shims/ShimTaskMetrics.scala    |  29 +++
 .../spark/sql/comet/shims/ShimTaskMetrics.scala    |  29 +++
 .../spark/sql/comet/shims/ShimTaskMetrics.scala    |  31 +++
 .../spark/sql/comet/shims/ShimTaskMetrics.scala    |  29 +++
 dev/diffs/{3.5.1.diff => 3.5.4.diff}               | 214 ++++++++++++---------
 docs/source/contributor-guide/spark-sql-tests.md   |   6 +-
 .../char_varchar_utils/read_side_padding.rs        |  22 ++-
 pom.xml                                            |   4 +-
 10 files changed, 266 insertions(+), 132 deletions(-)

diff --git a/.github/workflows/spark_sql_test.yml 
b/.github/workflows/spark_sql_test.yml
index b325a5193..8d60f0769 100644
--- a/.github/workflows/spark_sql_test.yml
+++ b/.github/workflows/spark_sql_test.yml
@@ -45,7 +45,7 @@ jobs:
       matrix:
         os: [ubuntu-24.04]
         java-version: [11]
-        spark-version: [{short: '3.4', full: '3.4.3'}, {short: '3.5', full: 
'3.5.1'}]
+        spark-version: [{short: '3.4', full: '3.4.3'}, {short: '3.5', full: 
'3.5.4'}]
         module:
           - {name: "catalyst", args1: "catalyst/test", args2: ""}
           - {name: "sql/core-1", args1: "", args2: sql/testOnly * -- -l 
org.apache.spark.tags.ExtendedSQLTest -l org.apache.spark.tags.SlowSQLTest}
diff --git a/common/src/main/java/org/apache/comet/parquet/BatchReader.java 
b/common/src/main/java/org/apache/comet/parquet/BatchReader.java
index 675dae9e7..dbf1b8180 100644
--- a/common/src/main/java/org/apache/comet/parquet/BatchReader.java
+++ b/common/src/main/java/org/apache/comet/parquet/BatchReader.java
@@ -21,8 +21,6 @@ package org.apache.comet.parquet;
 
 import java.io.Closeable;
 import java.io.IOException;
-import java.lang.reflect.InvocationTargetException;
-import java.lang.reflect.Method;
 import java.net.URI;
 import java.net.URISyntaxException;
 import java.util.Arrays;
@@ -35,8 +33,6 @@ import java.util.concurrent.Future;
 import java.util.concurrent.LinkedBlockingQueue;
 
 import scala.Option;
-import scala.collection.Seq;
-import scala.collection.mutable.Buffer;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -61,9 +57,9 @@ import org.apache.parquet.schema.MessageType;
 import org.apache.parquet.schema.Type;
 import org.apache.spark.TaskContext;
 import org.apache.spark.TaskContext$;
-import org.apache.spark.executor.TaskMetrics;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.comet.parquet.CometParquetReadSupport;
+import org.apache.spark.sql.comet.shims.ShimTaskMetrics;
 import org.apache.spark.sql.execution.datasources.PartitionedFile;
 import 
org.apache.spark.sql.execution.datasources.parquet.ParquetToSparkSchemaConverter;
 import org.apache.spark.sql.execution.metric.SQLMetric;
@@ -350,7 +346,8 @@ public class BatchReader extends RecordReader<Void, 
ColumnarBatch> implements Cl
     // Note that this tries to get thread local TaskContext object, if this is 
called at other
     // thread, it won't update the accumulator.
     if (taskContext != null) {
-      Option<AccumulatorV2<?, ?>> accu = 
getTaskAccumulator(taskContext.taskMetrics());
+      Option<AccumulatorV2<?, ?>> accu =
+          ShimTaskMetrics.getTaskAccumulator(taskContext.taskMetrics());
       if (accu.isDefined() && 
accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) {
         @SuppressWarnings("unchecked")
         AccumulatorV2<Integer, Integer> intAccum = (AccumulatorV2<Integer, 
Integer>) accu.get();
@@ -637,27 +634,4 @@ public class BatchReader extends RecordReader<Void, 
ColumnarBatch> implements Cl
       }
     }
   }
-
-  // Signature of externalAccums changed from returning a Buffer to returning 
a Seq. If comet is
-  // expecting a Buffer but the Spark version returns a Seq or vice versa, we 
get a
-  // method not found exception.
-  @SuppressWarnings("unchecked")
-  private Option<AccumulatorV2<?, ?>> getTaskAccumulator(TaskMetrics 
taskMetrics) {
-    Method externalAccumsMethod;
-    try {
-      externalAccumsMethod = 
TaskMetrics.class.getDeclaredMethod("externalAccums");
-      externalAccumsMethod.setAccessible(true);
-      String returnType = externalAccumsMethod.getReturnType().getName();
-      if (returnType.equals("scala.collection.mutable.Buffer")) {
-        return ((Buffer<AccumulatorV2<?, ?>>) 
externalAccumsMethod.invoke(taskMetrics))
-            .lastOption();
-      } else if (returnType.equals("scala.collection.Seq")) {
-        return ((Seq<AccumulatorV2<?, ?>>) 
externalAccumsMethod.invoke(taskMetrics)).lastOption();
-      } else {
-        return Option.apply(null); // None
-      }
-    } catch (NoSuchMethodException | InvocationTargetException | 
IllegalAccessException e) {
-      return Option.apply(null); // None
-    }
-  }
 }
diff --git 
a/common/src/main/spark-3.3/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
 
b/common/src/main/spark-3.3/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
new file mode 100644
index 000000000..5b2a5fb5b
--- /dev/null
+++ 
b/common/src/main/spark-3.3/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.AccumulatorV2
+
+object ShimTaskMetrics {
+
+  def getTaskAccumulator(taskMetrics: TaskMetrics): Option[AccumulatorV2[_, 
_]] =
+    taskMetrics.externalAccums.lastOption
+}
diff --git 
a/common/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
 
b/common/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
new file mode 100644
index 000000000..5b2a5fb5b
--- /dev/null
+++ 
b/common/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.AccumulatorV2
+
+object ShimTaskMetrics {
+
+  def getTaskAccumulator(taskMetrics: TaskMetrics): Option[AccumulatorV2[_, 
_]] =
+    taskMetrics.externalAccums.lastOption
+}
diff --git 
a/common/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
 
b/common/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
new file mode 100644
index 000000000..2ca0ef277
--- /dev/null
+++ 
b/common/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.AccumulatorV2
+
+object ShimTaskMetrics {
+
+  def getTaskAccumulator(taskMetrics: TaskMetrics): Option[AccumulatorV2[_, 
_]] =
+    taskMetrics.withExternalAccums(identity[ArrayBuffer[AccumulatorV2[_, 
_]]](_)).lastOption
+}
diff --git 
a/common/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
 
b/common/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
new file mode 100644
index 000000000..5b2a5fb5b
--- /dev/null
+++ 
b/common/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.AccumulatorV2
+
+object ShimTaskMetrics {
+
+  def getTaskAccumulator(taskMetrics: TaskMetrics): Option[AccumulatorV2[_, 
_]] =
+    taskMetrics.externalAccums.lastOption
+}
diff --git a/dev/diffs/3.5.1.diff b/dev/diffs/3.5.4.diff
similarity index 96%
rename from dev/diffs/3.5.1.diff
rename to dev/diffs/3.5.4.diff
index 762cd948d..47bc3ccd0 100644
--- a/dev/diffs/3.5.1.diff
+++ b/dev/diffs/3.5.4.diff
@@ -1,5 +1,5 @@
 diff --git a/pom.xml b/pom.xml
-index 0f504dbee85..430ec217e59 100644
+index 8dc47f391f9..8a3e72133a8 100644
 --- a/pom.xml
 +++ b/pom.xml
 @@ -152,6 +152,8 @@
@@ -11,9 +11,9 @@ index 0f504dbee85..430ec217e59 100644
      <!--
      If you changes codahale.metrics.version, you also need to change
      the link to metrics.dropwizard.io in docs/monitoring.md.
-@@ -2787,6 +2789,25 @@
-         <artifactId>arpack</artifactId>
-         <version>${netlib.ludovic.dev.version}</version>
+@@ -2836,6 +2838,25 @@
+         <artifactId>okio</artifactId>
+         <version>${okio.version}</version>
        </dependency>
 +      <dependency>
 +        <groupId>org.apache.datafusion</groupId>
@@ -38,7 +38,7 @@ index 0f504dbee85..430ec217e59 100644
    </dependencyManagement>
  
 diff --git a/sql/core/pom.xml b/sql/core/pom.xml
-index c46ab7b8fce..13357e8c7a6 100644
+index 9577de81c20..a37f4a1f89f 100644
 --- a/sql/core/pom.xml
 +++ b/sql/core/pom.xml
 @@ -77,6 +77,10 @@
@@ -203,7 +203,7 @@ index 0efe0877e9b..423d3b3d76d 100644
  -- SELECT_HAVING
  -- 
https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/select_having.sql
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
-index 8331a3c10fc..b4e22732a91 100644
+index 9815cb816c9..95b5f9992b0 100644
 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
 +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
 @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants
@@ -226,10 +226,10 @@ index 8331a3c10fc..b4e22732a91 100644
  
    test("A cached table preserves the partitioning and ordering of its cached 
SparkPlan") {
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
-index 631fcd8c0d8..6df0e1b4176 100644
+index 5a8681aed97..da9d25e2eb4 100644
 --- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
 +++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
-@@ -27,7 +27,7 @@ import org.apache.spark.{SparkException, SparkThrowable}
+@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Expand
  import org.apache.spark.sql.execution.WholeStageCodegenExec
  import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
  import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
ObjectHashAggregateExec, SortAggregateExec}
@@ -238,7 +238,7 @@ index 631fcd8c0d8..6df0e1b4176 100644
  import org.apache.spark.sql.expressions.Window
  import org.apache.spark.sql.functions._
  import org.apache.spark.sql.internal.SQLConf
-@@ -792,7 +792,7 @@ class DataFrameAggregateSuite extends QueryTest
+@@ -793,7 +793,7 @@ class DataFrameAggregateSuite extends QueryTest
        assert(objHashAggPlans.nonEmpty)
  
        val exchangePlans = collect(aggPlan) {
@@ -263,7 +263,7 @@ index 56e9520fdab..917932336df 100644
            spark.range(100).write.saveAsTable(s"$dbName.$table2Name")
  
 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
-index 002719f0689..784d24afe2d 100644
+index 7ee18df3756..64f01a68048 100644
 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
 +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
 @@ -40,11 +40,12 @@ import 
org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
@@ -280,7 +280,7 @@ index 002719f0689..784d24afe2d 100644
  import org.apache.spark.sql.expressions.{Aggregator, Window}
  import org.apache.spark.sql.functions._
  import org.apache.spark.sql.internal.SQLConf
-@@ -2020,7 +2021,7 @@ class DataFrameSuite extends QueryTest
+@@ -2006,7 +2007,7 @@ class DataFrameSuite extends QueryTest
            fail("Should not have back to back Aggregates")
          }
          atFirstAgg = true
@@ -289,7 +289,7 @@ index 002719f0689..784d24afe2d 100644
        case _ =>
      }
    }
-@@ -2344,7 +2345,7 @@ class DataFrameSuite extends QueryTest
+@@ -2330,7 +2331,7 @@ class DataFrameSuite extends QueryTest
        checkAnswer(join, df)
        assert(
          collect(join.queryExecution.executedPlan) {
@@ -298,7 +298,7 @@ index 002719f0689..784d24afe2d 100644
        assert(
          collect(join.queryExecution.executedPlan) { case e: 
ReusedExchangeExec => true }.size === 1)
        val broadcasted = broadcast(join)
-@@ -2352,10 +2353,12 @@ class DataFrameSuite extends QueryTest
+@@ -2338,10 +2339,12 @@ class DataFrameSuite extends QueryTest
        checkAnswer(join2, df)
        assert(
          collect(join2.queryExecution.executedPlan) {
@@ -313,7 +313,7 @@ index 002719f0689..784d24afe2d 100644
        assert(
          collect(join2.queryExecution.executedPlan) { case e: 
ReusedExchangeExec => true }.size == 4)
      }
-@@ -2915,7 +2918,7 @@ class DataFrameSuite extends QueryTest
+@@ -2901,7 +2904,7 @@ class DataFrameSuite extends QueryTest
  
      // Assert that no extra shuffle introduced by cogroup.
      val exchanges = collect(df3.queryExecution.executedPlan) {
@@ -322,7 +322,7 @@ index 002719f0689..784d24afe2d 100644
      }
      assert(exchanges.size == 2)
    }
-@@ -3364,7 +3367,8 @@ class DataFrameSuite extends QueryTest
+@@ -3350,7 +3353,8 @@ class DataFrameSuite extends QueryTest
      assert(df2.isLocal)
    }
  
@@ -333,7 +333,7 @@ index 002719f0689..784d24afe2d 100644
        sql(
          """
 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
-index c2fe31520ac..0f54b233d14 100644
+index f32b32ffc5a..447d7c6416e 100644
 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
 +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
 @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.{LeftAnti, 
LeftSemi}
@@ -715,10 +715,10 @@ index 7af826583bd..3c3def1eb67 100644
      assert(shuffleMergeJoins.size == 1)
    }
 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
-index 9dcf7ec2904..d8b014a4eb8 100644
+index 4d256154c85..43f0bebb00c 100644
 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
 +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
-@@ -30,7 +30,8 @@ import 
org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+@@ -31,7 +31,8 @@ import 
org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
  import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, 
SortOrder}
  import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
  import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, Join, 
JoinHint, NO_BROADCAST_AND_REPLICATION}
@@ -728,7 +728,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
  import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
  import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, 
ShuffleExchangeLike}
  import org.apache.spark.sql.execution.joins._
-@@ -801,7 +802,8 @@ class JoinSuite extends QueryTest with SharedSparkSession 
with AdaptiveSparkPlan
+@@ -802,7 +803,8 @@ class JoinSuite extends QueryTest with SharedSparkSession 
with AdaptiveSparkPlan
      }
    }
  
@@ -738,7 +738,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
      withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
        SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "0",
        SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") {
-@@ -927,10 +929,12 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+@@ -928,10 +930,12 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
        val physical = df.queryExecution.sparkPlan
        val physicalJoins = physical.collect {
          case j: SortMergeJoinExec => j
@@ -751,7 +751,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
        }
        // This only applies to the above tested queries, in which a child 
SortMergeJoin always
        // contains the SortOrder required by its parent SortMergeJoin. Thus, 
SortExec should never
-@@ -1176,9 +1180,11 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+@@ -1177,9 +1181,11 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
        val plan = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType)
          .groupBy($"k1").count()
          .queryExecution.executedPlan
@@ -765,7 +765,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
      })
    }
  
-@@ -1195,10 +1201,11 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+@@ -1196,10 +1202,11 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
          .join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType)
          .queryExecution
          .executedPlan
@@ -779,7 +779,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
      })
  
      // Test shuffled hash join
-@@ -1208,10 +1215,13 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+@@ -1209,10 +1216,13 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
          .join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType)
          .queryExecution
          .executedPlan
@@ -796,7 +796,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
      })
    }
  
-@@ -1302,12 +1312,12 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+@@ -1303,12 +1313,12 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
      inputDFs.foreach { case (df1, df2, joinExprs) =>
        val smjDF = df1.join(df2.hint("SHUFFLE_MERGE"), joinExprs, "full")
        assert(collect(smjDF.queryExecution.executedPlan) {
@@ -811,7 +811,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
        // Same result between shuffled hash join and sort merge join
        checkAnswer(shjDF, smjResult)
      }
-@@ -1366,12 +1376,14 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+@@ -1367,12 +1377,14 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
            val smjDF = df1.hint("SHUFFLE_MERGE").join(df2, joinExprs, 
"leftouter")
            assert(collect(smjDF.queryExecution.executedPlan) {
              case _: SortMergeJoinExec => true
@@ -826,7 +826,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
            }.size === 1)
            // Same result between shuffled hash join and sort merge join
            checkAnswer(shjDF, smjResult)
-@@ -1382,12 +1394,14 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+@@ -1383,12 +1395,14 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
            val smjDF = df2.join(df1.hint("SHUFFLE_MERGE"), joinExprs, 
"rightouter")
            assert(collect(smjDF.queryExecution.executedPlan) {
              case _: SortMergeJoinExec => true
@@ -841,7 +841,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
            }.size === 1)
            // Same result between shuffled hash join and sort merge join
            checkAnswer(shjDF, smjResult)
-@@ -1431,13 +1445,19 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+@@ -1432,13 +1446,19 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
          assert(shjCodegenDF.queryExecution.executedPlan.collect {
            case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
            case WholeStageCodegenExec(ProjectExec(_, _ : 
ShuffledHashJoinExec)) => true
@@ -862,7 +862,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
            checkAnswer(shjNonCodegenDF, Seq.empty)
          }
        }
-@@ -1485,7 +1505,8 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+@@ -1486,7 +1506,8 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
            val plan = sql(getAggQuery(selectExpr, 
joinType)).queryExecution.executedPlan
            assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true 
}.size === 1)
            // Have shuffle before aggregation
@@ -872,7 +872,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
        }
  
        def getJoinQuery(selectExpr: String, joinType: String): String = {
-@@ -1514,9 +1535,12 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+@@ -1515,9 +1536,12 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
            }
            val plan = sql(getJoinQuery(selectExpr, 
joinType)).queryExecution.executedPlan
            assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true 
}.size === 1)
@@ -887,7 +887,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
        }
  
        // Test output ordering is not preserved
-@@ -1525,9 +1549,12 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+@@ -1526,9 +1550,12 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
            val selectExpr = "/*+ BROADCAST(left_t) */ k1 as k0"
            val plan = sql(getJoinQuery(selectExpr, 
joinType)).queryExecution.executedPlan
            assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true 
}.size === 1)
@@ -902,7 +902,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
        }
  
        // Test singe partition
-@@ -1537,7 +1564,8 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+@@ -1538,7 +1565,8 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
             |FROM range(0, 10, 1, 1) t1 FULL OUTER JOIN range(0, 10, 1, 1) t2
             |""".stripMargin)
        val plan = fullJoinDF.queryExecution.executedPlan
@@ -912,7 +912,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
        checkAnswer(fullJoinDF, Row(100))
      }
    }
-@@ -1582,6 +1610,9 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+@@ -1583,6 +1611,9 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
            Seq(semiJoinDF, antiJoinDF).foreach { df =>
              assert(collect(df.queryExecution.executedPlan) {
                case j: ShuffledHashJoinExec if j.ignoreDuplicatedKey == 
ignoreDuplicatedKey => true
@@ -922,7 +922,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
              }.size == 1)
            }
        }
-@@ -1626,14 +1657,20 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+@@ -1627,14 +1658,20 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
  
    test("SPARK-43113: Full outer join with duplicate stream-side references in 
condition (SMJ)") {
      def check(plan: SparkPlan): Unit = {
@@ -946,7 +946,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
      dupStreamSideColTest("SHUFFLE_HASH", check)
    }
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
-index b5b34922694..a72403780c4 100644
+index c26757c9cff..d55775f09d7 100644
 --- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
 +++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
 @@ -69,7 +69,7 @@ import org.apache.spark.tags.ExtendedSQLTest
@@ -959,10 +959,10 @@ index b5b34922694..a72403780c4 100644
    protected val baseResourcePath = {
      // use the same way as `SQLQueryTestSuite` to get the resource path
 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
-index cfeccbdf648..803d8734cc4 100644
+index 793a0da6a86..6ccb9d62582 100644
 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
 +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
-@@ -1510,7 +1510,8 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
+@@ -1521,7 +1521,8 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
      checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001")))
    }
  
@@ -1004,7 +1004,7 @@ index 8b4ac474f87..3f79f20822f 100644
          extensions.injectColumnar(session =>
            MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())) 
}
 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
-index fbc256b3396..0821999c7c2 100644
+index 260c992f1ae..b9d8e22337c 100644
 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
 +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
 @@ -22,10 +22,11 @@ import scala.collection.mutable.ArrayBuffer
@@ -1043,7 +1043,7 @@ index fbc256b3396..0821999c7c2 100644
        assert(exchanges.size === 1)
      }
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
-index 52d0151ee46..2b6d493cf38 100644
+index d269290e616..13726a31e07 100644
 --- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
 +++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
 @@ -24,6 +24,7 @@ import test.org.apache.spark.sql.connector._
@@ -1131,7 +1131,7 @@ index cfc8b2cc845..c6fcfd7bd08 100644
          } finally {
            spark.listenerManager.unregister(listener)
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
-index 6b07c77aefb..8277661560e 100644
+index 71e030f535e..d5ae6cbf3d5 100644
 --- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 +++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 @@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, Row}
@@ -1449,23 +1449,23 @@ index 5a413c77754..a6f97dccb67 100644
            val df = spark.read.parquet(path).selectExpr(projection: _*)
  
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
-index 68bae34790a..0cc77ad09d7 100644
+index 2f8e401e743..a4f94417dcc 100644
 --- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
 +++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
-@@ -26,9 +26,11 @@ import org.scalatest.time.SpanSugar._
- 
+@@ -27,9 +27,11 @@ import org.scalatest.time.SpanSugar._
  import org.apache.spark.SparkException
  import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, 
SparkListenerJobStart}
+ import org.apache.spark.shuffle.sort.SortShuffleManager
 -import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
 +import org.apache.spark.sql.{Dataset, IgnoreComet, QueryTest, Row, 
SparkSession, Strategy}
  import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
  import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
 +import org.apache.spark.sql.comet._
 +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
- import org.apache.spark.sql.execution.{CollectLimitExec, ColumnarToRowExec, 
LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, 
ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, 
UnionExec}
+ import org.apache.spark.sql.execution._
  import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
- import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
-@@ -112,6 +114,7 @@ class AdaptiveQueryExecSuite
+ import org.apache.spark.sql.execution.columnar.{InMemoryTableScanExec, 
InMemoryTableScanLike}
+@@ -117,6 +119,7 @@ class AdaptiveQueryExecSuite
    private def findTopLevelBroadcastHashJoin(plan: SparkPlan): 
Seq[BroadcastHashJoinExec] = {
      collect(plan) {
        case j: BroadcastHashJoinExec => j
@@ -1473,7 +1473,7 @@ index 68bae34790a..0cc77ad09d7 100644
      }
    }
  
-@@ -124,30 +127,39 @@ class AdaptiveQueryExecSuite
+@@ -129,36 +132,46 @@ class AdaptiveQueryExecSuite
    private def findTopLevelSortMergeJoin(plan: SparkPlan): 
Seq[SortMergeJoinExec] = {
      collect(plan) {
        case j: SortMergeJoinExec => j
@@ -1513,7 +1513,14 @@ index 68bae34790a..0cc77ad09d7 100644
      }
    }
  
-@@ -191,6 +203,7 @@ class AdaptiveQueryExecSuite
+   private def findTopLevelLimit(plan: SparkPlan): Seq[CollectLimitExec] = {
+     collect(plan) {
+       case l: CollectLimitExec => l
++      case l: CometCollectLimitExec => 
l.originalPlan.asInstanceOf[CollectLimitExec]
+     }
+   }
+ 
+@@ -202,6 +215,7 @@ class AdaptiveQueryExecSuite
        val parts = rdd.partitions
        assert(parts.forall(rdd.preferredLocations(_).nonEmpty))
      }
@@ -1521,7 +1528,7 @@ index 68bae34790a..0cc77ad09d7 100644
      assert(numShuffles === (numLocalReads.length + 
numShufflesWithoutLocalRead))
    }
  
-@@ -199,7 +212,7 @@ class AdaptiveQueryExecSuite
+@@ -210,7 +224,7 @@ class AdaptiveQueryExecSuite
      val plan = df.queryExecution.executedPlan
      assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
      val shuffle = 
plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
@@ -1530,7 +1537,7 @@ index 68bae34790a..0cc77ad09d7 100644
      }
      assert(shuffle.size == 1)
      assert(shuffle(0).outputPartitioning.numPartitions == numPartition)
-@@ -215,7 +228,8 @@ class AdaptiveQueryExecSuite
+@@ -226,7 +240,8 @@ class AdaptiveQueryExecSuite
        assert(smj.size == 1)
        val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
        assert(bhj.size == 1)
@@ -1540,7 +1547,7 @@ index 68bae34790a..0cc77ad09d7 100644
      }
    }
  
-@@ -242,7 +256,8 @@ class AdaptiveQueryExecSuite
+@@ -253,7 +268,8 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1550,7 +1557,7 @@ index 68bae34790a..0cc77ad09d7 100644
      withSQLConf(
        SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
-@@ -274,7 +289,8 @@ class AdaptiveQueryExecSuite
+@@ -285,7 +301,8 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1560,7 +1567,7 @@ index 68bae34790a..0cc77ad09d7 100644
      withSQLConf(
        SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
-@@ -288,7 +304,8 @@ class AdaptiveQueryExecSuite
+@@ -299,7 +316,8 @@ class AdaptiveQueryExecSuite
        val localReads = collect(adaptivePlan) {
          case read: AQEShuffleReadExec if read.isLocalRead => read
        }
@@ -1570,7 +1577,7 @@ index 68bae34790a..0cc77ad09d7 100644
        val localShuffleRDD0 = 
localReads(0).execute().asInstanceOf[ShuffledRowRDD]
        val localShuffleRDD1 = 
localReads(1).execute().asInstanceOf[ShuffledRowRDD]
        // the final parallelism is math.max(1, numReduces / numMappers): 
math.max(1, 5/2) = 2
-@@ -313,7 +330,9 @@ class AdaptiveQueryExecSuite
+@@ -324,7 +342,9 @@ class AdaptiveQueryExecSuite
            .groupBy($"a").count()
          checkAnswer(testDf, Seq())
          val plan = testDf.queryExecution.executedPlan
@@ -1581,7 +1588,7 @@ index 68bae34790a..0cc77ad09d7 100644
          val coalescedReads = collect(plan) {
            case r: AQEShuffleReadExec => r
          }
-@@ -327,7 +346,9 @@ class AdaptiveQueryExecSuite
+@@ -338,7 +358,9 @@ class AdaptiveQueryExecSuite
            .groupBy($"a").count()
          checkAnswer(testDf, Seq())
          val plan = testDf.queryExecution.executedPlan
@@ -1592,7 +1599,7 @@ index 68bae34790a..0cc77ad09d7 100644
          val coalescedReads = collect(plan) {
            case r: AQEShuffleReadExec => r
          }
-@@ -337,7 +358,7 @@ class AdaptiveQueryExecSuite
+@@ -348,7 +370,7 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1601,7 +1608,7 @@ index 68bae34790a..0cc77ad09d7 100644
      withSQLConf(
          SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
          SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
-@@ -352,7 +373,7 @@ class AdaptiveQueryExecSuite
+@@ -363,7 +385,7 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1610,7 +1617,7 @@ index 68bae34790a..0cc77ad09d7 100644
      withSQLConf(
          SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
          SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
-@@ -368,7 +389,7 @@ class AdaptiveQueryExecSuite
+@@ -379,7 +401,7 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1619,7 +1626,7 @@ index 68bae34790a..0cc77ad09d7 100644
      withSQLConf(
          SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
          SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
-@@ -413,7 +434,7 @@ class AdaptiveQueryExecSuite
+@@ -424,7 +446,7 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1628,7 +1635,7 @@ index 68bae34790a..0cc77ad09d7 100644
      withSQLConf(
          SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
          SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
-@@ -458,7 +479,7 @@ class AdaptiveQueryExecSuite
+@@ -469,7 +491,7 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1637,7 +1644,7 @@ index 68bae34790a..0cc77ad09d7 100644
      withSQLConf(
          SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
          SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
-@@ -504,7 +525,7 @@ class AdaptiveQueryExecSuite
+@@ -515,7 +537,7 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1646,7 +1653,7 @@ index 68bae34790a..0cc77ad09d7 100644
      withSQLConf(
          SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
          SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
-@@ -523,7 +544,7 @@ class AdaptiveQueryExecSuite
+@@ -534,7 +556,7 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1655,7 +1662,7 @@ index 68bae34790a..0cc77ad09d7 100644
      withSQLConf(
          SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
          SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
-@@ -554,7 +575,9 @@ class AdaptiveQueryExecSuite
+@@ -565,7 +587,9 @@ class AdaptiveQueryExecSuite
        assert(smj.size == 1)
        val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
        assert(bhj.size == 1)
@@ -1666,7 +1673,7 @@ index 68bae34790a..0cc77ad09d7 100644
        // Even with local shuffle read, the query stage reuse can also work.
        val ex = findReusedExchange(adaptivePlan)
        assert(ex.nonEmpty)
-@@ -575,7 +598,9 @@ class AdaptiveQueryExecSuite
+@@ -586,7 +610,9 @@ class AdaptiveQueryExecSuite
        assert(smj.size == 1)
        val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
        assert(bhj.size == 1)
@@ -1677,7 +1684,7 @@ index 68bae34790a..0cc77ad09d7 100644
        // Even with local shuffle read, the query stage reuse can also work.
        val ex = findReusedExchange(adaptivePlan)
        assert(ex.isEmpty)
-@@ -584,7 +609,8 @@ class AdaptiveQueryExecSuite
+@@ -595,7 +621,8 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1687,7 +1694,7 @@ index 68bae34790a..0cc77ad09d7 100644
      withSQLConf(
          SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
          SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000",
-@@ -679,7 +705,8 @@ class AdaptiveQueryExecSuite
+@@ -690,7 +717,8 @@ class AdaptiveQueryExecSuite
        val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
        assert(bhj.size == 1)
        // There is still a SMJ, and its two shuffles can't apply local read.
@@ -1697,7 +1704,7 @@ index 68bae34790a..0cc77ad09d7 100644
      }
    }
  
-@@ -801,7 +828,8 @@ class AdaptiveQueryExecSuite
+@@ -812,7 +840,8 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1707,7 +1714,7 @@ index 68bae34790a..0cc77ad09d7 100644
      Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint =>
        def getJoinNode(plan: SparkPlan): Seq[ShuffledJoin] = if (joinHint == 
"SHUFFLE_MERGE") {
          findTopLevelSortMergeJoin(plan)
-@@ -1019,7 +1047,8 @@ class AdaptiveQueryExecSuite
+@@ -1030,7 +1059,8 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1717,7 +1724,7 @@ index 68bae34790a..0cc77ad09d7 100644
      withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
        val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
          "SELECT key FROM testData GROUP BY key")
-@@ -1614,7 +1643,7 @@ class AdaptiveQueryExecSuite
+@@ -1625,7 +1655,7 @@ class AdaptiveQueryExecSuite
          val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
            "SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id")
          assert(collect(adaptivePlan) {
@@ -1726,7 +1733,7 @@ index 68bae34790a..0cc77ad09d7 100644
          }.length == 1)
        }
      }
-@@ -1694,7 +1723,8 @@ class AdaptiveQueryExecSuite
+@@ -1705,7 +1735,8 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1736,7 +1743,7 @@ index 68bae34790a..0cc77ad09d7 100644
      def hasRepartitionShuffle(plan: SparkPlan): Boolean = {
        find(plan) {
          case s: ShuffleExchangeLike =>
-@@ -1879,6 +1909,9 @@ class AdaptiveQueryExecSuite
+@@ -1890,6 +1921,9 @@ class AdaptiveQueryExecSuite
      def checkNoCoalescePartitions(ds: Dataset[Row], origin: ShuffleOrigin): 
Unit = {
        assert(collect(ds.queryExecution.executedPlan) {
          case s: ShuffleExchangeExec if s.shuffleOrigin == origin && 
s.numPartitions == 2 => s
@@ -1746,7 +1753,7 @@ index 68bae34790a..0cc77ad09d7 100644
        }.size == 1)
        ds.collect()
        val plan = ds.queryExecution.executedPlan
-@@ -1887,6 +1920,9 @@ class AdaptiveQueryExecSuite
+@@ -1898,6 +1932,9 @@ class AdaptiveQueryExecSuite
        }.isEmpty)
        assert(collect(plan) {
          case s: ShuffleExchangeExec if s.shuffleOrigin == origin && 
s.numPartitions == 2 => s
@@ -1756,7 +1763,7 @@ index 68bae34790a..0cc77ad09d7 100644
        }.size == 1)
        checkAnswer(ds, testData)
      }
-@@ -2043,7 +2079,8 @@ class AdaptiveQueryExecSuite
+@@ -2054,7 +2091,8 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1766,7 +1773,7 @@ index 68bae34790a..0cc77ad09d7 100644
      withTempView("t1", "t2") {
        def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = {
          Seq("100", "100000").foreach { size =>
-@@ -2129,7 +2166,8 @@ class AdaptiveQueryExecSuite
+@@ -2140,7 +2178,8 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1776,7 +1783,7 @@ index 68bae34790a..0cc77ad09d7 100644
      withTempView("v") {
        withSQLConf(
          SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
-@@ -2228,7 +2266,7 @@ class AdaptiveQueryExecSuite
+@@ -2239,7 +2278,7 @@ class AdaptiveQueryExecSuite
                runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM 
skewData1 " +
                  s"JOIN skewData2 ON key1 = key2 GROUP BY key1")
              val shuffles1 = collect(adaptive1) {
@@ -1785,7 +1792,7 @@ index 68bae34790a..0cc77ad09d7 100644
              }
              assert(shuffles1.size == 3)
              // shuffles1.head is the top-level shuffle under the Aggregate 
operator
-@@ -2241,7 +2279,7 @@ class AdaptiveQueryExecSuite
+@@ -2252,7 +2291,7 @@ class AdaptiveQueryExecSuite
                runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM 
skewData1 " +
                  s"JOIN skewData2 ON key1 = key2")
              val shuffles2 = collect(adaptive2) {
@@ -1794,7 +1801,7 @@ index 68bae34790a..0cc77ad09d7 100644
              }
              if (hasRequiredDistribution) {
                assert(shuffles2.size == 3)
-@@ -2275,7 +2313,8 @@ class AdaptiveQueryExecSuite
+@@ -2286,7 +2325,8 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1804,7 +1811,17 @@ index 68bae34790a..0cc77ad09d7 100644
      CostEvaluator.instantiate(
        classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, 
spark.sparkContext.getConf)
      intercept[IllegalArgumentException] {
-@@ -2419,6 +2458,7 @@ class AdaptiveQueryExecSuite
+@@ -2417,7 +2457,8 @@ class AdaptiveQueryExecSuite
+   }
+ 
+   test("SPARK-48037: Fix SortShuffleWriter lacks shuffle write related 
metrics " +
+-    "resulting in potentially inaccurate data") {
++    "resulting in potentially inaccurate data",
++    IgnoreComet("https://github.com/apache/datafusion-comet/issues/1501";)) {
+     withTable("t3") {
+       withSQLConf(
+         SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+@@ -2452,6 +2493,7 @@ class AdaptiveQueryExecSuite
            val (_, adaptive) = runAdaptiveAndVerifyResult(query)
            assert(adaptive.collect {
              case sort: SortExec => sort
@@ -1812,7 +1829,7 @@ index 68bae34790a..0cc77ad09d7 100644
            }.size == 1)
            val read = collect(adaptive) {
              case read: AQEShuffleReadExec => read
-@@ -2436,7 +2476,8 @@ class AdaptiveQueryExecSuite
+@@ -2469,7 +2511,8 @@ class AdaptiveQueryExecSuite
      }
    }
  
@@ -1822,7 +1839,7 @@ index 68bae34790a..0cc77ad09d7 100644
      withTempView("v") {
        withSQLConf(
          SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key 
-> "true",
-@@ -2548,7 +2589,7 @@ class AdaptiveQueryExecSuite
+@@ -2581,7 +2624,7 @@ class AdaptiveQueryExecSuite
            runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN 
skewData2 ON key1 = key2 " +
              "JOIN skewData3 ON value2 = value3")
          val shuffles1 = collect(adaptive1) {
@@ -1831,7 +1848,7 @@ index 68bae34790a..0cc77ad09d7 100644
          }
          assert(shuffles1.size == 4)
          val smj1 = findTopLevelSortMergeJoin(adaptive1)
-@@ -2559,7 +2600,7 @@ class AdaptiveQueryExecSuite
+@@ -2592,7 +2635,7 @@ class AdaptiveQueryExecSuite
            runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN 
skewData2 ON key1 = key2 " +
              "JOIN skewData3 ON value1 = value3")
          val shuffles2 = collect(adaptive2) {
@@ -1840,14 +1857,25 @@ index 68bae34790a..0cc77ad09d7 100644
          }
          assert(shuffles2.size == 4)
          val smj2 = findTopLevelSortMergeJoin(adaptive2)
-@@ -2756,6 +2797,7 @@ class AdaptiveQueryExecSuite
+@@ -2850,6 +2893,7 @@ class AdaptiveQueryExecSuite
          }.size == (if (firstAccess) 1 else 0))
          assert(collect(initialExecutedPlan) {
            case s: SortExec => s
 +          case s: CometSortExec => s
          }.size == (if (firstAccess) 2 else 0))
          assert(collect(initialExecutedPlan) {
-           case i: InMemoryTableScanExec => i
+           case i: InMemoryTableScanLike => i
+@@ -2980,7 +3024,9 @@ class AdaptiveQueryExecSuite
+ 
+       val plan = 
df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec]
+       assert(plan.inputPlan.isInstanceOf[TakeOrderedAndProjectExec])
+-      assert(plan.finalPhysicalPlan.isInstanceOf[WindowExec])
++      assert(
++        plan.finalPhysicalPlan.isInstanceOf[WindowExec] ||
++        plan.finalPhysicalPlan.find(_.isInstanceOf[CometWindowExec]).nonEmpty)
+       plan.inputPlan.output.zip(plan.finalPhysicalPlan.output).foreach { case 
(o1, o2) =>
+         assert(o1.semanticEquals(o2), "Different output column order after 
AQE optimization")
+       }
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCustomMetadataStructSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCustomMetadataStructSuite.scala
 index 05872d41131..a2c328b9742 100644
 --- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCustomMetadataStructSuite.scala
@@ -2094,10 +2122,10 @@ index 4f8a9e39716..fb55ac7a955 100644
        checkAnswer(
          // "fruit" column in this file is encoded using 
DELTA_LENGTH_BYTE_ARRAY.
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
-index 828ec39c7d7..369b3848192 100644
+index f6472ba3d9d..dc13e00c853 100644
 --- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
 +++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
-@@ -1041,7 +1041,8 @@ abstract class ParquetQuerySuite extends QueryTest with 
ParquetTest with SharedS
+@@ -1067,7 +1067,8 @@ abstract class ParquetQuerySuite extends QueryTest with 
ParquetTest with SharedS
          checkAnswer(readParquet(schema, path), df)
        }
  
@@ -2107,7 +2135,7 @@ index 828ec39c7d7..369b3848192 100644
          val schema1 = "a DECIMAL(3, 2), b DECIMAL(18, 3), c DECIMAL(37, 3)"
          checkAnswer(readParquet(schema1, path), df)
          val schema2 = "a DECIMAL(3, 0), b DECIMAL(18, 1), c DECIMAL(37, 1)"
-@@ -1063,7 +1064,8 @@ abstract class ParquetQuerySuite extends QueryTest with 
ParquetTest with SharedS
+@@ -1089,7 +1090,8 @@ abstract class ParquetQuerySuite extends QueryTest with 
ParquetTest with SharedS
        val df = sql(s"SELECT 1 a, 123456 b, ${Int.MaxValue.toLong * 10} c, 
CAST('1.2' AS BINARY) d")
        df.write.parquet(path.toString)
  
@@ -2117,7 +2145,7 @@ index 828ec39c7d7..369b3848192 100644
          checkAnswer(readParquet("a DECIMAL(3, 2)", path), sql("SELECT 1.00"))
          checkAnswer(readParquet("b DECIMAL(3, 2)", path), Row(null))
          checkAnswer(readParquet("b DECIMAL(11, 1)", path), sql("SELECT 
123456.0"))
-@@ -1122,7 +1124,7 @@ abstract class ParquetQuerySuite extends QueryTest with 
ParquetTest with SharedS
+@@ -1148,7 +1150,7 @@ abstract class ParquetQuerySuite extends QueryTest with 
ParquetTest with SharedS
              .where(s"a < ${Long.MaxValue}")
              .collect()
          }
@@ -2241,7 +2269,7 @@ index b8f3ea3c6f3..bbd44221288 100644
        val workDirPath = workDir.getAbsolutePath
        val input = spark.range(5).toDF("id")
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
-index 6347757e178..6d0fa493308 100644
+index 5cdbdc27b32..307fba16578 100644
 --- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
 +++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
 @@ -46,8 +46,10 @@ import org.apache.spark.sql.util.QueryExecutionListener
@@ -2533,7 +2561,7 @@ index d675503a8ba..659fa686fb7 100644
      }
  
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
-index 75f440caefc..36b1146bc3a 100644
+index 1954cce7fdc..73d1464780e 100644
 --- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
 +++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
 @@ -34,6 +34,7 @@ import org.apache.spark.paths.SparkPath
@@ -2544,7 +2572,7 @@ index 75f440caefc..36b1146bc3a 100644
  import org.apache.spark.sql.execution.DataSourceScanExec
  import org.apache.spark.sql.execution.datasources._
  import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, 
DataSourceV2Relation, FileScan, FileTable}
-@@ -748,6 +749,8 @@ class FileStreamSinkV2Suite extends FileStreamSinkSuite {
+@@ -761,6 +762,8 @@ class FileStreamSinkV2Suite extends FileStreamSinkSuite {
        val fileScan = df.queryExecution.executedPlan.collect {
          case batch: BatchScanExec if batch.scan.isInstanceOf[FileScan] =>
            batch.scan.asInstanceOf[FileScan]
@@ -2706,7 +2734,7 @@ index b4c4ec7acbf..20579284856 100644
  
          val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter {
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
-index 3e1bc57dfa2..4a8d75ff512 100644
+index aad91601758..201083bd621 100644
 --- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
 +++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
 @@ -31,7 +31,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation
@@ -2778,7 +2806,7 @@ index abe606ad9c1..2d930b64cca 100644
      val tblTargetName = "tbl_target"
      val tblSourceQualified = s"default.$tblSourceName"
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
-index dd55fcfe42c..0d66bcccbdc 100644
+index e937173a590..c2e00c53cc3 100644
 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
 +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
 @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest
@@ -2832,7 +2860,7 @@ index dd55fcfe42c..0d66bcccbdc 100644
    protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): 
Unit = {
      SparkSession.setActiveSession(spark)
      super.withSQLConf(pairs: _*)(f)
-@@ -434,6 +462,8 @@ private[sql] trait SQLTestUtilsBase
+@@ -435,6 +463,8 @@ private[sql] trait SQLTestUtilsBase
      val schema = df.schema
      val withoutFilters = df.queryExecution.executedPlan.transform {
        case FilterExec(_, child) => child
@@ -2927,7 +2955,7 @@ index dc8b184fcee..dd69a989d40 100644
        spark.sql(
          """
 diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
-index 9284b35fb3e..37f91610500 100644
+index 1d646f40b3e..7f2cdb8f061 100644
 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
 +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
 @@ -53,25 +53,55 @@ object TestHive
diff --git a/docs/source/contributor-guide/spark-sql-tests.md 
b/docs/source/contributor-guide/spark-sql-tests.md
index 0cdef50ab..cb88d2f43 100644
--- a/docs/source/contributor-guide/spark-sql-tests.md
+++ b/docs/source/contributor-guide/spark-sql-tests.md
@@ -72,11 +72,11 @@ of Apache Spark to enable Comet when running tests. This is 
a highly manual proc
 vary depending on the changes in the new version of Spark, but here is a 
general guide to the process.
 
 We typically start by applying a patch from a previous version of Spark. For 
example, when enabling the tests 
-for Spark version 3.5.1 we may start by applying the existing diff for 3.4.3 
first.
+for Spark version 3.5.4 we may start by applying the existing diff for 3.4.3 
first.
 
 ```shell
 cd git/apache/spark
-git checkout v3.5.1
+git checkout v3.5.4
 git apply --reject --whitespace=fix ../datafusion-comet/dev/diffs/3.4.3.diff
 ```
 
@@ -118,7 +118,7 @@ wiggle --replace 
./sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.sc
 ## Generating The Diff File
 
 ```shell
-git diff v3.5.1 > ../datafusion-comet/dev/diffs/3.5.1.diff
+git diff v3.5.4 > ../datafusion-comet/dev/diffs/3.5.4.diff
 ```
 
 ## Running Tests in CI 
diff --git 
a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs 
b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
index 1f9400b35..320938a5f 100644
--- 
a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
+++ 
b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
@@ -17,7 +17,9 @@
 
 use arrow::array::{ArrayRef, OffsetSizeTrait};
 use arrow_array::builder::GenericStringBuilder;
-use arrow_array::Array;
+use arrow_array::cast::as_dictionary_array;
+use arrow_array::types::Int32Type;
+use arrow_array::{make_array, Array, DictionaryArray};
 use arrow_schema::DataType;
 use datafusion::physical_plan::ColumnarValue;
 use datafusion_common::{cast::as_generic_string_array, DataFusionError, 
ScalarValue};
@@ -45,14 +47,26 @@ fn spark_read_side_padding2(
                 DataType::LargeUtf8 => {
                     spark_read_side_padding_internal::<i64>(array, *length, 
truncate)
                 }
-                // TODO: handle Dictionary types
+                // Dictionary support required for SPARK-48498
+                DataType::Dictionary(_, value_type) => {
+                    let dict = as_dictionary_array::<Int32Type>(array);
+                    let col = if value_type.as_ref() == &DataType::Utf8 {
+                        spark_read_side_padding_internal::<i32>(dict.values(), 
*length, truncate)?
+                    } else {
+                        spark_read_side_padding_internal::<i64>(dict.values(), 
*length, truncate)?
+                    };
+                    // col consists of an array, so arg of to_array() is not 
used. Can be anything
+                    let values = col.to_array(0)?;
+                    let result = DictionaryArray::try_new(dict.keys().clone(), 
values)?;
+                    Ok(ColumnarValue::Array(make_array(result.into())))
+                }
                 other => Err(DataFusionError::Internal(format!(
-                    "Unsupported data type {other:?} for function 
read_side_padding",
+                    "Unsupported data type {other:?} for function 
rpad/read_side_padding",
                 ))),
             }
         }
         other => Err(DataFusionError::Internal(format!(
-            "Unsupported arguments {other:?} for function read_side_padding",
+            "Unsupported arguments {other:?} for function 
rpad/read_side_padding",
         ))),
     }
 }
diff --git a/pom.xml b/pom.xml
index f55e44316..e236c1dd7 100644
--- a/pom.xml
+++ b/pom.xml
@@ -50,8 +50,8 @@ under the License.
     <scala.version>2.12.17</scala.version>
     <scala.binary.version>2.12</scala.binary.version>
     <scala.plugin.version>4.7.2</scala.plugin.version>
-    <scalatest.version>3.2.9</scalatest.version>
-    <scalatest-maven-plugin.version>2.0.2</scalatest-maven-plugin.version>
+    <scalatest.version>3.2.16</scalatest.version>
+    <scalatest-maven-plugin.version>2.2.0</scalatest-maven-plugin.version>
     <spark.version>3.4.3</spark.version>
     <spark.version.short>3.4</spark.version.short>
     <spark.maven.scope>provided</spark.maven.scope>


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to