erenavsarogullari commented on code in PR #45234:
URL: https://github.com/apache/spark/pull/45234#discussion_r1525339983


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala:
##########
@@ -897,6 +897,138 @@ class AdaptiveQueryExecSuite
     }
   }
 
+  test("SPARK-47148: AQE should avoid to materialize ShuffleQueryStage on the 
cancellation") {
+    withSQLConf(
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+      SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+      withTable("bucketed_table1", "bucketed_table2", "bucketed_table3") {
+        val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", 
"j", "k")
+        df.write.format("parquet").bucketBy(8, 
"i").saveAsTable("bucketed_table1")
+        df.write.format("parquet").bucketBy(8, 
"i").saveAsTable("bucketed_table2")
+        df.write.format("parquet").bucketBy(8, 
"i").saveAsTable("bucketed_table3")
+
+        val warehouseFilePath = new 
URI(spark.sessionState.conf.warehousePath).getPath
+        val tableDir = new File(warehouseFilePath, "bucketed_table2")
+        Utils.deleteRecursively(tableDir)
+        df.write.parquet(tableDir.getAbsolutePath)
+
+        val aggDF1 = spark.table("bucketed_table1").groupBy("i", "j", 
"k").count().repartition(5)
+        val aggDF2 = spark.table("bucketed_table2").groupBy("i", 
"j").count().repartition(5)
+        val aggDF3 = 
spark.table("bucketed_table3").groupBy("i").count().repartition(5)
+        val joinedDF = aggDF1.join(aggDF2, Seq("i", "j")).join(aggDF3, 
Seq("i"))
+
+        val error = intercept[Exception] {
+          joinedDF.collect()
+        }
+        assert(error.getMessage() contains "Invalid bucket file")
+        assert(error.getSuppressed.size === 0)
+
+        val adaptivePlan = 
joinedDF.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec]
+
+        // There should not be BroadcastQueryStageExec
+        val broadcastQueryStageExecs = collect(adaptivePlan) {
+          case r: BroadcastQueryStageExec => r
+        }
+        assert(broadcastQueryStageExecs.isEmpty, s"$adaptivePlan")
+
+        // All QueryStages should be based on ShuffleQueryStageExec
+        val shuffleQueryStageExecs = collect(adaptivePlan) {
+          case r: ShuffleQueryStageExec => r
+        }
+        assert(shuffleQueryStageExecs.length == 3, s"$adaptivePlan")
+        // First ShuffleQueryStage is materialized so it needs to be canceled.
+        assert(shuffleQueryStageExecs(0).isMaterializationStarted(),
+          "Materialization should be started.")
+        // Second ShuffleQueryStage materialization is failed so
+        // it is excluded from the cancellation due to earlyFailedStage.
+        assert(shuffleQueryStageExecs(1).isMaterializationStarted(),
+          "Materialization should be started but it is failed.")
+        // Last ShuffleQueryStage is not materialized yet so it does not 
require
+        // to be canceled and it is just skipped from the cancellation.
+        assert(!shuffleQueryStageExecs(2).isMaterializationStarted(),
+          "Materialization should not be started.")
+      }
+    }
+  }
+
+  test("SPARK-47148: Check if BroadcastQueryStage materialization is started") 
{
+    withSQLConf(
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+      SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true") {
+      withTable("bucketed_table1", "bucketed_table2", "bucketed_table3") {
+        val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", 
"j", "k")
+        df.write.format("parquet").bucketBy(8, 
"i").saveAsTable("bucketed_table1")
+        df.write.format("parquet").bucketBy(8, 
"i").saveAsTable("bucketed_table2")
+        df.write.format("parquet").bucketBy(8, 
"i").saveAsTable("bucketed_table3")
+
+        val warehouseFilePath = new 
URI(spark.sessionState.conf.warehousePath).getPath
+        val tableDir = new File(warehouseFilePath, "bucketed_table2")
+        Utils.deleteRecursively(tableDir)
+        df.write.parquet(tableDir.getAbsolutePath)
+
+        val aggDF1 = spark.table("bucketed_table1").groupBy("i", "j", 
"k").count()
+        val aggDF2 = spark.table("bucketed_table2").groupBy("i", "j", 
"k").count()
+        val aggDF3 = spark.table("bucketed_table3").groupBy("i", "j").count()
+        val joinedDF = aggDF1.join(aggDF2, Seq("i", "j", "k")).join(aggDF3, 
Seq("i", "j"))
+          .join(aggDF1, Seq("i"))
+
+        val error = intercept[Exception] {
+          joinedDF.collect()
+        }
+        assert(error.getMessage() contains "Invalid bucket file")
+        assert(error.getSuppressed.size === 0)
+
+        val adaptivePlan = 
joinedDF.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec]
+        // There should not be ShuffleQueryStageExec
+        val shuffleQueryStageExecs = collect(adaptivePlan) {
+          case r: ShuffleQueryStageExec => r
+        }
+        assert(shuffleQueryStageExecs.isEmpty, s"$adaptivePlan")
+
+        // All QueryStages should be based on BroadcastQueryStageExec
+        val broadcastQueryStageExecs = collect(adaptivePlan) {
+          case r: BroadcastQueryStageExec => r
+        }
+        assert(broadcastQueryStageExecs.length == 3, s"$adaptivePlan")
+        broadcastQueryStageExecs.foreach { bqse =>
+          assert(bqse.isMaterializationStarted(),
+            s"${bqse.name}' s materialization should be started before " +
+              "BroadcastQueryStage-1' s materialization is failed.")
+        }
+      }
+    }
+  }
+
+  test("SPARK-47148: Check AQE QueryStages names") {
+    withSQLConf(
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+      SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true") {
+      withTable("bucketed_table1", "bucketed_table2") {
+        val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", 
"j", "k")
+        df.write.format("parquet").bucketBy(8, 
"i").saveAsTable("bucketed_table1")
+        df.write.format("parquet").bucketBy(8, 
"i").saveAsTable("bucketed_table2")
+
+        val df1 = spark.table("bucketed_table1").persist()
+        val df2 = spark.table("bucketed_table2").persist()
+        val joinedDF = df1.join(df2, Seq("i", "j", "k")).join(df1, Seq("i"))
+          .repartition(5).sort("i")
+        joinedDF.collect()
+
+        // Verify QueryStageExecs names
+        val adaptivePlanOfJoinedDF =
+          
joinedDF.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec]
+        val queryStageExecs = collect(adaptivePlanOfJoinedDF) {
+          case qse: QueryStageExec => qse
+        }
+        assert(queryStageExecs.size == 7, s"$adaptivePlanOfJoinedDF")
+        
assert(queryStageExecs.filter(_.name.contains("TableCacheQueryStageExec-")).size
 == 3)
+        
assert(queryStageExecs.filter(_.name.contains("BroadcastQueryStageExec-")).size 
== 2)
+        
assert(queryStageExecs.filter(_.name.contains("ShuffleQueryStageExec-")).size 
== 2)
+      }
+    }
+  }
+

Review Comment:
   We need to have stage failure to trigger the cancellation and used the 
[existing 
test](https://github.com/apache/spark/pull/45234/files#diff-f89f2fe78b324c6bc7190bef84220181f3616efc156ea99b3f15d375a22d7f88R878)
 to simulate this case as well. 
   I think we can have single test case checking both new properties: 
`QueryStageExec.materializationStarted` and `name` like last commit.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to