Repository: spark
Updated Branches:
  refs/heads/master 84454d7d3 -> 95ad960ca


[SPARK-21669] Internal API for collecting metrics/stats during FileFormatWriter 
jobs

## What changes were proposed in this pull request?

This patch introduces an internal interface for tracking metrics and/or 
statistics on data on the fly, as it is being written to disk during a 
`FileFormatWriter` job and partially reimplements SPARK-20703 in terms of it.

The interface basically consists of 3 traits:
- `WriteTaskStats`: just a tag for classes that represent statistics collected 
during a `WriteTask`
  The only constraint it adds is that the class should be `Serializable`, as 
instances of it will be collected on the driver from all executors at the end 
of the `WriteJob`.
- `WriteTaskStatsTracker`: a trait for classes that can actually compute 
statistics based on tuples that are processed by a given `WriteTask` and 
eventually produce a `WriteTaskStats` instance.
- `WriteJobStatsTracker`: a trait for classes that act as containers of 
`Serializable` state that's necessary for instantiating `WriteTaskStatsTracker` 
on executors and finally process the resulting collection of `WriteTaskStats`, 
once they're gathered back on the driver.

Potential future use of this interface is e.g. CBO stats maintenance during 
`INSERT INTO table ... ` operations.

## How was this patch tested?
Existing tests for SPARK-20703 exercise the new code: `hive/SQLMetricsSuite`, 
`sql/JavaDataFrameReaderWriterSuite`, etc.

Author: Adrian Ionescu <adr...@databricks.com>

Closes #18884 from adrian-ionescu/write-stats-tracker-api.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/95ad960c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/95ad960c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/95ad960c

Branch: refs/heads/master
Commit: 95ad960caf009d843ec700ee41cbccc2fa3a68a5
Parents: 84454d7
Author: Adrian Ionescu <adr...@databricks.com>
Authored: Thu Aug 10 12:37:10 2017 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Thu Aug 10 12:37:10 2017 -0700

----------------------------------------------------------------------
 .../execution/command/DataWritingCommand.scala  |  34 +--
 .../datasources/BasicWriteStatsTracker.scala    | 133 ++++++++++
 .../datasources/FileFormatWriter.scala          | 245 ++++++++++---------
 .../InsertIntoHadoopFsRelationCommand.scala     |  43 ++--
 .../datasources/WriteStatsTracker.scala         | 121 +++++++++
 .../execution/streaming/FileStreamSink.scala    |   2 +-
 .../hive/execution/InsertIntoHiveTable.scala    |   4 +-
 7 files changed, 420 insertions(+), 162 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/95ad960c/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala
index 700f7f8..4e1c5e4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala
@@ -17,10 +17,13 @@
 
 package org.apache.spark.sql.execution.command
 
+import org.apache.hadoop.conf.Configuration
+
 import org.apache.spark.SparkContext
