Repository: spark
Updated Branches:
  refs/heads/master 695f0e919 -> bc3760d40


[SPARK-14237][SQL] De-duplicate partition value appending logic in various 
buildReader() implementations

## What changes were proposed in this pull request?

Currently, various `FileFormat` data sources share approximately the same code 
for partition value appending. This PR tries to eliminate this duplication.

A new method `buildReaderWithPartitionValues()` is added to `FileFormat` with a 
default implementation that appends partition values to `InternalRow`s produced 
by the reader function returned by `buildReader()`.

Special data sources like Parquet, which implements partition value appending 
inside `buildReader()` because of the vectorized reader, and the Text data 
source, which doesn't support partitioning, override 
`buildReaderWithPartitionValues()` and simply delegate to `buildReader()`.

This PR brings two benefits:

1. Apparently, it de-duplicates partition value appending logic

2. Now the reader function returned by `buildReader()` is only required to 
produce `InternalRow`s rather than `UnsafeRow`s if the data source doesn't 
override `buildReaderWithPartitionValues()`.

   Because the safe-to-unsafe conversion is also performed while appending 
partition values. This makes 3rd-party data sources (e.g. spark-avro) easier to 
implement since they no longer need to access private APIs involving 
`UnsafeRow`.

## How was this patch tested?

Existing tests should do the work.

Author: Cheng Lian <l...@databricks.com>

Closes #12866 from liancheng/spark-14237-simplify-partition-values-appending.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bc3760d4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bc3760d4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bc3760d4

Branch: refs/heads/master
Commit: bc3760d405cc8c3ffcd957b188afa8b7e3b1f824
Parents: 695f0e9
Author: Cheng Lian <l...@databricks.com>
Authored: Wed May 4 14:16:57 2016 +0800
Committer: Cheng Lian <l...@databricks.com>
Committed: Wed May 4 14:16:57 2016 +0800

----------------------------------------------------------------------
 .../spark/ml/source/libsvm/LibSVMRelation.scala | 17 +--------
 .../datasources/FileSourceStrategy.scala        |  2 +-
 .../datasources/csv/DefaultSource.scala         | 17 ++-------
 .../datasources/fileSourceInterfaces.scala      | 40 ++++++++++++++++++++
 .../datasources/json/JSONRelation.scala         | 10 +----
 .../datasources/parquet/ParquetRelation.scala   | 14 +++++++
 .../datasources/text/DefaultSource.scala        | 13 +++++++
 .../execution/datasources/csv/CSVSuite.scala    |  3 --
 .../apache/spark/sql/hive/orc/OrcRelation.scala | 11 +-----
 9 files changed, 74 insertions(+), 53 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bc3760d4/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala 
