This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4b40920e331 [SPARK-41713][SQL] Make CTAS hold a nested execution for data writing 4b40920e331 is described below commit 4b40920e33176fc8b18380703e4dcf4d16824094 Author: ulysses-you <ulyssesyo...@gmail.com> AuthorDate: Wed Dec 28 17:11:59 2022 +0800 [SPARK-41713][SQL] Make CTAS hold a nested execution for data writing ### What changes were proposed in this pull request? This pr aims to make ctas use a nested execution instead of running data writing cmmand. So, we can clean up ctas itself to remove the unnecessary v1write information. Now, the v1writes only have two implementation: `InsertIntoHadoopFsRelationCommand` and `InsertIntoHiveTable` ### Why are the changes needed? Make v1writes code clear. ```sql EXPLAIN FORMATTED CREATE TABLE t2 USING PARQUET AS SELECT * FROM t; == Physical Plan == Execute CreateDataSourceTableAsSelectCommand (1) +- CreateDataSourceTableAsSelectCommand (2) +- Project (5) +- SubqueryAlias (4) +- LogicalRelation (3) (1) Execute CreateDataSourceTableAsSelectCommand Output: [] (2) CreateDataSourceTableAsSelectCommand Arguments: `spark_catalog`.`default`.`t2`, ErrorIfExists, [c1, c2] (3) LogicalRelation Arguments: parquet, [c1#11, c2#12], `spark_catalog`.`default`.`t`, org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe, false (4) SubqueryAlias Arguments: spark_catalog.default.t (5) Project Arguments: [c1#11, c2#12] ``` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? improve existed test Closes #39220 from ulysses-you/SPARK-41713. Authored-by: ulysses-you <ulyssesyo...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../execution/command/createDataSourceTables.scala | 40 +++------------ .../sql/execution/datasources/DataSource.scala | 48 ++++++------------ .../spark/sql/execution/datasources/V1Writes.scala | 8 +-- .../scala/org/apache/spark/sql/ExplainSuite.scala | 7 ++- .../adaptive/AdaptiveQueryExecSuite.scala | 58 ++++++++++++++-------- .../datasources/V1WriteCommandSuite.scala | 17 +++---- .../sql/execution/metric/SQLMetricsSuite.scala | 41 ++++++++++----- .../spark/sql/util/DataFrameCallbackSuite.scala | 12 +++-- .../execution/CreateHiveTableAsSelectCommand.scala | 46 +++++------------ .../sql/hive/execution/HiveExplainSuite.scala | 16 ++---- .../spark/sql/hive/execution/SQLMetricsSuite.scala | 49 +++++++++++------- 11 files changed, 159 insertions(+), 183 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 9bf9f43829e..bf14ef14cf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -20,13 +20,11 @@ package org.apache.spark.sql.execution.command import java.net.URI import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.{CommandExecutionMode, SparkPlan} +import org.apache.spark.sql.execution.CommandExecutionMode import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType @@ -143,29 +141,11 @@ case class CreateDataSourceTableAsSelectCommand( mode: SaveMode, query: LogicalPlan, outputColumnNames: Seq[String]) - extends V1WriteCommand { - - override def fileFormatProvider: Boolean = { - table.provider.forall { provider => - classOf[FileFormat].isAssignableFrom(DataSource.providingClass(provider, conf)) - } - } - - override lazy val partitionColumns: Seq[Attribute] = { - val unresolvedPartitionColumns = table.partitionColumnNames.map(UnresolvedAttribute.quoted) - DataSource.resolvePartitionColumns( - unresolvedPartitionColumns, - outputColumns, - query, - SparkSession.active.sessionState.conf.resolver) - } - - override def requiredOrdering: Seq[SortOrder] = { - val options = table.storage.properties - V1WritesUtils.getSortOrder(outputColumns, partitionColumns, table.bucketSpec, options) - } + extends LeafRunnableCommand { + assert(query.resolved) + override def innerChildren: Seq[LogicalPlan] = query :: Nil - override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) assert(table.provider.isDefined) @@ -187,7 +167,7 @@ case class CreateDataSourceTableAsSelectCommand( } saveDataIntoTable( - sparkSession, table, table.storage.locationUri, child, SaveMode.Append, tableExists = true) + sparkSession, table, table.storage.locationUri, SaveMode.Append, tableExists = true) } else { table.storage.locationUri.foreach { p => DataWritingCommand.assertEmptyRootPath(p, mode, sparkSession.sessionState.newHadoopConf) @@ -200,7 +180,7 @@ case class CreateDataSourceTableAsSelectCommand( table.storage.locationUri } val result = saveDataIntoTable( - sparkSession, table, tableLocation, child, SaveMode.Overwrite, tableExists = false) + sparkSession, table, tableLocation, SaveMode.Overwrite, tableExists = false) val tableSchema = CharVarcharUtils.getRawSchema(result.schema, sessionState.conf) val newTable = table.copy( storage = table.storage.copy(locationUri = tableLocation), @@ -232,7 +212,6 @@ case class CreateDataSourceTableAsSelectCommand( session: SparkSession, table: CatalogTable, tableLocation: Option[URI], - physicalPlan: SparkPlan, mode: SaveMode, tableExists: Boolean): BaseRelation = { // Create the relation based on the input logical plan: `query`. @@ -246,14 +225,11 @@ case class CreateDataSourceTableAsSelectCommand( catalogTable = if (tableExists) Some(table) else None) try { - dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan, metrics) + dataSource.writeAndRead(mode, query, outputColumnNames) } catch { case ex: AnalysisException => logError(s"Failed to write to table ${table.identifier.unquotedString}", ex) throw ex } } - - override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = - copy(query = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 3d8eb9bc8a8..edbdd6bbc67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, TypeUtils} import org.apache.spark.sql.connector.catalog.TableProvider import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider @@ -45,7 +44,6 @@ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 -import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} import org.apache.spark.sql.internal.SQLConf @@ -97,8 +95,19 @@ case class DataSource( case class SourceInfo(name: String, schema: StructType, partitionColumns: Seq[String]) - lazy val providingClass: Class[_] = - DataSource.providingClass(className, sparkSession.sessionState.conf) + lazy val providingClass: Class[_] = { + val cls = DataSource.lookupDataSource(className, sparkSession.sessionState.conf) + // `providingClass` is used for resolving data source relation for catalog tables. + // As now catalog for data source V2 is under development, here we fall back all the + // [[FileDataSourceV2]] to [[FileFormat]] to guarantee the current catalog works. + // [[FileDataSourceV2]] will still be used if we call the load()/save() method in + // [[DataFrameReader]]/[[DataFrameWriter]], since they use method `lookupDataSource` + // instead of `providingClass`. + cls.newInstance() match { + case f: FileDataSourceV2 => f.fallbackFileFormat + case _ => cls + } + } private[sql] def providingInstance(): Any = providingClass.getConstructor().newInstance() @@ -483,17 +492,11 @@ case class DataSource( * @param outputColumnNames The original output column names of the input query plan. The * optimizer may not preserve the output column's names' case, so we need * this parameter instead of `data.output`. - * @param physicalPlan The physical plan of the input query plan. We should run the writing - * command with this physical plan instead of creating a new physical plan, - * so that the metrics can be correctly linked to the given physical plan and - * shown in the web UI. */ def writeAndRead( mode: SaveMode, data: LogicalPlan, - outputColumnNames: Seq[String], - physicalPlan: SparkPlan, - metrics: Map[String, SQLMetric]): BaseRelation = { + outputColumnNames: Seq[String]): BaseRelation = { val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(data, outputColumnNames) providingInstance() match { case dataSource: CreatableRelationProvider => @@ -503,13 +506,8 @@ case class DataSource( case format: FileFormat => disallowWritingIntervals(outputColumns.map(_.dataType), forbidAnsiIntervals = false) val cmd = planForWritingFileFormat(format, mode, data) - val resolvedPartCols = - DataSource.resolvePartitionColumns(cmd.partitionColumns, outputColumns, data, equality) - val resolved = cmd.copy( - partitionColumns = resolvedPartCols, - outputColumnNames = outputColumnNames) - resolved.run(sparkSession, physicalPlan) - DataWritingCommand.propogateMetrics(sparkSession.sparkContext, resolved, metrics) + val qe = sparkSession.sessionState.executePlan(cmd) + qe.assertCommandExecuted() // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation() case _ => throw new IllegalStateException( @@ -832,18 +830,4 @@ object DataSource extends Logging { } } } - - def providingClass(className: String, conf: SQLConf): Class[_] = { - val cls = DataSource.lookupDataSource(className, conf) - // `providingClass` is used for resolving data source relation for catalog tables. - // As now catalog for data source V2 is under development, here we fall back all the - // [[FileDataSourceV2]] to [[FileFormat]] to guarantee the current catalog works. - // [[FileDataSourceV2]] will still be used if we call the load()/save() method in - // [[DataFrameReader]]/[[DataFrameWriter]], since they use method `lookupDataSource` - // instead of `providingClass`. - cls.newInstance() match { - case f: FileDataSourceV2 => f.fallbackFileFormat - case _ => cls - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala index e9f6e3df785..3ed04e5bd6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala @@ -31,11 +31,6 @@ import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String trait V1WriteCommand extends DataWritingCommand { - /** - * Return if the provider is [[FileFormat]] - */ - def fileFormatProvider: Boolean = true - /** * Specify the partition columns of the V1 write command. */ @@ -58,8 +53,7 @@ object V1Writes extends Rule[LogicalPlan] with SQLConfHelper { override def apply(plan: LogicalPlan): LogicalPlan = { if (conf.plannedWriteEnabled) { plan.transformUp { - case write: V1WriteCommand if write.fileFormatProvider && - !write.child.isInstanceOf[WriteFiles] => + case write: V1WriteCommand if !write.child.isInstanceOf[WriteFiles] => val newQuery = prepareQuery(write, write.query) val attrMap = AttributeMap(write.query.output.zip(newQuery.output)) val newChild = WriteFiles(newQuery) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index b5353455dc2..9a75cc5ff8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -250,7 +250,12 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite withTable("temptable") { val df = sql("create table temptable using parquet as select * from range(2)") withNormalizedExplain(df, SimpleMode) { normalizedOutput => - assert("Create\\w*?TableAsSelectCommand".r.findAllMatchIn(normalizedOutput).length == 1) + // scalastyle:off + // == Physical Plan == + // Execute CreateDataSourceTableAsSelectCommand + // +- CreateDataSourceTableAsSelectCommand `spark_catalog`.`default`.`temptable`, ErrorIfExists, Project [id#5L], [id] + // scalastyle:on + assert("Create\\w*?TableAsSelectCommand".r.findAllMatchIn(normalizedOutput).length == 2) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 88baf76ba7a..1f10ff36acb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListe import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnionExec} +import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.noop.NoopDataSource @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter -import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate +import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLExecutionStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode @@ -1150,18 +1150,31 @@ class AdaptiveQueryExecSuite SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true", SQLConf.PLANNED_WRITE_ENABLED.key -> enabled.toString) { withTable("t1") { - val df = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col") - val plan = df.queryExecution.executedPlan - assert(plan.isInstanceOf[CommandResultExec]) - val commandPhysicalPlan = plan.asInstanceOf[CommandResultExec].commandPhysicalPlan - if (enabled) { - assert(commandPhysicalPlan.isInstanceOf[AdaptiveSparkPlanExec]) - assert(commandPhysicalPlan.asInstanceOf[AdaptiveSparkPlanExec] - .executedPlan.isInstanceOf[DataWritingCommandExec]) - } else { - assert(commandPhysicalPlan.isInstanceOf[DataWritingCommandExec]) - assert(commandPhysicalPlan.asInstanceOf[DataWritingCommandExec] - .child.isInstanceOf[AdaptiveSparkPlanExec]) + var checkDone = false + val listener = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) => + if (enabled) { + assert(planInfo.nodeName == "AdaptiveSparkPlan") + assert(planInfo.children.size == 1) + assert(planInfo.children.head.nodeName == + "Execute InsertIntoHadoopFsRelationCommand") + } else { + assert(planInfo.nodeName == "Execute InsertIntoHadoopFsRelationCommand") + } + checkDone = true + case _ => // ignore other events + } + } + } + spark.sparkContext.addSparkListener(listener) + try { + sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect() + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(checkDone) + } finally { + spark.sparkContext.removeSparkListener(listener) } } } @@ -1209,16 +1222,12 @@ class AdaptiveQueryExecSuite withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { withTable("t1") { - var checkDone = false + var commands: Seq[SparkPlanInfo] = Seq.empty val listener = new SparkListener { override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { - case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) => - assert(planInfo.nodeName == "AdaptiveSparkPlan") - assert(planInfo.children.size == 1) - assert(planInfo.children.head.nodeName == - "Execute CreateDataSourceTableAsSelectCommand") - checkDone = true + case start: SparkListenerSQLExecutionStart => + commands = commands ++ Seq(start.sparkPlanInfo) case _ => // ignore other events } } @@ -1227,7 +1236,12 @@ class AdaptiveQueryExecSuite try { sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect() spark.sparkContext.listenerBus.waitUntilEmpty() - assert(checkDone) + assert(commands.size == 3) + assert(commands.head.nodeName == "Execute CreateDataSourceTableAsSelectCommand") + assert(commands(1).nodeName == "AdaptiveSparkPlan") + assert(commands(1).children.size == 1) + assert(commands(1).children.head.nodeName == "Execute InsertIntoHadoopFsRelationCommand") + assert(commands(2).nodeName == "CommandResult") } finally { spark.sparkContext.removeSparkListener(listener) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala index eb2aa09e075..e9c5c77e6d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala @@ -65,7 +65,7 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { qe.optimizedPlan match { case w: V1WriteCommand => - if (hasLogicalSort) { + if (hasLogicalSort && conf.getConf(SQLConf.PLANNED_WRITE_ENABLED)) { assert(w.query.isInstanceOf[WriteFiles]) optimizedPlan = w.query.asInstanceOf[WriteFiles].child } else { @@ -86,16 +86,15 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils { sparkContext.listenerBus.waitUntilEmpty() + assert(optimizedPlan != null) // Check whether a logical sort node is at the top of the logical plan of the write query. - if (optimizedPlan != null) { - assert(optimizedPlan.isInstanceOf[Sort] == hasLogicalSort, - s"Expect hasLogicalSort: $hasLogicalSort, Actual: ${optimizedPlan.isInstanceOf[Sort]}") + assert(optimizedPlan.isInstanceOf[Sort] == hasLogicalSort, + s"Expect hasLogicalSort: $hasLogicalSort, Actual: ${optimizedPlan.isInstanceOf[Sort]}") - // Check empty2null conversion. - val empty2nullExpr = optimizedPlan.exists(p => V1WritesUtils.hasEmptyToNull(p.expressions)) - assert(empty2nullExpr == hasEmpty2Null, - s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: $empty2nullExpr. Plan:\n$optimizedPlan") - } + // Check empty2null conversion. + val empty2nullExpr = optimizedPlan.exists(p => V1WritesUtils.hasEmptyToNull(p.expressions)) + assert(empty2nullExpr == hasEmpty2Null, + s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: $empty2nullExpr. Plan:\n$optimizedPlan") spark.listenerManager.unregister(listener) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 1f20fb62d37..424052df289 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -34,12 +34,13 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommandExec -import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, SQLHadoopMapReduceCommitProtocol} +import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, InsertIntoHadoopFsRelationCommand, SQLHadoopMapReduceCommitProtocol, V1WriteCommand} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.util.{AccumulatorContext, JsonProtocol} // Disable AQE because metric info is different with AQE on/off @@ -832,18 +833,32 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils test("SPARK-34567: Add metrics for CTAS operator") { withTable("t") { - val df = sql("CREATE TABLE t USING PARQUET AS SELECT 1 as a") - assert(df.queryExecution.executedPlan.isInstanceOf[CommandResultExec]) - val commandResultExec = df.queryExecution.executedPlan.asInstanceOf[CommandResultExec] - val dataWritingCommandExec = - commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec] - val createTableAsSelect = dataWritingCommandExec.cmd - assert(createTableAsSelect.metrics.contains("numFiles")) - assert(createTableAsSelect.metrics("numFiles").value == 1) - assert(createTableAsSelect.metrics.contains("numOutputBytes")) - assert(createTableAsSelect.metrics("numOutputBytes").value > 0) - assert(createTableAsSelect.metrics.contains("numOutputRows")) - assert(createTableAsSelect.metrics("numOutputRows").value == 1) + var v1WriteCommand: V1WriteCommand = null + val listener = new QueryExecutionListener { + override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {} + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + qe.executedPlan match { + case dataWritingCommandExec: DataWritingCommandExec => + val createTableAsSelect = dataWritingCommandExec.cmd + v1WriteCommand = createTableAsSelect.asInstanceOf[InsertIntoHadoopFsRelationCommand] + case _ => + } + } + } + spark.listenerManager.register(listener) + try { + val df = sql("CREATE TABLE t USING PARQUET AS SELECT 1 as a") + sparkContext.listenerBus.waitUntilEmpty() + assert(v1WriteCommand != null) + assert(v1WriteCommand.metrics.contains("numFiles")) + assert(v1WriteCommand.metrics("numFiles").value == 1) + assert(v1WriteCommand.metrics.contains("numOutputBytes")) + assert(v1WriteCommand.metrics("numOutputBytes").value > 0) + assert(v1WriteCommand.metrics.contains("numOutputRows")) + assert(v1WriteCommand.metrics("numOutputRows").value == 1) + } finally { + spark.listenerManager.unregister(listener) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index dd6acc983b7..2fc1f10d3ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -217,10 +217,14 @@ class DataFrameCallbackSuite extends QueryTest withTable("tab") { spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab") sparkContext.listenerBus.waitUntilEmpty() - assert(commands.length == 5) - assert(commands(4)._1 == "command") - assert(commands(4)._2.isInstanceOf[CreateDataSourceTableAsSelectCommand]) - assert(commands(4)._2.asInstanceOf[CreateDataSourceTableAsSelectCommand] + // CTAS would derive 3 query executions + // 1. CreateDataSourceTableAsSelectCommand + // 2. InsertIntoHadoopFsRelationCommand + // 3. CommandResultExec + assert(commands.length == 6) + assert(commands(5)._1 == "command") + assert(commands(5)._2.isInstanceOf[CreateDataSourceTableAsSelectCommand]) + assert(commands(5)._2.asInstanceOf[CreateDataSourceTableAsSelectCommand] .table.partitionColumnNames == Seq("p")) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index ce320775027..4dfb2cf65eb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -21,43 +21,26 @@ import scala.util.control.NonFatal import org.apache.spark.sql.{Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog} -import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.command.{DataWritingCommand, DDLUtils} -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InsertIntoHadoopFsRelationCommand, LogicalRelation, V1WriteCommand, V1WritesUtils} +import org.apache.spark.sql.execution.command.{DataWritingCommand, DDLUtils, LeafRunnableCommand} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InsertIntoHadoopFsRelationCommand, LogicalRelation} import org.apache.spark.sql.hive.HiveSessionCatalog import org.apache.spark.util.Utils -trait CreateHiveTableAsSelectBase extends V1WriteCommand with V1WritesHiveUtils { +trait CreateHiveTableAsSelectBase extends LeafRunnableCommand { val tableDesc: CatalogTable val query: LogicalPlan val outputColumnNames: Seq[String] val mode: SaveMode - protected val tableIdentifier = tableDesc.identifier + assert(query.resolved) + override def innerChildren: Seq[LogicalPlan] = query :: Nil - override lazy val partitionColumns: Seq[Attribute] = { - // If the table does not exist the schema should always be empty. - val table = if (tableDesc.schema.isEmpty) { - val tableSchema = CharVarcharUtils.getRawSchema(outputColumns.toStructType, conf) - tableDesc.copy(schema = tableSchema) - } else { - tableDesc - } - // For CTAS, there is no static partition values to insert. - val partition = tableDesc.partitionColumnNames.map(_ -> None).toMap - getDynamicPartitionColumns(table, partition, query) - } - - override def requiredOrdering: Seq[SortOrder] = { - val options = getOptionsWithHiveBucketWrite(tableDesc.bucketSpec) - V1WritesUtils.getSortOrder(outputColumns, partitionColumns, tableDesc.bucketSpec, options) - } + protected val tableIdentifier = tableDesc.identifier - override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val tableExists = catalog.tableExists(tableIdentifier) @@ -74,8 +57,8 @@ trait CreateHiveTableAsSelectBase extends V1WriteCommand with V1WritesHiveUtils } val command = getWritingCommand(catalog, tableDesc, tableExists = true) - command.run(sparkSession, child) - DataWritingCommand.propogateMetrics(sparkSession.sparkContext, command, metrics) + val qe = sparkSession.sessionState.executePlan(command) + qe.assertCommandExecuted() } else { tableDesc.storage.locationUri.foreach { p => DataWritingCommand.assertEmptyRootPath(p, mode, sparkSession.sessionState.newHadoopConf) @@ -83,6 +66,7 @@ trait CreateHiveTableAsSelectBase extends V1WriteCommand with V1WritesHiveUtils // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. + val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(query, outputColumnNames) val tableSchema = CharVarcharUtils.getRawSchema( outputColumns.toStructType, sparkSession.sessionState.conf) assert(tableDesc.schema.isEmpty) @@ -93,8 +77,8 @@ trait CreateHiveTableAsSelectBase extends V1WriteCommand with V1WritesHiveUtils // Read back the metadata of the table which was created just now. val createdTableMeta = catalog.getTableMetadata(tableDesc.identifier) val command = getWritingCommand(catalog, createdTableMeta, tableExists = false) - command.run(sparkSession, child) - DataWritingCommand.propogateMetrics(sparkSession.sparkContext, command, metrics) + val qe = sparkSession.sessionState.executePlan(command) + qe.assertCommandExecuted() } catch { case NonFatal(e) => // drop the created table. @@ -154,9 +138,6 @@ case class CreateHiveTableAsSelectCommand( override def writingCommandClassName: String = Utils.getSimpleName(classOf[InsertIntoHiveTable]) - - override protected def withNewChildInternal( - newChild: LogicalPlan): CreateHiveTableAsSelectCommand = copy(query = newChild) } /** @@ -204,7 +185,4 @@ case class OptimizedCreateHiveTableAsSelectCommand( override def writingCommandClassName: String = Utils.getSimpleName(classOf[InsertIntoHadoopFsRelationCommand]) - - override protected def withNewChildInternal( - newChild: LogicalPlan): OptimizedCreateHiveTableAsSelectCommand = copy(query = newChild) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 85c2cd53957..258b101dd21 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -102,10 +102,8 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto test("explain create table command") { checkKeywordsExist(sql("explain create table temp__b using hive as select * from src limit 2"), - "== Physical Plan ==", - "InsertIntoHiveTable", - "Limit", - "src") + "== Physical Plan ==", + "CreateHiveTableAsSelect") checkKeywordsExist( sql("explain extended create table temp__b using hive as select * from src limit 2"), @@ -113,10 +111,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", "== Physical Plan ==", - "CreateHiveTableAsSelect", - "InsertIntoHiveTable", - "Limit", - "src") + "CreateHiveTableAsSelect") checkKeywordsExist(sql( """ @@ -131,10 +126,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", "== Physical Plan ==", - "CreateHiveTableAsSelect", - "InsertIntoHiveTable", - "Limit", - "src") + "CreateHiveTableAsSelect") } test("explain output of physical plan should contain proper codegen stage ID", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala index 7f6272666a6..c5a84b930a9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.execution.CommandResultExec +import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, V1WriteCommand} import org.apache.spark.sql.execution.metric.SQLMetricsTestUtils import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.tags.SlowHiveTest // Disable AQE because metric info is different with AQE on/off @@ -44,23 +46,36 @@ class SQLMetricsSuite extends SQLMetricsTestUtils with TestHiveSingleton Seq(false, true).foreach { canOptimized => withSQLConf(HiveUtils.CONVERT_METASTORE_CTAS.key -> canOptimized.toString) { withTable("t") { - val df = sql(s"CREATE TABLE t STORED AS PARQUET AS SELECT 1 as a") - assert(df.queryExecution.executedPlan.isInstanceOf[CommandResultExec]) - val commandResultExec = df.queryExecution.executedPlan.asInstanceOf[CommandResultExec] - val dataWritingCommandExec = - commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec] - val createTableAsSelect = dataWritingCommandExec.cmd - if (canOptimized) { - assert(createTableAsSelect.isInstanceOf[OptimizedCreateHiveTableAsSelectCommand]) - } else { - assert(createTableAsSelect.isInstanceOf[CreateHiveTableAsSelectCommand]) + var v1WriteCommand: V1WriteCommand = null + val listener = new QueryExecutionListener { + override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {} + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + qe.executedPlan match { + case dataWritingCommandExec: DataWritingCommandExec => + val createTableAsSelect = dataWritingCommandExec.cmd + v1WriteCommand = if (canOptimized) { + createTableAsSelect.asInstanceOf[InsertIntoHadoopFsRelationCommand] + } else { + createTableAsSelect.asInstanceOf[InsertIntoHiveTable] + } + case _ => + } + } + } + spark.listenerManager.register(listener) + try { + sql(s"CREATE TABLE t STORED AS PARQUET AS SELECT 1 as a") + sparkContext.listenerBus.waitUntilEmpty() + assert(v1WriteCommand != null) + assert(v1WriteCommand.metrics.contains("numFiles")) + assert(v1WriteCommand.metrics("numFiles").value == 1) + assert(v1WriteCommand.metrics.contains("numOutputBytes")) + assert(v1WriteCommand.metrics("numOutputBytes").value > 0) + assert(v1WriteCommand.metrics.contains("numOutputRows")) + assert(v1WriteCommand.metrics("numOutputRows").value == 1) + } finally { + spark.listenerManager.unregister(listener) } - assert(createTableAsSelect.metrics.contains("numFiles")) - assert(createTableAsSelect.metrics("numFiles").value == 1) - assert(createTableAsSelect.metrics.contains("numOutputBytes")) - assert(createTableAsSelect.metrics("numOutputBytes").value > 0) - assert(createTableAsSelect.metrics.contains("numOutputRows")) - assert(createTableAsSelect.metrics("numOutputRows").value == 1) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org