-import org.apache.spark.sql.execution.SQLExecution
-import org.apache.spark.sql.execution.datasources.ExecutedWriteSummary
+import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.util.SerializableConfiguration
+
 
 /**
  * A special `RunnableCommand` which writes data out and updates metrics.
@@ -37,29 +40,8 @@ trait DataWritingCommand extends RunnableCommand {
     )
   }
 
-  /**
-   * Callback function that update metrics collected from the writing 
operation.
-   */
-  protected def updateWritingMetrics(writeSummaries: 
Seq[ExecutedWriteSummary]): Unit = {
-    val sparkContext = SparkContext.getActive.get
-    var numPartitions = 0
-    var numFiles = 0
-    var totalNumBytes: Long = 0L
-    var totalNumOutput: Long = 0L
-
-    writeSummaries.foreach { summary =>
-      numPartitions += summary.updatedPartitions.size
-      numFiles += summary.numOutputFile
-      totalNumBytes += summary.numOutputBytes
-      totalNumOutput += summary.numOutputRows
-    }
-
-    metrics("numFiles").add(numFiles)
-    metrics("numOutputBytes").add(totalNumBytes)
-    metrics("numOutputRows").add(totalNumOutput)
-    metrics("numParts").add(numPartitions)
-
-    val executionId = 
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
-    SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, 
metrics.values.toList)
+  def basicWriteJobStatsTracker(hadoopConf: Configuration): 
BasicWriteJobStatsTracker = {
+    val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
+    new BasicWriteJobStatsTracker(serializableHadoopConf, metrics)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/95ad960c/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
new file mode 100644
index 0000000..b8f7d13
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.util.SerializableConfiguration
+
+
+/**
+ * Simple metrics collected during an instance of 
[[FileFormatWriter.ExecuteWriteTask]].
+ * These were first introduced in https://github.com/apache/spark/pull/18159 
(SPARK-20703).
+ */
+case class BasicWriteTaskStats(
+    numPartitions: Int,
+    numFiles: Int,
+    numBytes: Long,
+    numRows: Long)
+  extends WriteTaskStats
+
+
+/**
+ * Simple [[WriteTaskStatsTracker]] implementation that produces 
[[BasicWriteTaskStats]].
+ * @param hadoopConf
+ */
+class BasicWriteTaskStatsTracker(hadoopConf: Configuration)
+  extends WriteTaskStatsTracker {
+
+  private[this] var numPartitions: Int = 0
+  private[this] var numFiles: Int = 0
+  private[this] var numBytes: Long = 0L
+  private[this] var numRows: Long = 0L
+
+  private[this] var curFile: String = null
+
+
+  private def getFileSize(filePath: String): Long = {
+    val path = new Path(filePath)
+    val fs = path.getFileSystem(hadoopConf)
+    fs.getFileStatus(path).getLen()
+  }
+
+
+  override def newPartition(partitionValues: InternalRow): Unit = {
+    numPartitions += 1
+  }
+
+  override def newBucket(bucketId: Int): Unit = {
+    // currently unhandled
+  }
+
+  override def newFile(filePath: String): Unit = {
+    if (numFiles > 0) {
+      // we assume here that we've finished writing to disk the previous file 
by now
+      numBytes += getFileSize(curFile)
+    }
+    curFile = filePath
+    numFiles += 1
+  }
+
+  override def newRow(row: InternalRow): Unit = {
+    numRows += 1
+  }
+
+  override def getFinalStats(): WriteTaskStats = {
+    if (numFiles > 0) {
+      numBytes += getFileSize(curFile)
+    }
+    BasicWriteTaskStats(numPartitions, numFiles, numBytes, numRows)
+  }
+}
+
+
+/**
+ * Simple [[WriteJobStatsTracker]] implementation that's serializable, capable 
of
+ * instantiating [[BasicWriteTaskStatsTracker]] on executors and processing the
+ * [[BasicWriteTaskStats]] they produce by aggregating the metrics and posting 
them
+ * as DriverMetricUpdates.
+ */
+class BasicWriteJobStatsTracker(
+    serializableHadoopConf: SerializableConfiguration,
+    @transient val metrics: Map[String, SQLMetric])
+  extends WriteJobStatsTracker {
+
+  override def newTaskInstance(): WriteTaskStatsTracker = {
+    new BasicWriteTaskStatsTracker(serializableHadoopConf.value)
+  }
+
+  override def processStats(stats: Seq[WriteTaskStats]): Unit = {
+    val sparkContext = SparkContext.getActive.get
+    var numPartitions: Long = 0L
+    var numFiles: Long = 0L
+    var totalNumBytes: Long = 0L
+    var totalNumOutput: Long = 0L
+
+    val basicStats = stats.map(_.asInstanceOf[BasicWriteTaskStats])
+
+    basicStats.foreach { summary =>
+      numPartitions += summary.numPartitions
+      numFiles += summary.numFiles
+      totalNumBytes += summary.numBytes
+      totalNumOutput += summary.numRows
+    }
+
+    metrics("numFiles").add(numFiles)
+    metrics("numOutputBytes").add(totalNumBytes)
+    metrics("numOutputRows").add(totalNumOutput)
+    metrics("numParts").add(numPartitions)
+
+    val executionId = 
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+    SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, 
metrics.values.toList)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/95ad960c/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 073e878..68aaa8a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -22,7 +22,7 @@ import java.util.{Date, UUID}
 import scala.collection.mutable
 
 import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.Path
 import org.apache.hadoop.mapreduce._
 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
 import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
@@ -35,7 +35,7 @@ import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils}
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _}
 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
