This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4b40920e331 [SPARK-41713][SQL] Make CTAS hold a nested execution for 
data writing
4b40920e331 is described below

commit 4b40920e33176fc8b18380703e4dcf4d16824094
Author: ulysses-you <ulyssesyo...@gmail.com>
AuthorDate: Wed Dec 28 17:11:59 2022 +0800

    [SPARK-41713][SQL] Make CTAS hold a nested execution for data writing
    
    ### What changes were proposed in this pull request?
    
    This pr aims to make ctas use a nested execution instead of running data 
writing cmmand.
    
    So, we can clean up ctas itself to remove the unnecessary v1write 
information. Now, the v1writes only have two implementation: 
`InsertIntoHadoopFsRelationCommand` and `InsertIntoHiveTable`
    
    ### Why are the changes needed?
    
    Make v1writes code clear.
    
    ```sql
    EXPLAIN FORMATTED CREATE TABLE t2 USING PARQUET AS SELECT * FROM t;
    
    == Physical Plan ==
    Execute CreateDataSourceTableAsSelectCommand (1)
       +- CreateDataSourceTableAsSelectCommand (2)
             +- Project (5)
                +- SubqueryAlias (4)
                   +- LogicalRelation (3)
    
    (1) Execute CreateDataSourceTableAsSelectCommand
    Output: []
    
    (2) CreateDataSourceTableAsSelectCommand
    Arguments: `spark_catalog`.`default`.`t2`, ErrorIfExists, [c1, c2]
    
    (3) LogicalRelation
    Arguments: parquet, [c1#11, c2#12], `spark_catalog`.`default`.`t`, 
org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe, false
    
    (4) SubqueryAlias
    Arguments: spark_catalog.default.t
    
    (5) Project
    Arguments: [c1#11, c2#12]
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    improve existed test
    
    Closes #39220 from ulysses-you/SPARK-41713.
    
    Authored-by: ulysses-you <ulyssesyo...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../execution/command/createDataSourceTables.scala | 40 +++------------
 .../sql/execution/datasources/DataSource.scala     | 48 ++++++------------
 .../spark/sql/execution/datasources/V1Writes.scala |  8 +--
 .../scala/org/apache/spark/sql/ExplainSuite.scala  |  7 ++-
 .../adaptive/AdaptiveQueryExecSuite.scala          | 58 ++++++++++++++--------
 .../datasources/V1WriteCommandSuite.scala          | 17 +++----
 .../sql/execution/metric/SQLMetricsSuite.scala     | 41 ++++++++++-----
 .../spark/sql/util/DataFrameCallbackSuite.scala    | 12 +++--
 .../execution/CreateHiveTableAsSelectCommand.scala | 46 +++++------------
 .../sql/hive/execution/HiveExplainSuite.scala      | 16 ++----
 .../spark/sql/hive/execution/SQLMetricsSuite.scala | 49 +++++++++++-------
 11 files changed, 159 insertions(+), 183 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
index 9bf9f43829e..bf14ef14cf4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
@@ -20,13 +20,11 @@ package org.apache.spark.sql.execution.command
 import java.net.URI
 
 import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util.CharVarcharUtils
 import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.{CommandExecutionMode, SparkPlan}
+import org.apache.spark.sql.execution.CommandExecutionMode
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.sources.BaseRelation
 import org.apache.spark.sql.types.StructType
@@ -143,29 +141,11 @@ case class CreateDataSourceTableAsSelectCommand(
     mode: SaveMode,
     query: LogicalPlan,
     outputColumnNames: Seq[String])
-  extends V1WriteCommand {
-
-  override def fileFormatProvider: Boolean = {
-    table.provider.forall { provider =>
-      classOf[FileFormat].isAssignableFrom(DataSource.providingClass(provider, 
conf))
-    }
-  }
-
-  override lazy val partitionColumns: Seq[Attribute] = {
-    val unresolvedPartitionColumns = 
table.partitionColumnNames.map(UnresolvedAttribute.quoted)
-    DataSource.resolvePartitionColumns(
-      unresolvedPartitionColumns,
-      outputColumns,
-      query,
-      SparkSession.active.sessionState.conf.resolver)
-  }
-
-  override def requiredOrdering: Seq[SortOrder] = {
-    val options = table.storage.properties
-    V1WritesUtils.getSortOrder(outputColumns, partitionColumns, 
table.bucketSpec, options)
-  }
+  extends LeafRunnableCommand {
+  assert(query.resolved)
+  override def innerChildren: Seq[LogicalPlan] = query :: Nil
 
-  override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
+  override def run(sparkSession: SparkSession): Seq[Row] = {
     assert(table.tableType != CatalogTableType.VIEW)
     assert(table.provider.isDefined)
 
@@ -187,7 +167,7 @@ case class CreateDataSourceTableAsSelectCommand(
       }
 
       saveDataIntoTable(
-        sparkSession, table, table.storage.locationUri, child, 
SaveMode.Append, tableExists = true)
+        sparkSession, table, table.storage.locationUri, SaveMode.Append, 
tableExists = true)
     } else {
       table.storage.locationUri.foreach { p =>
         DataWritingCommand.assertEmptyRootPath(p, mode, 
sparkSession.sessionState.newHadoopConf)
@@ -200,7 +180,7 @@ case class CreateDataSourceTableAsSelectCommand(
         table.storage.locationUri
       }
       val result = saveDataIntoTable(
-        sparkSession, table, tableLocation, child, SaveMode.Overwrite, 
tableExists = false)
+        sparkSession, table, tableLocation, SaveMode.Overwrite, tableExists = 
false)
       val tableSchema = CharVarcharUtils.getRawSchema(result.schema, 
sessionState.conf)
       val newTable = table.copy(
         storage = table.storage.copy(locationUri = tableLocation),
@@ -232,7 +212,6 @@ case class CreateDataSourceTableAsSelectCommand(
       session: SparkSession,
       table: CatalogTable,
       tableLocation: Option[URI],
-      physicalPlan: SparkPlan,
       mode: SaveMode,
       tableExists: Boolean): BaseRelation = {
     // Create the relation based on the input logical plan: `query`.
@@ -246,14 +225,11 @@ case class CreateDataSourceTableAsSelectCommand(
       catalogTable = if (tableExists) Some(table) else None)
 
     try {
-      dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan, 
metrics)
+      dataSource.writeAndRead(mode, query, outputColumnNames)
     } catch {
       case ex: AnalysisException =>
         logError(s"Failed to write to table 
${table.identifier.unquotedString}", ex)
         throw ex
     }
   }
-
-  override protected def withNewChildInternal(newChild: LogicalPlan): 
LogicalPlan =
-    copy(query = newChild)
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 3d8eb9bc8a8..edbdd6bbc67 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, TypeUtils}
 import org.apache.spark.sql.connector.catalog.TableProvider
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
-import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.command.DataWritingCommand
 import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
 import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider
@@ -45,7 +44,6 @@ import 
org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
 import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
 import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
 import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2
-import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, 
TextSocketSourceProvider}
 import org.apache.spark.sql.internal.SQLConf
@@ -97,8 +95,19 @@ case class DataSource(
 
   case class SourceInfo(name: String, schema: StructType, partitionColumns: 
Seq[String])
 
-  lazy val providingClass: Class[_] =
-    DataSource.providingClass(className, sparkSession.sessionState.conf)
+  lazy val providingClass: Class[_] = {
+    val cls = DataSource.lookupDataSource(className, 
sparkSession.sessionState.conf)
+    // `providingClass` is used for resolving data source relation for catalog 
tables.
+    // As now catalog for data source V2 is under development, here we fall 
back all the
+    // [[FileDataSourceV2]] to [[FileFormat]] to guarantee the current catalog 
works.
+    // [[FileDataSourceV2]] will still be used if we call the load()/save() 
method in
+    // [[DataFrameReader]]/[[DataFrameWriter]], since they use method 
`lookupDataSource`
+    // instead of `providingClass`.
+    cls.newInstance() match {
+      case f: FileDataSourceV2 => f.fallbackFileFormat
+      case _ => cls
+    }
+  }
 
   private[sql] def providingInstance(): Any = 
providingClass.getConstructor().newInstance()
 
@@ -483,17 +492,11 @@ case class DataSource(
    * @param outputColumnNames The original output column names of the input 
query plan. The
    *                          optimizer may not preserve the output column's 
names' case, so we need
    *                          this parameter instead of `data.output`.
-   * @param physicalPlan The physical plan of the input query plan. We should 
run the writing
-   *                     command with this physical plan instead of creating a 
new physical plan,
-   *                     so that the metrics can be correctly linked to the 
given physical plan and
-   *                     shown in the web UI.
    */
   def writeAndRead(
       mode: SaveMode,
       data: LogicalPlan,
-      outputColumnNames: Seq[String],
-      physicalPlan: SparkPlan,
-      metrics: Map[String, SQLMetric]): BaseRelation = {
+      outputColumnNames: Seq[String]): BaseRelation = {
     val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(data, 
outputColumnNames)
     providingInstance() match {
       case dataSource: CreatableRelationProvider =>
@@ -503,13 +506,8 @@ case class DataSource(
       case format: FileFormat =>
         disallowWritingIntervals(outputColumns.map(_.dataType), 
forbidAnsiIntervals = false)
         val cmd = planForWritingFileFormat(format, mode, data)
-        val resolvedPartCols =
-          DataSource.resolvePartitionColumns(cmd.partitionColumns, 
outputColumns, data, equality)
-        val resolved = cmd.copy(
-          partitionColumns = resolvedPartCols,
-          outputColumnNames = outputColumnNames)
-        resolved.run(sparkSession, physicalPlan)
-        DataWritingCommand.propogateMetrics(sparkSession.sparkContext, 
resolved, metrics)
+        val qe = sparkSession.sessionState.executePlan(cmd)
+        qe.assertCommandExecuted()
         // Replace the schema with that of the DataFrame we just wrote out to 
avoid re-inferring
         copy(userSpecifiedSchema = 
Some(outputColumns.toStructType.asNullable)).resolveRelation()
       case _ => throw new IllegalStateException(
@@ -832,18 +830,4 @@ object DataSource extends Logging {
       }
     }
   }
-
-  def providingClass(className: String, conf: SQLConf): Class[_] = {
-    val cls = DataSource.lookupDataSource(className, conf)
-    // `providingClass` is used for resolving data source relation for catalog 
tables.
-    // As now catalog for data source V2 is under development, here we fall 
back all the
-    // [[FileDataSourceV2]] to [[FileFormat]] to guarantee the current catalog 
works.
-    // [[FileDataSourceV2]] will still be used if we call the load()/save() 
method in
-    // [[DataFrameReader]]/[[DataFrameWriter]], since they use method 
`lookupDataSource`
-    // instead of `providingClass`.
-    cls.newInstance() match {
-      case f: FileDataSourceV2 => f.fallbackFileFormat
-      case _ => cls
-    }
-  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala
index e9f6e3df785..3ed04e5bd6d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala
@@ -31,11 +31,6 @@ import org.apache.spark.sql.types.StringType
 import org.apache.spark.unsafe.types.UTF8String
 
 trait V1WriteCommand extends DataWritingCommand {
-  /**
-   * Return if the provider is [[FileFormat]]
-   */
-  def fileFormatProvider: Boolean = true
-
   /**
    * Specify the partition columns of the V1 write command.
    */
@@ -58,8 +53,7 @@ object V1Writes extends Rule[LogicalPlan] with SQLConfHelper {
   override def apply(plan: LogicalPlan): LogicalPlan = {
     if (conf.plannedWriteEnabled) {
       plan.transformUp {
-        case write: V1WriteCommand if write.fileFormatProvider &&
-          !write.child.isInstanceOf[WriteFiles] =>
+        case write: V1WriteCommand if !write.child.isInstanceOf[WriteFiles] =>
           val newQuery = prepareQuery(write, write.query)
           val attrMap = AttributeMap(write.query.output.zip(newQuery.output))
           val newChild = WriteFiles(newQuery)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
index b5353455dc2..9a75cc5ff8f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
@@ -250,7 +250,12 @@ class ExplainSuite extends ExplainSuiteHelper with 
DisableAdaptiveExecutionSuite
     withTable("temptable") {
       val df = sql("create table temptable using parquet as select * from 
range(2)")
       withNormalizedExplain(df, SimpleMode) { normalizedOutput =>
-        
assert("Create\\w*?TableAsSelectCommand".r.findAllMatchIn(normalizedOutput).length
 == 1)
+        // scalastyle:off
+        // == Physical Plan ==
+        // Execute CreateDataSourceTableAsSelectCommand
+        //   +- CreateDataSourceTableAsSelectCommand 
`spark_catalog`.`default`.`temptable`, ErrorIfExists, Project [id#5L], [id]
+        // scalastyle:on
+        
assert("Create\\w*?TableAsSelectCommand".r.findAllMatchIn(normalizedOutput).length
 == 2)
       }
     }
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 88baf76ba7a..1f10ff36acb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.scheduler.{SparkListener, 
SparkListenerEvent, SparkListe
 import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
 import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
-import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, 
LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, 
ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnionExec}
+import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, 
PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, 
ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec}
 import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
 import org.apache.spark.sql.execution.command.DataWritingCommandExec
 import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
@@ -37,7 +37,7 @@ import 
org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, 
ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
 import org.apache.spark.sql.execution.joins.{BaseJoinExec, 
BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, 
ShuffledJoin, SortMergeJoinExec}
 import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter
-import 
org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
+import 
org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, 
SparkListenerSQLExecutionStart}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
@@ -1150,18 +1150,31 @@ class AdaptiveQueryExecSuite
         SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true",
         SQLConf.PLANNED_WRITE_ENABLED.key -> enabled.toString) {
         withTable("t1") {
-          val df = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col")
-          val plan = df.queryExecution.executedPlan
-          assert(plan.isInstanceOf[CommandResultExec])
-          val commandPhysicalPlan = 
plan.asInstanceOf[CommandResultExec].commandPhysicalPlan
-          if (enabled) {
-            assert(commandPhysicalPlan.isInstanceOf[AdaptiveSparkPlanExec])
-            assert(commandPhysicalPlan.asInstanceOf[AdaptiveSparkPlanExec]
-              .executedPlan.isInstanceOf[DataWritingCommandExec])
-          } else {
-            assert(commandPhysicalPlan.isInstanceOf[DataWritingCommandExec])
-            assert(commandPhysicalPlan.asInstanceOf[DataWritingCommandExec]
-              .child.isInstanceOf[AdaptiveSparkPlanExec])
+          var checkDone = false
+          val listener = new SparkListener {
+            override def onOtherEvent(event: SparkListenerEvent): Unit = {
+              event match {
+                case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) =>
+                  if (enabled) {
+                    assert(planInfo.nodeName == "AdaptiveSparkPlan")
+                    assert(planInfo.children.size == 1)
+                    assert(planInfo.children.head.nodeName ==
+                      "Execute InsertIntoHadoopFsRelationCommand")
+                  } else {
+                    assert(planInfo.nodeName == "Execute 
InsertIntoHadoopFsRelationCommand")
+                  }
+                  checkDone = true
+                case _ => // ignore other events
+              }
+            }
+          }
+          spark.sparkContext.addSparkListener(listener)
+          try {
+            sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect()
+            spark.sparkContext.listenerBus.waitUntilEmpty()
+            assert(checkDone)
+          } finally {
+            spark.sparkContext.removeSparkListener(listener)
           }
         }
       }
@@ -1209,16 +1222,12 @@ class AdaptiveQueryExecSuite
     withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
       SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
       withTable("t1") {
-        var checkDone = false
+        var commands: Seq[SparkPlanInfo] = Seq.empty
         val listener = new SparkListener {
           override def onOtherEvent(event: SparkListenerEvent): Unit = {
             event match {
-              case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) =>
-                assert(planInfo.nodeName == "AdaptiveSparkPlan")
-                assert(planInfo.children.size == 1)
-                assert(planInfo.children.head.nodeName ==
-                  "Execute CreateDataSourceTableAsSelectCommand")
-                checkDone = true
+              case start: SparkListenerSQLExecutionStart =>
+                commands = commands ++ Seq(start.sparkPlanInfo)
               case _ => // ignore other events
             }
           }
@@ -1227,7 +1236,12 @@ class AdaptiveQueryExecSuite
         try {
           sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect()
           spark.sparkContext.listenerBus.waitUntilEmpty()
-          assert(checkDone)
+          assert(commands.size == 3)
+          assert(commands.head.nodeName == "Execute 
CreateDataSourceTableAsSelectCommand")
+          assert(commands(1).nodeName == "AdaptiveSparkPlan")
+          assert(commands(1).children.size == 1)
+          assert(commands(1).children.head.nodeName == "Execute 
InsertIntoHadoopFsRelationCommand")
+          assert(commands(2).nodeName == "CommandResult")
         } finally {
           spark.sparkContext.removeSparkListener(listener)
         }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
index eb2aa09e075..e9c5c77e6d9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
@@ -65,7 +65,7 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils {
       override def onSuccess(funcName: String, qe: QueryExecution, durationNs: 
Long): Unit = {
         qe.optimizedPlan match {
           case w: V1WriteCommand =>
-            if (hasLogicalSort) {
+            if (hasLogicalSort && conf.getConf(SQLConf.PLANNED_WRITE_ENABLED)) 
{
               assert(w.query.isInstanceOf[WriteFiles])
               optimizedPlan = w.query.asInstanceOf[WriteFiles].child
             } else {
@@ -86,16 +86,15 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils {
 
     sparkContext.listenerBus.waitUntilEmpty()
 
+    assert(optimizedPlan != null)
     // Check whether a logical sort node is at the top of the logical plan of 
the write query.
-    if (optimizedPlan != null) {
-      assert(optimizedPlan.isInstanceOf[Sort] == hasLogicalSort,
-        s"Expect hasLogicalSort: $hasLogicalSort, Actual: 
${optimizedPlan.isInstanceOf[Sort]}")
+    assert(optimizedPlan.isInstanceOf[Sort] == hasLogicalSort,
+      s"Expect hasLogicalSort: $hasLogicalSort, Actual: 
${optimizedPlan.isInstanceOf[Sort]}")
 
-      // Check empty2null conversion.
-      val empty2nullExpr = optimizedPlan.exists(p => 
V1WritesUtils.hasEmptyToNull(p.expressions))
-      assert(empty2nullExpr == hasEmpty2Null,
-        s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: $empty2nullExpr. 
Plan:\n$optimizedPlan")
-    }
+    // Check empty2null conversion.
+    val empty2nullExpr = optimizedPlan.exists(p => 
V1WritesUtils.hasEmptyToNull(p.expressions))
+    assert(empty2nullExpr == hasEmpty2Null,
+      s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: $empty2nullExpr. 
Plan:\n$optimizedPlan")
 
     spark.listenerManager.unregister(listener)
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 1f20fb62d37..424052df289 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -34,12 +34,13 @@ import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.command.DataWritingCommandExec
-import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, 
SQLHadoopMapReduceCommitProtocol}
+import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, 
InsertIntoHadoopFsRelationCommand, SQLHadoopMapReduceCommitProtocol, 
V1WriteCommand}
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ShuffleExchangeExec}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
ShuffledHashJoinExec}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.util.QueryExecutionListener
 import org.apache.spark.util.{AccumulatorContext, JsonProtocol}
 
 // Disable AQE because metric info is different with AQE on/off
@@ -832,18 +833,32 @@ class SQLMetricsSuite extends SharedSparkSession with 
SQLMetricsTestUtils
 
   test("SPARK-34567: Add metrics for CTAS operator") {
     withTable("t") {
-      val df = sql("CREATE TABLE t USING PARQUET AS SELECT 1 as a")
-      assert(df.queryExecution.executedPlan.isInstanceOf[CommandResultExec])
-      val commandResultExec = 
df.queryExecution.executedPlan.asInstanceOf[CommandResultExec]
-      val dataWritingCommandExec =
-        
commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec]
-      val createTableAsSelect = dataWritingCommandExec.cmd
-      assert(createTableAsSelect.metrics.contains("numFiles"))
-      assert(createTableAsSelect.metrics("numFiles").value == 1)
-      assert(createTableAsSelect.metrics.contains("numOutputBytes"))
-      assert(createTableAsSelect.metrics("numOutputBytes").value > 0)
-      assert(createTableAsSelect.metrics.contains("numOutputRows"))
-      assert(createTableAsSelect.metrics("numOutputRows").value == 1)
+      var v1WriteCommand: V1WriteCommand = null
+      val listener = new QueryExecutionListener {
+        override def onFailure(f: String, qe: QueryExecution, e: Exception): 
Unit = {}
+        override def onSuccess(funcName: String, qe: QueryExecution, duration: 
Long): Unit = {
+          qe.executedPlan match {
+            case dataWritingCommandExec: DataWritingCommandExec =>
+              val createTableAsSelect = dataWritingCommandExec.cmd
+              v1WriteCommand = 
createTableAsSelect.asInstanceOf[InsertIntoHadoopFsRelationCommand]
+            case _ =>
+          }
+        }
+      }
+      spark.listenerManager.register(listener)
+      try {
+        val df = sql("CREATE TABLE t USING PARQUET AS SELECT 1 as a")
+        sparkContext.listenerBus.waitUntilEmpty()
+        assert(v1WriteCommand != null)
+        assert(v1WriteCommand.metrics.contains("numFiles"))
+        assert(v1WriteCommand.metrics("numFiles").value == 1)
+        assert(v1WriteCommand.metrics.contains("numOutputBytes"))
+        assert(v1WriteCommand.metrics("numOutputBytes").value > 0)
+        assert(v1WriteCommand.metrics.contains("numOutputRows"))
+        assert(v1WriteCommand.metrics("numOutputRows").value == 1)
+      } finally {
+        spark.listenerManager.unregister(listener)
+      }
     }
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index dd6acc983b7..2fc1f10d3ea 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -217,10 +217,14 @@ class DataFrameCallbackSuite extends QueryTest
     withTable("tab") {
       spark.range(10).select($"id", $"id" % 5 as 
"p").write.partitionBy("p").saveAsTable("tab")
       sparkContext.listenerBus.waitUntilEmpty()
-      assert(commands.length == 5)
-      assert(commands(4)._1 == "command")
-      assert(commands(4)._2.isInstanceOf[CreateDataSourceTableAsSelectCommand])
-      assert(commands(4)._2.asInstanceOf[CreateDataSourceTableAsSelectCommand]
+      // CTAS would derive 3 query executions
+      // 1. CreateDataSourceTableAsSelectCommand
+      // 2. InsertIntoHadoopFsRelationCommand
+      // 3. CommandResultExec
+      assert(commands.length == 6)
+      assert(commands(5)._1 == "command")
+      assert(commands(5)._2.isInstanceOf[CreateDataSourceTableAsSelectCommand])
+      assert(commands(5)._2.asInstanceOf[CreateDataSourceTableAsSelectCommand]
         .table.partitionColumnNames == Seq("p"))
     }
 
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
index ce320775027..4dfb2cf65eb 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
@@ -21,43 +21,26 @@ import scala.util.control.NonFatal
 
 import org.apache.spark.sql.{Row, SaveMode, SparkSession}
 import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog}
-import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util.CharVarcharUtils
 import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.command.{DataWritingCommand, DDLUtils}
-import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, 
InsertIntoHadoopFsRelationCommand, LogicalRelation, V1WriteCommand, 
V1WritesUtils}
+import org.apache.spark.sql.execution.command.{DataWritingCommand, DDLUtils, 
LeafRunnableCommand}
+import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, 
InsertIntoHadoopFsRelationCommand, LogicalRelation}
 import org.apache.spark.sql.hive.HiveSessionCatalog
 import org.apache.spark.util.Utils
 
-trait CreateHiveTableAsSelectBase extends V1WriteCommand with 
V1WritesHiveUtils {
+trait CreateHiveTableAsSelectBase extends LeafRunnableCommand {
   val tableDesc: CatalogTable
   val query: LogicalPlan
   val outputColumnNames: Seq[String]
   val mode: SaveMode
 
-  protected val tableIdentifier = tableDesc.identifier
+  assert(query.resolved)
+  override def innerChildren: Seq[LogicalPlan] = query :: Nil
 
-  override lazy val partitionColumns: Seq[Attribute] = {
-    // If the table does not exist the schema should always be empty.
-    val table = if (tableDesc.schema.isEmpty) {
-      val tableSchema = 
CharVarcharUtils.getRawSchema(outputColumns.toStructType, conf)
-      tableDesc.copy(schema = tableSchema)
-    } else {
-      tableDesc
-    }
-    // For CTAS, there is no static partition values to insert.
-    val partition = tableDesc.partitionColumnNames.map(_ -> None).toMap
-    getDynamicPartitionColumns(table, partition, query)
-  }
-
-  override def requiredOrdering: Seq[SortOrder] = {
-    val options = getOptionsWithHiveBucketWrite(tableDesc.bucketSpec)
-    V1WritesUtils.getSortOrder(outputColumns, partitionColumns, 
tableDesc.bucketSpec, options)
-  }
+  protected val tableIdentifier = tableDesc.identifier
 
-  override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
+  override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
     val tableExists = catalog.tableExists(tableIdentifier)
 
@@ -74,8 +57,8 @@ trait CreateHiveTableAsSelectBase extends V1WriteCommand with 
V1WritesHiveUtils
       }
 
       val command = getWritingCommand(catalog, tableDesc, tableExists = true)
-      command.run(sparkSession, child)
-      DataWritingCommand.propogateMetrics(sparkSession.sparkContext, command, 
metrics)
+      val qe = sparkSession.sessionState.executePlan(command)
+      qe.assertCommandExecuted()
     } else {
         tableDesc.storage.locationUri.foreach { p =>
           DataWritingCommand.assertEmptyRootPath(p, mode, 
sparkSession.sessionState.newHadoopConf)
@@ -83,6 +66,7 @@ trait CreateHiveTableAsSelectBase extends V1WriteCommand with 
V1WritesHiveUtils
       // TODO ideally, we should get the output data ready first and then
       // add the relation into catalog, just in case of failure occurs while 
data
       // processing.
+      val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(query, 
outputColumnNames)
       val tableSchema = CharVarcharUtils.getRawSchema(
         outputColumns.toStructType, sparkSession.sessionState.conf)
       assert(tableDesc.schema.isEmpty)
@@ -93,8 +77,8 @@ trait CreateHiveTableAsSelectBase extends V1WriteCommand with 
V1WritesHiveUtils
         // Read back the metadata of the table which was created just now.
         val createdTableMeta = catalog.getTableMetadata(tableDesc.identifier)
         val command = getWritingCommand(catalog, createdTableMeta, tableExists 
= false)
-        command.run(sparkSession, child)
-        DataWritingCommand.propogateMetrics(sparkSession.sparkContext, 
command, metrics)
+        val qe = sparkSession.sessionState.executePlan(command)
+        qe.assertCommandExecuted()
       } catch {
         case NonFatal(e) =>
           // drop the created table.
@@ -154,9 +138,6 @@ case class CreateHiveTableAsSelectCommand(
 
   override def writingCommandClassName: String =
     Utils.getSimpleName(classOf[InsertIntoHiveTable])
-
-  override protected def withNewChildInternal(
-    newChild: LogicalPlan): CreateHiveTableAsSelectCommand = copy(query = 
newChild)
 }
 
 /**
@@ -204,7 +185,4 @@ case class OptimizedCreateHiveTableAsSelectCommand(
 
   override def writingCommandClassName: String =
     Utils.getSimpleName(classOf[InsertIntoHadoopFsRelationCommand])
-
-  override protected def withNewChildInternal(
-    newChild: LogicalPlan): OptimizedCreateHiveTableAsSelectCommand = 
copy(query = newChild)
 }
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
index 85c2cd53957..258b101dd21 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
@@ -102,10 +102,8 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils 
with TestHiveSingleto
 
   test("explain create table command") {
     checkKeywordsExist(sql("explain create table temp__b using hive as select 
* from src limit 2"),
-                   "== Physical Plan ==",
-                   "InsertIntoHiveTable",
-                   "Limit",
-                   "src")
+      "== Physical Plan ==",
+      "CreateHiveTableAsSelect")
 
     checkKeywordsExist(
       sql("explain extended create table temp__b using hive as select * from 
src limit 2"),
@@ -113,10 +111,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils 
with TestHiveSingleto
       "== Analyzed Logical Plan ==",
       "== Optimized Logical Plan ==",
       "== Physical Plan ==",
-      "CreateHiveTableAsSelect",
-      "InsertIntoHiveTable",
-      "Limit",
-      "src")
+      "CreateHiveTableAsSelect")
 
     checkKeywordsExist(sql(
       """
@@ -131,10 +126,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils 
with TestHiveSingleto
       "== Analyzed Logical Plan ==",
       "== Optimized Logical Plan ==",
       "== Physical Plan ==",
-      "CreateHiveTableAsSelect",
-      "InsertIntoHiveTable",
-      "Limit",
-      "src")
+      "CreateHiveTableAsSelect")
   }
 
   test("explain output of physical plan should contain proper codegen stage 
ID",
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala
index 7f6272666a6..c5a84b930a9 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala
@@ -17,12 +17,14 @@
 
 package org.apache.spark.sql.hive.execution
 
-import org.apache.spark.sql.execution.CommandResultExec
+import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
 import org.apache.spark.sql.execution.command.DataWritingCommandExec
+import 
org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, 
V1WriteCommand}
 import org.apache.spark.sql.execution.metric.SQLMetricsTestUtils
 import org.apache.spark.sql.hive.HiveUtils
 import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.util.QueryExecutionListener
 import org.apache.spark.tags.SlowHiveTest
 
 // Disable AQE because metric info is different with AQE on/off
@@ -44,23 +46,36 @@ class SQLMetricsSuite extends SQLMetricsTestUtils with 
TestHiveSingleton
     Seq(false, true).foreach { canOptimized =>
       withSQLConf(HiveUtils.CONVERT_METASTORE_CTAS.key -> 
canOptimized.toString) {
         withTable("t") {
-          val df = sql(s"CREATE TABLE t STORED AS PARQUET AS SELECT 1 as a")
-          
assert(df.queryExecution.executedPlan.isInstanceOf[CommandResultExec])
-          val commandResultExec = 
df.queryExecution.executedPlan.asInstanceOf[CommandResultExec]
-          val dataWritingCommandExec =
-            
commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec]
-          val createTableAsSelect = dataWritingCommandExec.cmd
-          if (canOptimized) {
-            
assert(createTableAsSelect.isInstanceOf[OptimizedCreateHiveTableAsSelectCommand])
-          } else {
-            
assert(createTableAsSelect.isInstanceOf[CreateHiveTableAsSelectCommand])
+          var v1WriteCommand: V1WriteCommand = null
+          val listener = new QueryExecutionListener {
+            override def onFailure(f: String, qe: QueryExecution, e: 
Exception): Unit = {}
+            override def onSuccess(funcName: String, qe: QueryExecution, 
duration: Long): Unit = {
+              qe.executedPlan match {
+                case dataWritingCommandExec: DataWritingCommandExec =>
+                  val createTableAsSelect = dataWritingCommandExec.cmd
+                  v1WriteCommand = if (canOptimized) {
+                    
createTableAsSelect.asInstanceOf[InsertIntoHadoopFsRelationCommand]
+                  } else {
+                    createTableAsSelect.asInstanceOf[InsertIntoHiveTable]
+                  }
+                case _ =>
+              }
+            }
+          }
+          spark.listenerManager.register(listener)
+          try {
+            sql(s"CREATE TABLE t STORED AS PARQUET AS SELECT 1 as a")
+            sparkContext.listenerBus.waitUntilEmpty()
+            assert(v1WriteCommand != null)
+            assert(v1WriteCommand.metrics.contains("numFiles"))
+            assert(v1WriteCommand.metrics("numFiles").value == 1)
+            assert(v1WriteCommand.metrics.contains("numOutputBytes"))
+            assert(v1WriteCommand.metrics("numOutputBytes").value > 0)
+            assert(v1WriteCommand.metrics.contains("numOutputRows"))
+            assert(v1WriteCommand.metrics("numOutputRows").value == 1)
+          } finally {
+            spark.listenerManager.unregister(listener)
           }
-          assert(createTableAsSelect.metrics.contains("numFiles"))
-          assert(createTableAsSelect.metrics("numFiles").value == 1)
-          assert(createTableAsSelect.metrics.contains("numOutputBytes"))
-          assert(createTableAsSelect.metrics("numOutputBytes").value > 0)
-          assert(createTableAsSelect.metrics.contains("numOutputRows"))
-          assert(createTableAsSelect.metrics("numOutputRows").value == 1)
         }
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to