Repository: spark
Updated Branches:
  refs/heads/master 20a89478e -> 8d05a7a98


[SPARK-10216][SQL] Avoid creating empty files during overwriting with group by 
query

## What changes were proposed in this pull request?

Currently, `INSERT INTO` with `GROUP BY` query tries to make at least 200 files 
(default value of `spark.sql.shuffle.partition`), which results in lots of 
empty files.

This PR makes it avoid creating empty files during overwriting into Hive table 
and in internal data sources  with group by query.

This checks whether the given partition has data in it or not and 
creates/writes file only when it actually has data.

## How was this patch tested?

Unittests in `InsertIntoHiveTableSuite` and `HadoopFsRelationTest`.

Closes #8411

Author: hyukjinkwon <gurwls...@gmail.com>
Author: Keuntae Park <sir...@apache.org>

Closes #12855 from HyukjinKwon/pr/8411.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8d05a7a9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8d05a7a9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8d05a7a9

Branch: refs/heads/master
Commit: 8d05a7a98bdbd3ce7c81d273e05a375877ebe68f
Parents: 20a8947
Author: hyukjinkwon <gurwls...@gmail.com>
Authored: Tue May 17 11:18:51 2016 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Tue May 17 11:18:51 2016 -0700

----------------------------------------------------------------------
 .../execution/datasources/WriterContainer.scala | 221 ++++++++++---------
 .../spark/sql/hive/hiveWriterContainers.scala   |  24 +-
 .../sql/hive/InsertIntoHiveTableSuite.scala     |  41 +++-
 .../sql/sources/HadoopFsRelationTest.scala      |  22 +-
 4 files changed, 182 insertions(+), 126 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8d05a7a9/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index 3b064a5..7e12bbb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -239,48 +239,50 @@ private[sql] class DefaultWriterContainer(
   extends BaseWriterContainer(relation, job, isAppend) {
 
   def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): 
Unit = {
-    executorSideSetup(taskContext)
-    val configuration = taskAttemptContext.getConfiguration
-    configuration.set("spark.sql.sources.output.path", outputPath)
-    var writer = newOutputWriter(getWorkPath)
-    writer.initConverter(dataSchema)
-
-    // If anything below fails, we should abort the task.
-    try {
-      Utils.tryWithSafeFinallyAndFailureCallbacks {
-        while (iterator.hasNext) {
-          val internalRow = iterator.next()
-          writer.writeInternal(internalRow)
-        }
-        commitTask()
-      }(catchBlock = abortTask())
-    } catch {
-      case t: Throwable =>
-        throw new SparkException("Task failed while writing rows", t)
-    }
+    if (iterator.hasNext) {
+      executorSideSetup(taskContext)
+      val configuration = taskAttemptContext.getConfiguration
+      configuration.set("spark.sql.sources.output.path", outputPath)
+      var writer = newOutputWriter(getWorkPath)
+      writer.initConverter(dataSchema)
 
-    def commitTask(): Unit = {
+      // If anything below fails, we should abort the task.
       try {
-        if (writer != null) {
-          writer.close()
-          writer = null
-        }
-        super.commitTask()
+        Utils.tryWithSafeFinallyAndFailureCallbacks {
+          while (iterator.hasNext) {
+            val internalRow = iterator.next()
+            writer.writeInternal(internalRow)
+          }
+          commitTask()
+        }(catchBlock = abortTask())
       } catch {
-        case cause: Throwable =>
-          // This exception will be handled in 
`InsertIntoHadoopFsRelation.insert$writeRows`, and
-          // will cause `abortTask()` to be invoked.
-          throw new RuntimeException("Failed to commit task", cause)
+        case t: Throwable =>
+          throw new SparkException("Task failed while writing rows", t)
       }
-    }
 
-    def abortTask(): Unit = {
-      try {
-        if (writer != null) {
-          writer.close()
+      def commitTask(): Unit = {
+        try {
+          if (writer != null) {
+            writer.close()
+            writer = null
+          }
+          super.commitTask()
+        } catch {
+          case cause: Throwable =>
+            // This exception will be handled in 
`InsertIntoHadoopFsRelation.insert$writeRows`, and
+            // will cause `abortTask()` to be invoked.
+            throw new RuntimeException("Failed to commit task", cause)
+        }
+      }
+
+      def abortTask(): Unit = {
+        try {
+          if (writer != null) {
+            writer.close()
+          }
+        } finally {
+          super.abortTask()
         }
-      } finally {
-        super.abortTask()
       }
     }
   }
@@ -363,84 +365,87 @@ private[sql] class DynamicPartitionWriterContainer(
   }
 
   def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): 
Unit = {
-    executorSideSetup(taskContext)
-
-    // We should first sort by partition columns, then bucket id, and finally 
sorting columns.
-    val sortingExpressions: Seq[Expression] = partitionColumns ++ 
bucketIdExpression ++ sortColumns
-    val getSortingKey = UnsafeProjection.create(sortingExpressions, 
inputSchema)
-
-    val sortingKeySchema = StructType(sortingExpressions.map {
-      case a: Attribute => StructField(a.name, a.dataType, a.nullable)
-      // The sorting expressions are all `Attribute` except bucket id.
-      case _ => StructField("bucketId", IntegerType, nullable = false)
-    })
-
-    // Returns the data columns to be written given an input row
-    val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
-
-    // Returns the partition path given a partition key.
-    val getPartitionString =
-      UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, 
partitionColumns)
-
-    // Sorts the data before write, so that we only need one writer at the 
same time.
-    // TODO: inject a local sort operator in planning.
-    val sorter = new UnsafeKVExternalSorter(
-      sortingKeySchema,
-      StructType.fromAttributes(dataColumns),
-      SparkEnv.get.blockManager,
-      SparkEnv.get.serializerManager,
-      TaskContext.get().taskMemoryManager().pageSizeBytes)
-
-    while (iterator.hasNext) {
-      val currentRow = iterator.next()
-      sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
-    }
-    logInfo(s"Sorting complete. Writing out partition files one at a time.")
-
-    val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) 
{
-      identity
-    } else {
-      
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map
 {
-        case (expr, ordinal) => BoundReference(ordinal, expr.dataType, 
expr.nullable)
+    if (iterator.hasNext) {
+      executorSideSetup(taskContext)
+
+      // We should first sort by partition columns, then bucket id, and 
finally sorting columns.
+      val sortingExpressions: Seq[Expression] =
+        partitionColumns ++ bucketIdExpression ++ sortColumns
+      val getSortingKey = UnsafeProjection.create(sortingExpressions, 
inputSchema)
+
+      val sortingKeySchema = StructType(sortingExpressions.map {
+        case a: Attribute => StructField(a.name, a.dataType, a.nullable)
+        // The sorting expressions are all `Attribute` except bucket id.
+        case _ => StructField("bucketId", IntegerType, nullable = false)
       })
-    }
 
-    val sortedIterator = sorter.sortedIterator()
+      // Returns the data columns to be written given an input row
+      val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
+
+      // Returns the partition path given a partition key.
+      val getPartitionString =
+        UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, 
partitionColumns)
+
+      // Sorts the data before write, so that we only need one writer at the 
same time.
+      // TODO: inject a local sort operator in planning.
+      val sorter = new UnsafeKVExternalSorter(
+        sortingKeySchema,
+        StructType.fromAttributes(dataColumns),
+        SparkEnv.get.blockManager,
+        SparkEnv.get.serializerManager,
+        TaskContext.get().taskMemoryManager().pageSizeBytes)
+
+      while (iterator.hasNext) {
+        val currentRow = iterator.next()
+        sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
+      }
+      logInfo(s"Sorting complete. Writing out partition files one at a time.")
+
+      val getBucketingKey: InternalRow => InternalRow = if 
(sortColumns.isEmpty) {
+        identity
+      } else {
+        
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map
 {
+          case (expr, ordinal) => BoundReference(ordinal, expr.dataType, 
expr.nullable)
+        })
+      }
 
-    // If anything below fails, we should abort the task.
-    var currentWriter: OutputWriter = null
-    try {
-      Utils.tryWithSafeFinallyAndFailureCallbacks {
-        var currentKey: UnsafeRow = null
-        while (sortedIterator.next()) {
-          val nextKey = 
getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
-          if (currentKey != nextKey) {
-            if (currentWriter != null) {
-              currentWriter.close()
-              currentWriter = null
-            }
-            currentKey = nextKey.copy()
-            logDebug(s"Writing partition: $currentKey")
+      val sortedIterator = sorter.sortedIterator()
 
-            currentWriter = newOutputWriter(currentKey, getPartitionString)
+      // If anything below fails, we should abort the task.
+      var currentWriter: OutputWriter = null
+      try {
+        Utils.tryWithSafeFinallyAndFailureCallbacks {
+          var currentKey: UnsafeRow = null
+          while (sortedIterator.next()) {
+            val nextKey = 
getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
+            if (currentKey != nextKey) {
+              if (currentWriter != null) {
+                currentWriter.close()
+                currentWriter = null
+              }
+              currentKey = nextKey.copy()
+              logDebug(s"Writing partition: $currentKey")
+
+              currentWriter = newOutputWriter(currentKey, getPartitionString)
+            }
+            currentWriter.writeInternal(sortedIterator.getValue)
+          }
+          if (currentWriter != null) {
+            currentWriter.close()
+            currentWriter = null
           }
-          currentWriter.writeInternal(sortedIterator.getValue)
-        }
-        if (currentWriter != null) {
-          currentWriter.close()
-          currentWriter = null
-        }
 
-        commitTask()
-      }(catchBlock = {
-        if (currentWriter != null) {
-          currentWriter.close()
-        }
-        abortTask()
-      })
-    } catch {
-      case t: Throwable =>
-        throw new SparkException("Task failed while writing rows", t)
+          commitTask()
+        }(catchBlock = {
+          if (currentWriter != null) {
+            currentWriter.close()
+          }
+          abortTask()
+        })
+      } catch {
+        case t: Throwable =>
+          throw new SparkException("Task failed while writing rows", t)
+      }
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8d05a7a9/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
index 794fe26..706fdbc 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
@@ -178,19 +178,21 @@ private[hive] class SparkHiveWriterContainer(
 
   // this function is executed on executor side
   def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit 
= {
-    val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = 
prepareForWrite()
-    executorSideSetup(context.stageId, context.partitionId, 
context.attemptNumber)
-
-    iterator.foreach { row =>
-      var i = 0
-      while (i < fieldOIs.length) {
-        outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, 
dataTypes(i)))
-        i += 1
+    if (iterator.hasNext) {
+      val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) 
= prepareForWrite()
+      executorSideSetup(context.stageId, context.partitionId, 
context.attemptNumber)
+
+      iterator.foreach { row =>
+        var i = 0
+        while (i < fieldOIs.length) {
+          outputData(i) = if (row.isNullAt(i)) null else 
wrappers(i)(row.get(i, dataTypes(i)))
+          i += 1
+        }
+        writer.write(serializer.serialize(outputData, standardOI))
       }
-      writer.write(serializer.serialize(outputData, standardOI))
-    }
 
-    close()
+      close()
+    }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8d05a7a9/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 82d3e49..883cdac 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -19,13 +19,13 @@ package org.apache.spark.sql.hive
 
 import java.io.File
 
-import org.apache.hadoop.hive.conf.HiveConf
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.SparkException
-import org.apache.spark.sql.{QueryTest, _}
+import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
 import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SQLTestUtils
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
@@ -118,10 +118,10 @@ class InsertIntoHiveTableSuite extends QueryTest with 
TestHiveSingleton with Bef
 
     sql(
       s"""
-         |CREATE TABLE table_with_partition(c1 string)
-         |PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string)
-         |location '${tmpDir.toURI.toString}'
-        """.stripMargin)
+        |CREATE TABLE table_with_partition(c1 string)
+        |PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string)
+        |location '${tmpDir.toURI.toString}'
+      """.stripMargin)
     sql(
       """
         |INSERT OVERWRITE TABLE table_with_partition
@@ -216,6 +216,35 @@ class InsertIntoHiveTableSuite extends QueryTest with 
TestHiveSingleton with Bef
     sql("DROP TABLE hiveTableWithStructValue")
   }
 
+  test("SPARK-10216: Avoid empty files during overwrite into Hive table with 
group by query") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      val testDataset = hiveContext.sparkContext.parallelize(
+        (1 to 2).map(i => TestData(i, i.toString))).toDF()
+      testDataset.createOrReplaceTempView("testDataset")
+
+      val tmpDir = Utils.createTempDir()
+      sql(
+        s"""
+          |CREATE TABLE table1(key int,value string)
+          |location '${tmpDir.toURI.toString}'
+        """.stripMargin)
+      sql(
+        """
+          |INSERT OVERWRITE TABLE table1
+          |SELECT count(key), value FROM testDataset GROUP BY value
+        """.stripMargin)
+
+      val overwrittenFiles = tmpDir.listFiles()
+        .filter(f => f.isFile && !f.getName.endsWith(".crc"))
+        .sortBy(_.getName)
+      val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0)
+
+      assert(overwrittenFiles === overwrittenFilesWithoutEmpty)
+
+      sql("DROP TABLE table1")
+    }
+  }
+
   test("Reject partitioning that does not match table") {
     withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
       sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY 
(part string)")

http://git-wip-us.apache.org/repos/asf/spark/blob/8d05a7a9/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
index f4d6333..78d2dc2 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
@@ -29,7 +29,7 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.sql._
 import org.apache.spark.sql.execution.DataSourceScanExec
-import org.apache.spark.sql.execution.datasources.{FileScanRDD, 
HadoopFsRelation, LocalityTestFileSystem, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.{FileScanRDD, 
LocalityTestFileSystem}
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SQLTestUtils
@@ -879,6 +879,26 @@ abstract class HadoopFsRelationTest extends QueryTest with 
SQLTestUtils with Tes
       }
     }
   }
+
+  test("SPARK-10216: Avoid empty files during overwriting with group by 
query") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      withTempPath { path =>
+        val df = spark.range(0, 5)
+        val groupedDF = df.groupBy("id").count()
+        groupedDF.write
+          .format(dataSourceName)
+          .mode(SaveMode.Overwrite)
+          .save(path.getCanonicalPath)
+
+        val overwrittenFiles = path.listFiles()
+          .filter(f => f.isFile && !f.getName.startsWith(".") && 
!f.getName.startsWith("_"))
+          .sortBy(_.getName)
+        val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 
0)
+
+        assert(overwrittenFiles === overwrittenFilesWithoutEmpty)
+      }
+    }
+  }
 }
 
 // This class is used to test SPARK-8578. We should not use any custom output 
committer when


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to