@@ -70,7 +70,8 @@ object FileFormatWriter extends Logging {
       val path: String,
       val customPartitionLocations: Map[TablePartitionSpec, String],
       val maxRecordsPerFile: Long,
-      val timeZoneId: String)
+      val timeZoneId: String,
+      val statsTrackers: Seq[WriteJobStatsTracker])
     extends Serializable {
 
     assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ 
dataColumns),
@@ -94,6 +95,9 @@ object FileFormatWriter extends Logging {
    *    exception is thrown during task commitment, also aborts that task.
    * 4. If all tasks are committed, commit the job, otherwise aborts the job;  
If any exception is
    *    thrown during job commitment, also aborts the job.
+   * 5. If the job is successfully committed, perform post-commit operations 
such as
+   *    processing statistics.
+   * @return The set of all partition paths that were updated during this 
write job.
    */
   def write(
       sparkSession: SparkSession,
@@ -104,8 +108,9 @@ object FileFormatWriter extends Logging {
       hadoopConf: Configuration,
       partitionColumns: Seq[Attribute],
       bucketSpec: Option[BucketSpec],
-      refreshFunction: (Seq[ExecutedWriteSummary]) => Unit,
-      options: Map[String, String]): Unit = {
+      statsTrackers: Seq[WriteJobStatsTracker],
+      options: Map[String, String])
+    : Set[String] = {
 
     val job = Job.getInstance(hadoopConf)
     job.setOutputKeyClass(classOf[Void])
@@ -146,7 +151,8 @@ object FileFormatWriter extends Logging {
       maxRecordsPerFile = 
caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
         .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile),
       timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION)
-        .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone)
+        .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone),
+      statsTrackers = statsTrackers
     )
 
     // We should first sort by partition columns, then bucket id, and finally 