b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index ba2e1e2..5f78fab 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -204,25 +204,10 @@ class DefaultSource extends FileFormat with 
DataSourceRegister {
 
       val converter = RowEncoder(dataSchema)
 
-      val unsafeRowIterator = points.map { pt =>
+      points.map { pt =>
         val features = if (sparse) pt.features.toSparse else 
pt.features.toDense
         converter.toRow(Row(pt.label, features))
       }
-
-      def toAttribute(f: StructField): AttributeReference =
-        AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
-
-      // Appends partition values
-      val fullOutput = (dataSchema ++ partitionSchema).map(toAttribute)
-      val requiredOutput = fullOutput.filter { a =>
-        requiredSchema.fieldNames.contains(a.name) || 
partitionSchema.fieldNames.contains(a.name)
-      }
-      val joinedRow = new JoinedRow()
-      val appendPartitionColumns = 
GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
-
-      unsafeRowIterator.map { dataRow =>
-        appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
-      }
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bc3760d4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index 615906a..8a93c6f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -106,7 +106,7 @@ private[sql] object FileSourceStrategy extends Strategy 
with Logging {
       val pushedDownFilters = 
dataFilters.flatMap(DataSourceStrategy.translateFilter)
       logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}")
 
-      val readFile = files.fileFormat.buildReader(
+      val readFile = files.fileFormat.buildReaderWithPartitionValues(
         sparkSession = files.sparkSession,
         dataSchema = files.dataSchema,
         partitionSchema = files.partitionSchema,

http://git-wip-us.apache.org/repos/asf/spark/blob/bc3760d4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
index 75143e6..948fac0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
@@ -117,20 +117,9 @@ class DefaultSource extends FileFormat with 
DataSourceRegister {
 
       CSVRelation.dropHeaderLine(file, lineIterator, csvOptions)
 
-      val unsafeRowIterator = {
-        val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, 
headers)
-        val parser = CSVRelation.csvParser(dataSchema, 
requiredSchema.fieldNames, csvOptions)
-        tokenizedIterator.flatMap(parser(_).toSeq)
-      }
-
-      // Appends partition values
-      val fullOutput = requiredSchema.toAttributes ++ 
partitionSchema.toAttributes
-      val joinedRow = new JoinedRow()
-      val appendPartitionColumns = 
GenerateUnsafeProjection.generate(fullOutput, fullOutput)
-
-      unsafeRowIterator.map { dataRow =>
-        appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
-      }
+      val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, 
headers)
+      val parser = CSVRelation.csvParser(dataSchema, 
requiredSchema.fieldNames, csvOptions)
+      tokenizedIterator.flatMap(parser(_).toSeq)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bc3760d4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
index 0a34611..24e2bf6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
@@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, 
InternalRow}
 import org.apache.spark.sql.catalyst.expressions._
+import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
 import org.apache.spark.sql.execution.FileRelation
 import org.apache.spark.sql.sources.{BaseRelation, Filter}
 import org.apache.spark.sql.types.{StringType, StructType}
@@ -239,6 +240,45 @@ trait FileFormat {
   }
 
   /**
+   * Exactly the same as [[buildReader]] except that the reader function 
returned by this method
+   * appends partition values to [[InternalRow]]s produced by the reader 
function [[buildReader]]
+   * returns.
+   */
+  private[sql] def buildReaderWithPartitionValues(
+      sparkSession: SparkSession,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      requiredSchema: StructType,
+      filters: Seq[Filter],
+      options: Map[String, String],
+      hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
+    val dataReader = buildReader(
+      sparkSession, dataSchema, partitionSchema, requiredSchema, filters, 
options, hadoopConf)
+
+    new (PartitionedFile => Iterator[InternalRow]) with Serializable {
+      private val fullSchema = requiredSchema.toAttributes ++ 
partitionSchema.toAttributes
+
+      private val joinedRow = new JoinedRow()
+
+      // Using lazy val to avoid serialization
+      private lazy val appendPartitionColumns =
+        GenerateUnsafeProjection.generate(fullSchema, fullSchema)
+
+      override def apply(file: PartitionedFile): Iterator[InternalRow] = {
+        // Using local val to avoid per-row lazy val check (pre-mature 
optimization?...)
+        val converter = appendPartitionColumns
+
+        // Note that we have to apply the converter even though 
`file.partitionValues` is empty.
+        // This is because the converter is also responsible for converting 
safe `InternalRow`s into
+        // `UnsafeRow`s.
+        dataReader(file).map { dataRow =>
+          converter(joinedRow(dataRow, file.partitionValues))
+        }
+      }
+    }
+  }
+
+  /**
    * Returns a [[OutputWriterFactory]] for generating output writers that can 
write data.
    * This method is current used only by FileStreamSinkWriter to generate 
output writers that
    * does not use output committers to write data. The OutputWriter generated 
by the returned

http://git-wip-us.apache.org/repos/asf/spark/blob/bc3760d4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
index 6244658..4c97abe 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
@@ -106,22 +106,14 @@ class DefaultSource extends FileFormat with 
DataSourceRegister {
     val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord
       .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
 
-    val fullSchema = requiredSchema.toAttributes ++ 
partitionSchema.toAttributes
-    val joinedRow = new JoinedRow()
-
     (file: PartitionedFile) => {
       val lines = new HadoopFileLinesReader(file, 
broadcastedHadoopConf.value.value).map(_.toString)
 
-      val rows = JacksonParser.parseJson(
+      JacksonParser.parseJson(
         lines,
         requiredSchema,
         columnNameOfCorruptRecord,
         parsedOptions)
-
-      val appendPartitionColumns = 
GenerateUnsafeProjection.generate(fullSchema, fullSchema)
-      rows.map { row =>
-        appendPartitionColumns(joinedRow(row, file.partitionValues))
-      }
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bc3760d4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index 79185df..cf5c8e9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -255,6 +255,20 @@ private[sql] class DefaultSource
       schema.forall(_.dataType.isInstanceOf[AtomicType])
   }
 
+  override private[sql] def buildReaderWithPartitionValues(
+      sparkSession: SparkSession,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      requiredSchema: StructType,
+      filters: Seq[Filter],
+      options: Map[String, String],
+      hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = 
{
+    // For Parquet data source, `buildReader` already handles partition values 
appending. Here we
+    // simply delegate to `buildReader`.
+    buildReader(
+      sparkSession, dataSchema, partitionSchema, requiredSchema, filters, 
options, hadoopConf)
+  }
+
   override def buildReader(
       sparkSession: SparkSession,
       dataSchema: StructType,

http://git-wip-us.apache.org/repos/asf/spark/blob/bc3760d4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
index 348edfc..f22c024 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
@@ -83,6 +83,19 @@ class DefaultSource extends FileFormat with 
DataSourceRegister {
     }
   }
 
+  override private[sql] def buildReaderWithPartitionValues(
+      sparkSession: SparkSession,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      requiredSchema: StructType,
+      filters: Seq[Filter],
+      options: Map[String, String],
+      hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = 
{
+    // Text data source doesn't support partitioning. Here we simply delegate 
to `buildReader`.
+    buildReader(
+      sparkSession, dataSchema, partitionSchema, requiredSchema, filters, 
options, hadoopConf)
+  }
+
   override def buildReader(
       sparkSession: SparkSession,
       dataSchema: StructType,

http://git-wip-us.apache.org/repos/asf/spark/blob/bc3760d4/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 07f00a0..28e5905 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -22,9 +22,6 @@ import java.nio.charset.UnsupportedCharsetException
 import java.sql.{Date, Timestamp}
 import java.text.SimpleDateFormat
 
-import scala.collection.JavaConverters._
-
-import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.io.SequenceFile.CompressionType
 import org.apache.hadoop.io.compress.GzipCodec
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bc3760d4/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
index d6a847f..89d258e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
@@ -157,20 +157,11 @@ private[sql] class DefaultSource
         }
 
         // Unwraps `OrcStruct`s to `UnsafeRow`s
-        val unsafeRowIterator = OrcRelation.unwrapOrcStructs(
+        OrcRelation.unwrapOrcStructs(
           conf,
           requiredSchema,
           
Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]),
           new RecordReaderIterator[OrcStruct](orcRecordReader))
-
-        // Appends partition values
-        val fullOutput = requiredSchema.toAttributes ++ 
partitionSchema.toAttributes
-        val joinedRow = new JoinedRow()
-        val appendPartitionColumns = 
GenerateUnsafeProjection.generate(fullOutput, fullOutput)
-
-        unsafeRowIterator.map { dataRow =>
-          appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
-        }
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to