sorting columns.
@@ -199,7 +205,12 @@ object FileFormatWriter extends Logging {
 
       committer.commitJob(job, commitMsgs)
       logInfo(s"Job ${job.getJobID} committed.")
-      refreshFunction(ret.map(_.summary))
+
+      processStats(description.statsTrackers, ret.map(_.summary.stats))
+      logInfo(s"Finished processing stats for job ${job.getJobID}.")
+
+      // return a set of all the partition paths that were updated during this 
job
+      ret.map(_.summary.updatedPartitions).reduceOption(_ ++ 
_).getOrElse(Set.empty)
     } catch { case cause: Throwable =>
       logError(s"Aborting job ${job.getJobID}.", cause)
       committer.abortJob(job)
@@ -238,7 +249,7 @@ object FileFormatWriter extends Logging {
     val writeTask =
       if (sparkPartitionId != 0 && !iterator.hasNext) {
         // In case of empty job, leave first partition to save meta for file 
format like parquet.
-        new EmptyDirectoryWriteTask
+        new EmptyDirectoryWriteTask(description)
       } else if (description.partitionColumns.isEmpty && 
description.bucketIdExpression.isEmpty) {
         new SingleDirectoryWriteTask(description, taskAttemptContext, 
committer)
       } else {
@@ -269,6 +280,33 @@ object FileFormatWriter extends Logging {
   }
 
   /**
+   * For every registered [[WriteJobStatsTracker]], call `processStats()` on 
it, passing it
+   * the corresponding [[WriteTaskStats]] from all executors.
+   */
+  private def processStats(
+      statsTrackers: Seq[WriteJobStatsTracker],
+      statsPerTask: Seq[Seq[WriteTaskStats]])
+    : Unit = {
+
+    val numStatsTrackers = statsTrackers.length
+    assert(statsPerTask.forall(_.length == numStatsTrackers),
+      s"""Every WriteTask should have produced one `WriteTaskStats` object for 
every tracker.
+         |There are $numStatsTrackers statsTrackers, but some task returned
+         |${statsPerTask.find(_.length != numStatsTrackers).get.length} 
results instead.
+       """.stripMargin)
+
+    val statsPerTracker = if (statsPerTask.nonEmpty) {
+      statsPerTask.transpose
+    } else {
+      statsTrackers.map(_ => Seq.empty)
+    }
+
+    statsTrackers.zip(statsPerTracker).foreach {
+      case (statsTracker, stats) => statsTracker.processStats(stats)
+    }
+  }
+
+  /**
    * A simple trait for writing out data in a single Spark task, without any 
concerns about how
    * to commit or abort tasks. Exceptions thrown by the implementation of this 
trait will
    * automatically trigger task aborts.
@@ -276,43 +314,26 @@ object FileFormatWriter extends Logging {
   private trait ExecuteWriteTask {
 
     /**
-     * The data structures used to measure metrics during writing.
-     */
-    protected var numOutputRows: Long = 0L
-    protected var numOutputBytes: Long = 0L
-
-    /**
      * Writes data out to files, and then returns the summary of relative 
information which
      * includes the list of partition strings written out. The list of 
partitions is sent back
      * to the driver and used to update the catalog. Other information will be 
sent back to the
-     * driver too and used to update the metrics in UI.
+     * driver too and used to e.g. update the metrics in UI.
      */
     def execute(iterator: Iterator[InternalRow]): ExecutedWriteSummary
     def releaseResources(): Unit
-
-    /**
-     * A helper function used to determine the size in  bytes of a written 
file.
-     */
-    protected def getFileSize(conf: Configuration, filePath: String): Long = {
-      if (filePath != null) {
-        val path = new Path(filePath)
-        val fs = path.getFileSystem(conf)
-        fs.getFileStatus(path).getLen()
-      } else {
-        0L
-      }
-    }
   }
 
   /** ExecuteWriteTask for empty partitions */
-  private class EmptyDirectoryWriteTask extends ExecuteWriteTask {
+  private class EmptyDirectoryWriteTask(description: WriteJobDescription)
+    extends ExecuteWriteTask {
+
+    val statsTrackers: Seq[WriteTaskStatsTracker] =
+      description.statsTrackers.map(_.newTaskInstance())
 
     override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = {
       ExecutedWriteSummary(
         updatedPartitions = Set.empty,
-        numOutputFile = 0,
-        numOutputBytes = 0,
-        numOutputRows = 0)
+        stats = statsTrackers.map(_.getFinalStats()))
     }
 
     override def releaseResources(): Unit = {}
@@ -325,11 +346,13 @@ object FileFormatWriter extends Logging {
       committer: FileCommitProtocol) extends ExecuteWriteTask {
 
     private[this] var currentWriter: OutputWriter = _
-    private[this] var currentPath: String = _
+
+    val statsTrackers: Seq[WriteTaskStatsTracker] =
+      description.statsTrackers.map(_.newTaskInstance())
 
     private def newOutputWriter(fileCounter: Int): Unit = {
       val ext = 
description.outputWriterFactory.getFileExtension(taskAttemptContext)
-      currentPath = committer.newTaskTempFile(
+      val currentPath = committer.newTaskTempFile(
         taskAttemptContext,
         None,
         f"-c$fileCounter%03d" + ext)
@@ -338,6 +361,8 @@ object FileFormatWriter extends Logging {
         path = currentPath,
         dataSchema = description.dataColumns.toStructType,
         context = taskAttemptContext)
+
+      statsTrackers.map(_.newFile(currentPath))
     }
 
     override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = {
@@ -353,29 +378,24 @@ object FileFormatWriter extends Logging {
 
           recordsInFile = 0
           releaseResources()
-          numOutputRows += recordsInFile
           newOutputWriter(fileCounter)
         }
 
         val internalRow = iter.next()
         currentWriter.write(internalRow)
+        statsTrackers.foreach(_.newRow(internalRow))
         recordsInFile += 1
       }
       releaseResources()
-      numOutputRows += recordsInFile
-
       ExecutedWriteSummary(
         updatedPartitions = Set.empty,
-        numOutputFile = fileCounter + 1,
-        numOutputBytes = numOutputBytes,
-        numOutputRows = numOutputRows)
+        stats = statsTrackers.map(_.getFinalStats()))
     }
 
     override def releaseResources(): Unit = {
       if (currentWriter != null) {
         try {
           currentWriter.close()
-          numOutputBytes += getFileSize(taskAttemptContext.getConfiguration, 
currentPath)
         } finally {
           currentWriter = null
         }
@@ -392,30 +412,65 @@ object FileFormatWriter extends Logging {
       taskAttemptContext: TaskAttemptContext,
       committer: FileCommitProtocol) extends ExecuteWriteTask {
 
-    // currentWriter is initialized whenever we see a new key
+    /** Flag saying whether or not the data to be written out is partitioned. 
*/
+    val isPartitioned = desc.partitionColumns.nonEmpty
+
+    /** Flag saying whether or not the data to be written out is bucketed. */
+    val isBucketed = desc.bucketIdExpression.isDefined
+
+    assert(isPartitioned || isBucketed,
+      s"""DynamicPartitionWriteTask should be used for writing out data that's 
either
+         |partitioned or bucketed. In this case neither is true.
+         |WriteJobDescription: ${desc}
+       """.stripMargin)
+
+    // currentWriter is initialized whenever we see a new key (partitionValues 
+ BucketId)
     private var currentWriter: OutputWriter = _
 
-    private var currentPath: String = _
+    /** Trackers for computing various statistics on the data as it's being 
written out. */
+    private val statsTrackers: Seq[WriteTaskStatsTracker] =
+      desc.statsTrackers.map(_.newTaskInstance())
 
-    /** Expressions that given partition columns build a path string like: 
col1=val/col2=val/... */
-    private def partitionPathExpression: Seq[Expression] = {
+    /** Extracts the partition values out of an input row. */
+    private lazy val getPartitionValues: InternalRow => UnsafeRow = {
+      val proj = UnsafeProjection.create(desc.partitionColumns, 
desc.allColumns)
+      row => proj(row)
+    }
+
+    /** Expression that given partition columns builds a path string like: 
col1=val/col2=val/... */
+    private lazy val partitionPathExpression: Expression = Concat(
       desc.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
         val partitionName = ScalaUDF(
           ExternalCatalogUtils.getPartitionPathString _,
           StringType,
           Seq(Literal(c.name), Cast(c, StringType, Option(desc.timeZoneId))))
         if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), 
partitionName)
-      }
+      })
+
+    /** Evaluates the `partitionPathExpression` above on a row of 
`partitionValues` and returns
+     * the partition string. */
+    private lazy val getPartitionPath: InternalRow => String = {
+      val proj = UnsafeProjection.create(Seq(partitionPathExpression), 
desc.partitionColumns)
+      row => proj(row).getString(0)
+    }
+
+    /** Given an input row, returns the corresponding `bucketId` */
+    private lazy val getBucketId: InternalRow => Int = {
+      val proj = UnsafeProjection.create(desc.bucketIdExpression.toSeq, 
desc.allColumns)
+      row => proj(row).getInt(0)
     }
 
+    /** Returns the data columns to be written given an input row */
+    private val getOutputRow = UnsafeProjection.create(desc.dataColumns, 
desc.allColumns)
+
     /**
-     * Opens a new OutputWriter given a partition key and optional bucket id.
+     * Opens a new OutputWriter given a partition key and/or a bucket id.
      * If bucket id is specified, we will append it to the end of the file 
name, but before the
      * file extension, e.g. 
part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
      *
-     * @param partColsAndBucketId a row consisting of partition columns and a 
bucket id for the
-     *                            current row.
-     * @param getPartitionPath a function that projects the partition values 
into a path string.
+     * @param partitionValues the partition which all tuples being written by 
this `OutputWriter`
+     *                        belong to
+     * @param bucketId the bucket which all tuples being written by this 
`OutputWriter` belong to
      * @param fileCounter the number of files that have been written in the 
past for this specific
      *                    partition. This is used to limit the max number of 
records written for a
      *                    single file. The value should start from 0.
@@ -423,36 +478,24 @@ object FileFormatWriter extends Logging {
      *                          path of this writer to it.
      */
     private def newOutputWriter(
-        partColsAndBucketId: InternalRow,
-        getPartitionPath: UnsafeProjection,
+        partitionValues: Option[InternalRow],
+        bucketId: Option[Int],
         fileCounter: Int,
         updatedPartitions: mutable.Set[String]): Unit = {
-      val partDir = if (desc.partitionColumns.isEmpty) {
-        None
-      } else {
-        Option(getPartitionPath(partColsAndBucketId).getString(0))
-      }
+
+      val partDir = partitionValues.map(getPartitionPath(_))
       partDir.foreach(updatedPartitions.add)
 
-      // If the bucketId expression is defined, the bucketId column is right 
after the partition
-      // columns.
-      val bucketId = if (desc.bucketIdExpression.isDefined) {
-        
BucketingUtils.bucketIdToString(partColsAndBucketId.getInt(desc.partitionColumns.length))
-      } else {
-        ""
-      }
+      val bucketIdStr = 
bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
 
       // This must be in a form that matches our bucketing format. See 
BucketingUtils.
-      val ext = f"$bucketId.c$fileCounter%03d" +
+      val ext = f"$bucketIdStr.c$fileCounter%03d" +
         desc.outputWriterFactory.getFileExtension(taskAttemptContext)
 
-      val customPath = partDir match {
-        case Some(dir) =>
+      val customPath = partDir.flatMap { dir =>
           
desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
-        case _ =>
-          None
       }
-      currentPath = if (customPath.isDefined) {
+      val currentPath = if (customPath.isDefined) {
         committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, 
ext)
       } else {
         committer.newTaskTempFile(taskAttemptContext, partDir, ext)
@@ -462,78 +505,66 @@ object FileFormatWriter extends Logging {
         path = currentPath,
         dataSchema = desc.dataColumns.toStructType,
         context = taskAttemptContext)
+
+      statsTrackers.foreach(_.newFile(currentPath))
     }
 
     override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = {
-      val getPartitionColsAndBucketId = UnsafeProjection.create(
-        desc.partitionColumns ++ desc.bucketIdExpression, desc.allColumns)
-
-      // Generates the partition path given the row generated by 
`getPartitionColsAndBucketId`.
-      val getPartPath = UnsafeProjection.create(
-        Seq(Concat(partitionPathExpression)), desc.partitionColumns)
-
-      // Returns the data columns to be written given an input row
-      val getOutputRow = UnsafeProjection.create(desc.dataColumns, 
desc.allColumns)
-
       // If anything below fails, we should abort the task.
       var recordsInFile: Long = 0L
       var fileCounter = 0
-      var totalFileCounter = 0
-      var currentPartColsAndBucketId: UnsafeRow = null
       val updatedPartitions = mutable.Set[String]()
+      var currentPartionValues: Option[UnsafeRow] = None
+      var currentBucketId: Option[Int] = None
 
       for (row <- iter) {
-        val nextPartColsAndBucketId = getPartitionColsAndBucketId(row)
-        if (currentPartColsAndBucketId != nextPartColsAndBucketId) {
-          if (currentPartColsAndBucketId != null) {
-            totalFileCounter += (fileCounter + 1)
-          }
+        val nextPartitionValues = if (isPartitioned) 
Some(getPartitionValues(row)) else None
+        val nextBucketId = if (isBucketed) Some(getBucketId(row)) else None
 
+        if (currentPartionValues != nextPartitionValues || currentBucketId != 
nextBucketId) {
           // See a new partition or bucket - write to a new partition dir (or 
a new bucket file).
-          currentPartColsAndBucketId = nextPartColsAndBucketId.copy()
-          logDebug(s"Writing partition: $currentPartColsAndBucketId")
+          if (isPartitioned && currentPartionValues != nextPartitionValues) {
+            currentPartionValues = Some(nextPartitionValues.get.copy())
+            statsTrackers.foreach(_.newPartition(currentPartionValues.get))
+          }
+          if (isBucketed) {
+            currentBucketId = nextBucketId
+            statsTrackers.foreach(_.newBucket(currentBucketId.get))
+          }
 
-          numOutputRows += recordsInFile
           recordsInFile = 0
           fileCounter = 0
 
           releaseResources()
-          newOutputWriter(currentPartColsAndBucketId, getPartPath, 
fileCounter, updatedPartitions)
+          newOutputWriter(currentPartionValues, currentBucketId, fileCounter, 
updatedPartitions)
         } else if (desc.maxRecordsPerFile > 0 &&
             recordsInFile >= desc.maxRecordsPerFile) {
           // Exceeded the threshold in terms of the number of records per file.
           // Create a new file by increasing the file counter.
-
-          numOutputRows += recordsInFile
           recordsInFile = 0
           fileCounter += 1
           assert(fileCounter < MAX_FILE_COUNTER,
             s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
 
           releaseResources()
-          newOutputWriter(currentPartColsAndBucketId, getPartPath, 
fileCounter, updatedPartitions)
+          newOutputWriter(currentPartionValues, currentBucketId, fileCounter, 
updatedPartitions)
         }
-        currentWriter.write(getOutputRow(row))
+        val outputRow = getOutputRow(row)
+        currentWriter.write(outputRow)
+        statsTrackers.foreach(_.newRow(outputRow))
         recordsInFile += 1
       }
-      if (currentPartColsAndBucketId != null) {
-        totalFileCounter += (fileCounter + 1)
-      }
       releaseResources()
-      numOutputRows += recordsInFile
 
       ExecutedWriteSummary(
         updatedPartitions = updatedPartitions.toSet,
-        numOutputFile = totalFileCounter,
-        numOutputBytes = numOutputBytes,
-        numOutputRows = numOutputRows)
+        stats = statsTrackers.map(_.getFinalStats()))
     }
 
     override def releaseResources(): Unit = {
       if (currentWriter != null) {
         try {
           currentWriter.close()
-          numOutputBytes += getFileSize(taskAttemptContext.getConfiguration, 
currentPath)
         } finally {
           currentWriter = null
         }
@@ -547,12 +578,8 @@ object FileFormatWriter extends Logging {
  *
  * @param updatedPartitions the partitions updated during writing data out. 
Only valid
  *                          for dynamic partition.
- * @param numOutputFile the total number of files.
- * @param numOutputRows the number of output rows.
- * @param numOutputBytes the bytes of output data.
+ * @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` 
that the job had.
  */
 case class ExecutedWriteSummary(
   updatedPartitions: Set[String],
-  numOutputFile: Int,
-  numOutputRows: Long,
-  numOutputBytes: Long)
+  stats: Seq[WriteTaskStats])

http://git-wip-us.apache.org/repos/asf/spark/blob/95ad960c/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
index 9ebe1e4..64e5a57 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
@@ -120,18 +120,10 @@ case class InsertIntoHadoopFsRelationCommand(
 
     if (doInsertion) {
 
-      // Callback for updating metric and metastore partition metadata
-      // after the insertion job completes.
-      def refreshCallback(summary: Seq[ExecutedWriteSummary]): Unit = {
-        val updatedPartitions = summary.flatMap(_.updatedPartitions)
-          .distinct.map(PartitioningUtils.parsePathFragment)
-
-        // Updating metrics.
-        updateWritingMetrics(summary)
-
-        // Updating metastore partition metadata.
+      def refreshUpdatedPartitions(updatedPartitionPaths: Set[String]): Unit = 
{
+        val updatedPartitions = 
updatedPartitionPaths.map(PartitioningUtils.parsePathFragment)
         if (partitionsTrackedByCatalog) {
-          val newPartitions = updatedPartitions.toSet -- 
initialMatchingPartitions
+          val newPartitions = updatedPartitions -- initialMatchingPartitions
           if (newPartitions.nonEmpty) {
             AlterTableAddPartitionCommand(
               catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, 
None)),
@@ -149,18 +141,23 @@ case class InsertIntoHadoopFsRelationCommand(
         }
       }
 
-      FileFormatWriter.write(
-        sparkSession = sparkSession,
-        plan = children.head,
-        fileFormat = fileFormat,
-        committer = committer,
-        outputSpec = FileFormatWriter.OutputSpec(
-          qualifiedOutputPath.toString, customPartitionLocations),
-        hadoopConf = hadoopConf,
-        partitionColumns = partitionColumns,
-        bucketSpec = bucketSpec,
-        refreshFunction = refreshCallback,
-        options = options)
+      val updatedPartitionPaths =
+        FileFormatWriter.write(
+          sparkSession = sparkSession,
+          plan = children.head,
+          fileFormat = fileFormat,
+          committer = committer,
+          outputSpec = FileFormatWriter.OutputSpec(
+            qualifiedOutputPath.toString, customPartitionLocations),
+          hadoopConf = hadoopConf,
+          partitionColumns = partitionColumns,
+          bucketSpec = bucketSpec,
+          statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)),
+          options = options)
+
+
+      // update metastore partition metadata
+      refreshUpdatedPartitions(updatedPartitionPaths)
 
       // refresh cached files in FileIndex
       fileIndex.foreach(_.refresh())

http://git-wip-us.apache.org/repos/asf/spark/blob/95ad960c/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala
new file mode 100644
index 0000000..c39a82e
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.sql.catalyst.InternalRow
+
+
+/**
+ * To be implemented by classes that represent data statistics collected 
during a Write Task.
+ * It is important that instances of this type are [[Serializable]], as they 
will be gathered
+ * on the driver from all executors.
+ */
+trait WriteTaskStats extends Serializable
+
+
+/**
+ * A trait for classes that are capable of collecting statistics on data 
that's being processed by
+ * a single write task in [[FileFormatWriter]] - i.e. there should be one 
instance per executor.
+ *
+ * This trait is coupled with the way [[FileFormatWriter]] works, in the sense 
that its methods
+ * will be called according to how tuples are being written out to disk, 
namely in sorted order
+ * according to partitionValue(s), then bucketId.
+ *
+ * As such, a typical call scenario is:
+ *
+ * newPartition -> newBucket -> newFile -> newRow -.
+ *    ^        |______^___________^ ^         ^____|
+ *    |               |             |______________|
+ *    |               |____________________________|
+ *    |____________________________________________|
+ *
+ * newPartition and newBucket events are only triggered if the relation to be 
written out is
+ * partitioned and/or bucketed, respectively.
+ */
+trait WriteTaskStatsTracker {
+
+  /**
+   * Process the fact that a new partition is about to be written.
+   * Only triggered when the relation is partitioned by a (non-empty) sequence 
of columns.
+   * @param partitionValues The values that define this new partition.
+   */
+  def newPartition(partitionValues: InternalRow): Unit
+
+  /**
+   * Process the fact that a new bucket is about to written.
+   * Only triggered when the relation is bucketed by a (non-empty) sequence of 
columns.
+   * @param bucketId The bucket number.
+   */
+  def newBucket(bucketId: Int): Unit
+
+  /**
+   * Process the fact that a new file is about to be written.
+   * @param filePath Path of the file into which future rows will be written.
+   */
+  def newFile(filePath: String): Unit
+
+  /**
+   * Process the fact that a new row to update the tracked statistics 
accordingly.
+   * The row will be written to the most recently witnessed file (via 
`newFile`).
+   * @note Keep in mind that any overhead here is per-row, obviously,
+   *       so implementations should be as lightweight as possible.
+   * @param row Current data row to be processed.
+   */
+  def newRow(row: InternalRow): Unit
+
+  /**
+   * Returns the final statistics computed so far.
+   * @note This may only be called once. Further use of the object may lead to 
undefined behavior.
+   * @return An object of subtype of [[WriteTaskStats]], to be sent to the 
driver.
+   */
+  def getFinalStats(): WriteTaskStats
+}
+
+
+/**
+ * A class implementing this trait is basically a collection of parameters 
that are necessary
+ * for instantiating a (derived type of) [[WriteTaskStatsTracker]] on all 
executors and then
+ * process the statistics produced by them (e.g. save them to memory/disk, 
issue warnings, etc).
+ * It is therefore important that such an objects is [[Serializable]], as it 
will be sent
+ * from the driver to all executors.
+ */
+trait WriteJobStatsTracker extends Serializable {
+
+  /**
+   * Instantiates a [[WriteTaskStatsTracker]], based on (non-transient) 
members of this class.
+   * To be called by executors.
+   * @return A [[WriteTaskStatsTracker]] instance to be used for computing 
stats during a write task
+   */
+  def newTaskInstance(): WriteTaskStatsTracker
+
+  /**
+   * Process the given collection of stats computed during this job.
+   * E.g. aggregate them, write them to memory / disk, issue warnings, 
whatever.
+   * @param stats One [[WriteTaskStats]] object from each successful write 
task.
+   * @note The type of @param `stats` is too generic. These classes should 
probably be parametrized:
+   *   WriteTaskStatsTracker[S <: WriteTaskStats]
+   *   WriteJobStatsTracker[S <: WriteTaskStats, T <: WriteTaskStatsTracker[S]]
+   * and this would then be:
+   *   def processStats(stats: Seq[S]): Unit
+   * but then we wouldn't be able to have a Seq[WriteJobStatsTracker] due to 
type
+   * co-/contra-variance considerations. Instead, you may feel free to just 
cast `stats`
+   * to the expected derived type when implementing this method in a derived 
class.
+   * The framework will make sure to call this with the right arguments.
+   */
+  def processStats(stats: Seq[WriteTaskStats]): Unit
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/95ad960c/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
index 0ed2dbe..72e5ac4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
@@ -128,7 +128,7 @@ class FileStreamSink(
         hadoopConf = hadoopConf,
         partitionColumns = partitionColumns,
         bucketSpec = None,
-        refreshFunction = _ => (),
+        statsTrackers = Nil,
         options = options)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/95ad960c/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index b9461ad..b6f4898 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -31,7 +31,6 @@ import org.apache.hadoop.hive.ql.exec.TaskRunner
 import org.apache.hadoop.hive.ql.ErrorMsg
 import org.apache.hadoop.hive.ql.plan.TableDesc
 
-import org.apache.spark.SparkContext
 import org.apache.spark.internal.io.FileCommitProtocol
 import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
 import org.apache.spark.sql.catalyst.catalog.CatalogTable
@@ -40,7 +39,6 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.command.{CommandUtils, 
DataWritingCommand}
 import org.apache.spark.sql.execution.datasources.FileFormatWriter
-import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.hive._
 import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc}
 import org.apache.spark.sql.hive.client.{HiveClientImpl, HiveVersion}
@@ -355,7 +353,7 @@ case class InsertIntoHiveTable(
       hadoopConf = hadoopConf,
       partitionColumns = partitionAttributes,
       bucketSpec = None,
-      refreshFunction = updateWritingMetrics,
+      statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)),
       options = Map.empty)
 
     if (partition.nonEmpty) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to