Repository: spark
Updated Branches:
  refs/heads/master a00181418 -> 39872af88


[SPARK-25684][SQL] Organize header related codes in CSV datasource

## What changes were proposed in this pull request?

1. Move `CSVDataSource.makeSafeHeader` to `CSVUtils.makeSafeHeader` (as is).

    - Historically and at the first place of refactoring (which I did), I 
intended to put all CSV specific handling (like options), filtering, extracting 
header, etc.

    - See `JsonDataSource`. Now `CSVDataSource` is quite consistent with 
`JsonDataSource`. Since CSV's code path is quite complicated, we might better 
match them as possible as we can.

2. Create `CSVHeaderChecker` and put `enforceSchema` logics into that.

    - The checking header and column pruning stuff were added (per 
https://github.com/apache/spark/pull/20894 and 
https://github.com/apache/spark/pull/21296) but some of codes such as 
https://github.com/apache/spark/pull/22123 are duplicated

    - Also, checking header code is basically here and there. We better put 
them in a single place, which was quite error-prone. See 
(https://github.com/apache/spark/pull/22656).

3. Move `CSVDataSource.checkHeaderColumnNames` to 
`CSVHeaderChecker.checkHeaderColumnNames` (as is).

    - Similar reasons above with 1.

## How was this patch tested?

Existing tests should cover this.

Closes #22676 from HyukjinKwon/refactoring-csv.

Authored-by: hyukjinkwon <gurwls...@apache.org>
Signed-off-by: hyukjinkwon <gurwls...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/39872af8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/39872af8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/39872af8

Branch: refs/heads/master
Commit: 39872af882e3d73667acfab93c9de962c9c8939d
Parents: a001814
Author: hyukjinkwon <gurwls...@apache.org>
Authored: Fri Oct 12 09:16:41 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Fri Oct 12 09:16:41 2018 +0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/DataFrameReader.scala  |  18 +--
 .../datasources/csv/CSVDataSource.scala         | 161 ++-----------------
 .../datasources/csv/CSVFileFormat.scala         |  11 +-
 .../datasources/csv/CSVHeaderChecker.scala      | 131 +++++++++++++++
 .../execution/datasources/csv/CSVUtils.scala    |  44 ++++-
 .../datasources/csv/UnivocityParser.scala       |  34 ++--
 6 files changed, 217 insertions(+), 182 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/39872af8/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 7269446..3af70b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -505,20 +505,14 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
     val actualSchema =
       StructType(schema.filterNot(_.name == 
parsedOptions.columnNameOfCorruptRecord))
 
-    val linesWithoutHeader = if (parsedOptions.headerFlag && 
maybeFirstLine.isDefined) {
-      val firstLine = maybeFirstLine.get
-      val parser = new CsvParser(parsedOptions.asParserSettings)
-      val columnNames = parser.parseLine(firstLine)
-      CSVDataSource.checkHeaderColumnNames(
+    val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
+      val headerChecker = new CSVHeaderChecker(
         actualSchema,
-        columnNames,
-        csvDataset.getClass.getCanonicalName,
-        parsedOptions.enforceSchema,
-        sparkSession.sessionState.conf.caseSensitiveAnalysis)
+        parsedOptions,
+        source = s"CSV source: $csvDataset")
+      headerChecker.checkHeaderColumnNames(firstLine)
       filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, 
parsedOptions))
-    } else {
-      filteredLines.rdd
-    }
+    }.getOrElse(filteredLines.rdd)
 
     val parsed = linesWithoutHeader.mapPartitions { iter =>
       val rawParser = new UnivocityParser(actualSchema, parsedOptions)

http://git-wip-us.apache.org/repos/asf/spark/blob/39872af8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index b93f418..0b5a719 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -51,11 +51,8 @@ abstract class CSVDataSource extends Serializable {
       conf: Configuration,
       file: PartitionedFile,
       parser: UnivocityParser,
-      requiredSchema: StructType,
-      // Actual schema of data in the csv file
-      dataSchema: StructType,
-      caseSensitive: Boolean,
-      columnPruning: Boolean): Iterator[InternalRow]
+      headerChecker: CSVHeaderChecker,
+      requiredSchema: StructType): Iterator[InternalRow]
 
   /**
    * Infers the schema from `inputPaths` files.
@@ -75,48 +72,6 @@ abstract class CSVDataSource extends Serializable {
       sparkSession: SparkSession,
       inputPaths: Seq[FileStatus],
       parsedOptions: CSVOptions): StructType
-
-  /**
-   * Generates a header from the given row which is null-safe and 
duplicate-safe.
-   */
-  protected def makeSafeHeader(
-      row: Array[String],
-      caseSensitive: Boolean,
-      options: CSVOptions): Array[String] = {
-    if (options.headerFlag) {
-      val duplicates = {
-        val headerNames = row.filter(_ != null)
-          // scalastyle:off caselocale
-          .map(name => if (caseSensitive) name else name.toLowerCase)
-          // scalastyle:on caselocale
-        headerNames.diff(headerNames.distinct).distinct
-      }
-
-      row.zipWithIndex.map { case (value, index) =>
-        if (value == null || value.isEmpty || value == options.nullValue) {
-          // When there are empty strings or the values set in `nullValue`, 
put the
-          // index as the suffix.
-          s"_c$index"
-          // scalastyle:off caselocale
-        } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
-          // scalastyle:on caselocale
-          // When there are case-insensitive duplicates, put the index as the 
suffix.
-          s"$value$index"
-        } else if (duplicates.contains(value)) {
-          // When there are duplicates, put the index as the suffix.
-          s"$value$index"
-        } else {
-          value
-        }
-      }
-    } else {
-      row.zipWithIndex.map { case (_, index) =>
-        // Uses default column names, "_c#" where # is its position of fields
-        // when header option is disabled.
-        s"_c$index"
-      }
-    }
-  }
 }
 
 object CSVDataSource extends Logging {
@@ -127,67 +82,6 @@ object CSVDataSource extends Logging {
       TextInputCSVDataSource
     }
   }
-
-  /**
-   * Checks that column names in a CSV header and field names in the schema 
are the same
-   * by taking into account case sensitivity.
-   *
-   * @param schema - provided (or inferred) schema to which CSV must conform.
-   * @param columnNames - names of CSV columns that must be checked against to 
the schema.
-   * @param fileName - name of CSV file that are currently checked. It is used 
in error messages.
-   * @param enforceSchema - if it is `true`, column names are ignored 
otherwise the CSV column
-   *                        names are checked for conformance to the schema. 
In the case if
-   *                        the column name don't conform to the schema, an 
exception is thrown.
-   * @param caseSensitive - if it is set to `false`, comparison of column 
names and schema field
-   *                        names is not case sensitive.
-   */
-  def checkHeaderColumnNames(
-      schema: StructType,
-      columnNames: Array[String],
-      fileName: String,
-      enforceSchema: Boolean,
-      caseSensitive: Boolean): Unit = {
-    if (columnNames != null) {
-      val fieldNames = schema.map(_.name).toIndexedSeq
-      val (headerLen, schemaSize) = (columnNames.size, fieldNames.length)
-      var errorMessage: Option[String] = None
-
-      if (headerLen == schemaSize) {
-        var i = 0
-        while (errorMessage.isEmpty && i < headerLen) {
-          var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i))
-          if (!caseSensitive) {
-            // scalastyle:off caselocale
-            nameInSchema = nameInSchema.toLowerCase
-            nameInHeader = nameInHeader.toLowerCase
-            // scalastyle:on caselocale
-          }
-          if (nameInHeader != nameInSchema) {
-            errorMessage = Some(
-              s"""|CSV header does not conform to the schema.
-                  | Header: ${columnNames.mkString(", ")}
-                  | Schema: ${fieldNames.mkString(", ")}
-                  |Expected: ${fieldNames(i)} but found: ${columnNames(i)}
-                  |CSV file: $fileName""".stripMargin)
-          }
-          i += 1
-        }
-      } else {
-        errorMessage = Some(
-          s"""|Number of column in CSV header is not equal to number of fields 
in the schema:
-              | Header length: $headerLen, schema size: $schemaSize
-              |CSV file: $fileName""".stripMargin)
-      }
-
-      errorMessage.foreach { msg =>
-        if (enforceSchema) {
-          logWarning(msg)
-        } else {
-          throw new IllegalArgumentException(msg)
-        }
-      }
-    }
-  }
 }
 
 object TextInputCSVDataSource extends CSVDataSource {
@@ -197,10 +91,8 @@ object TextInputCSVDataSource extends CSVDataSource {
       conf: Configuration,
       file: PartitionedFile,
       parser: UnivocityParser,
-      requiredSchema: StructType,
-      dataSchema: StructType,
-      caseSensitive: Boolean,
-      columnPruning: Boolean): Iterator[InternalRow] = {
+      headerChecker: CSVHeaderChecker,
+      requiredSchema: StructType): Iterator[InternalRow] = {
     val lines = {
       val linesReader = new HadoopFileLinesReader(file, conf)
       Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => 
linesReader.close()))
@@ -209,25 +101,7 @@ object TextInputCSVDataSource extends CSVDataSource {
       }
     }
 
-    val hasHeader = parser.options.headerFlag && file.start == 0
-    if (hasHeader) {
-      // Checking that column names in the header are matched to field names 
of the schema.
-      // The header will be removed from lines.
-      // Note: if there are only comments in the first block, the header would 
probably
-      // be not extracted.
-      CSVUtils.extractHeader(lines, parser.options).foreach { header =>
-        val schema = if (columnPruning) requiredSchema else dataSchema
-        val columnNames = parser.tokenizer.parseLine(header)
-        CSVDataSource.checkHeaderColumnNames(
-          schema,
-          columnNames,
-          file.filePath,
-          parser.options.enforceSchema,
-          caseSensitive)
-      }
-    }
-
-    UnivocityParser.parseIterator(lines, parser, requiredSchema)
+    UnivocityParser.parseIterator(lines, parser, headerChecker, requiredSchema)
   }
 
   override def infer(
@@ -251,7 +125,7 @@ object TextInputCSVDataSource extends CSVDataSource {
     maybeFirstLine.map(csvParser.parseLine(_)) match {
       case Some(firstRow) if firstRow != null =>
         val caseSensitive = 
sparkSession.sessionState.conf.caseSensitiveAnalysis
-        val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
+        val header = CSVUtils.makeSafeHeader(firstRow, caseSensitive, 
parsedOptions)
         val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions)
         val tokenRDD = sampled.rdd.mapPartitions { iter =>
           val filteredLines = CSVUtils.filterCommentAndEmpty(iter, 
parsedOptions)
@@ -298,26 +172,13 @@ object MultiLineCSVDataSource extends CSVDataSource {
       conf: Configuration,
       file: PartitionedFile,
       parser: UnivocityParser,
-      requiredSchema: StructType,
-      dataSchema: StructType,
-      caseSensitive: Boolean,
-      columnPruning: Boolean): Iterator[InternalRow] = {
-    def checkHeader(header: Array[String]): Unit = {
-      val schema = if (columnPruning) requiredSchema else dataSchema
-      CSVDataSource.checkHeaderColumnNames(
-        schema,
-        header,
-        file.filePath,
-        parser.options.enforceSchema,
-        caseSensitive)
-    }
-
+      headerChecker: CSVHeaderChecker,
+      requiredSchema: StructType): Iterator[InternalRow] = {
     UnivocityParser.parseStream(
       CodecStreams.createInputStreamWithCloseResource(conf, new Path(new 
URI(file.filePath))),
-      parser.options.headerFlag,
       parser,
-      requiredSchema,
-      checkHeader)
+      headerChecker,
+      requiredSchema)
   }
 
   override def infer(
@@ -334,7 +195,7 @@ object MultiLineCSVDataSource extends CSVDataSource {
     }.take(1).headOption match {
       case Some(firstRow) =>
         val caseSensitive = 
sparkSession.sessionState.conf.caseSensitiveAnalysis
-        val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
+        val header = CSVUtils.makeSafeHeader(firstRow, caseSensitive, 
parsedOptions)
         val tokenRDD = csv.flatMap { lines =>
           UnivocityParser.tokenizeStream(
             CodecStreams.createInputStreamWithCloseResource(

http://git-wip-us.apache.org/repos/asf/spark/blob/39872af8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 9aad0bd..3de1c2d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -130,7 +130,6 @@ class CSVFileFormat extends TextBasedFileFormat with 
DataSourceRegister {
           "df.filter($\"_corrupt_record\".isNotNull).count()."
       )
     }
-    val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
     val columnPruning = sparkSession.sessionState.conf.csvColumnPruning
 
     (file: PartitionedFile) => {
@@ -139,14 +138,16 @@ class CSVFileFormat extends TextBasedFileFormat with 
DataSourceRegister {
         StructType(dataSchema.filterNot(_.name == 
parsedOptions.columnNameOfCorruptRecord)),
         StructType(requiredSchema.filterNot(_.name == 
parsedOptions.columnNameOfCorruptRecord)),
         parsedOptions)
+      val schema = if (columnPruning) requiredSchema else dataSchema
+      val isStartOfFile = file.start == 0
+      val headerChecker = new CSVHeaderChecker(
+        schema, parsedOptions, source = s"CSV file: ${file.filePath}", 
isStartOfFile)
       CSVDataSource(parsedOptions).readFile(
         conf,
         file,
         parser,
-        requiredSchema,
-        dataSchema,
-        caseSensitive,
-        columnPruning)
+        headerChecker,
+        requiredSchema)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/39872af8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala
new file mode 100644
index 0000000..558ee91
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import com.univocity.parsers.csv.CsvParser
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Checks that column names in a CSV header and field names in the schema are 
the same
+ * by taking into account case sensitivity.
+ *
+ * @param schema provided (or inferred) schema to which CSV must conform.
+ * @param options parsed CSV options.
+ * @param source name of CSV source that are currently checked. It is used in 
error messages.
+ * @param isStartOfFile indicates if the currently processing partition is the 
start of the file.
+ *                      if unknown or not applicable (for instance when the 
input is a dataset),
+ *                      can be omitted.
+ */
+class CSVHeaderChecker(
+    schema: StructType,
+    options: CSVOptions,
+    source: String,
+    isStartOfFile: Boolean = false) extends Logging {
+
+  // Indicates if it is set to `false`, comparison of column names and schema 
field
+  // names is not case sensitive.
+  private val caseSensitive = SQLConf.get.caseSensitiveAnalysis
+
+  // Indicates if it is `true`, column names are ignored otherwise the CSV 
column
+  // names are checked for conformance to the schema. In the case if
+  // the column name don't conform to the schema, an exception is thrown.
+  private val enforceSchema = options.enforceSchema
+
+  /**
+   * Checks that column names in a CSV header and field names in the schema 
are the same
+   * by taking into account case sensitivity.
+   *
+   * @param columnNames names of CSV columns that must be checked against to 
the schema.
+   */
+  private def checkHeaderColumnNames(columnNames: Array[String]): Unit = {
+    if (columnNames != null) {
+      val fieldNames = schema.map(_.name).toIndexedSeq
+      val (headerLen, schemaSize) = (columnNames.size, fieldNames.length)
+      var errorMessage: Option[String] = None
+
+      if (headerLen == schemaSize) {
+        var i = 0
+        while (errorMessage.isEmpty && i < headerLen) {
+          var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i))
+          if (!caseSensitive) {
+            // scalastyle:off caselocale
+            nameInSchema = nameInSchema.toLowerCase
+            nameInHeader = nameInHeader.toLowerCase
+            // scalastyle:on caselocale
+          }
+          if (nameInHeader != nameInSchema) {
+            errorMessage = Some(
+              s"""|CSV header does not conform to the schema.
+                  | Header: ${columnNames.mkString(", ")}
+                  | Schema: ${fieldNames.mkString(", ")}
+                  |Expected: ${fieldNames(i)} but found: ${columnNames(i)}
+                  |$source""".stripMargin)
+          }
+          i += 1
+        }
+      } else {
+        errorMessage = Some(
+          s"""|Number of column in CSV header is not equal to number of fields 
in the schema:
+              | Header length: $headerLen, schema size: $schemaSize
+              |$source""".stripMargin)
+      }
+
+      errorMessage.foreach { msg =>
+        if (enforceSchema) {
+          logWarning(msg)
+        } else {
+          throw new IllegalArgumentException(msg)
+        }
+      }
+    }
+  }
+
+  // This is currently only used to parse CSV from Dataset[String].
+  def checkHeaderColumnNames(line: String): Unit = {
+    if (options.headerFlag) {
+      val parser = new CsvParser(options.asParserSettings)
+      checkHeaderColumnNames(parser.parseLine(line))
+    }
+  }
+
+  // This is currently only used to parse CSV with multiLine mode.
+  private[csv] def checkHeaderColumnNames(tokenizer: CsvParser): Unit = {
+    assert(options.multiLine, "This method should be executed with multiLine.")
+    if (options.headerFlag) {
+      val firstRecord = tokenizer.parseNext()
+      checkHeaderColumnNames(firstRecord)
+    }
+  }
+
+  // This is currently only used to parse CSV with non-multiLine mode.
+  private[csv] def checkHeaderColumnNames(lines: Iterator[String], tokenizer: 
CsvParser): Unit = {
+    assert(!options.multiLine, "This method should not be executed with 
multiline.")
+    // Checking that column names in the header are matched to field names of 
the schema.
+    // The header will be removed from lines.
+    // Note: if there are only comments in the first block, the header would 
probably
+    // be not extracted.
+    if (options.headerFlag && isStartOfFile) {
+      CSVUtils.extractHeader(lines, options).foreach { header =>
+        checkHeaderColumnNames(tokenizer.parseLine(header))
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/39872af8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
index 7ce65fa..b912f8a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.datasources.csv
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.Dataset
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types._
 
 object CSVUtils {
   /**
@@ -90,6 +89,49 @@ object CSVUtils {
       None
     }
   }
+
+  /**
+   * Generates a header from the given row which is null-safe and 
duplicate-safe.
+   */
+  def makeSafeHeader(
+      row: Array[String],
+      caseSensitive: Boolean,
+      options: CSVOptions): Array[String] = {
+    if (options.headerFlag) {
+      val duplicates = {
+        val headerNames = row.filter(_ != null)
+          // scalastyle:off caselocale
+          .map(name => if (caseSensitive) name else name.toLowerCase)
+        // scalastyle:on caselocale
+        headerNames.diff(headerNames.distinct).distinct
+      }
+
+      row.zipWithIndex.map { case (value, index) =>
+        if (value == null || value.isEmpty || value == options.nullValue) {
+          // When there are empty strings or the values set in `nullValue`, 
put the
+          // index as the suffix.
+          s"_c$index"
+          // scalastyle:off caselocale
+        } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
+          // scalastyle:on caselocale
+          // When there are case-insensitive duplicates, put the index as the 
suffix.
+          s"$value$index"
+        } else if (duplicates.contains(value)) {
+          // When there are duplicates, put the index as the suffix.
+          s"$value$index"
+        } else {
+          value
+        }
+      }
+    } else {
+      row.zipWithIndex.map { case (_, index) =>
+        // Uses default column names, "_c#" where # is its position of fields
+        // when header option is disabled.
+        s"_c$index"
+      }
+    }
+  }
+
   /**
    * Helper method that converts string representation of a character to 
actual character.
    * It handles some Java escaped strings and throws exception if given string 
is longer than one

http://git-wip-us.apache.org/repos/asf/spark/blob/39872af8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
index 9088d43..fbd19c6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
@@ -273,7 +273,10 @@ private[csv] object UnivocityParser {
       inputStream: InputStream,
       shouldDropHeader: Boolean,
       tokenizer: CsvParser): Iterator[Array[String]] = {
-    convertStream(inputStream, shouldDropHeader, tokenizer)(tokens => tokens)
+    val handleHeader: () => Unit =
+      () => if (shouldDropHeader) tokenizer.parseNext
+
+    convertStream(inputStream, tokenizer, handleHeader)(tokens => tokens)
   }
 
   /**
@@ -281,10 +284,9 @@ private[csv] object UnivocityParser {
    */
   def parseStream(
       inputStream: InputStream,
-      shouldDropHeader: Boolean,
       parser: UnivocityParser,
-      schema: StructType,
-      checkHeader: Array[String] => Unit): Iterator[InternalRow] = {
+      headerChecker: CSVHeaderChecker,
+      schema: StructType): Iterator[InternalRow] = {
     val tokenizer = parser.tokenizer
     val safeParser = new FailureSafeParser[Array[String]](
       input => Seq(parser.convert(input)),
@@ -292,25 +294,26 @@ private[csv] object UnivocityParser {
       schema,
       parser.options.columnNameOfCorruptRecord,
       parser.options.multiLine)
-    convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { 
tokens =>
+
+    val handleHeader: () => Unit =
+      () => headerChecker.checkHeaderColumnNames(tokenizer)
+
+    convertStream(inputStream, tokenizer, handleHeader) { tokens =>
       safeParser.parse(tokens)
     }.flatten
   }
 
   private def convertStream[T](
       inputStream: InputStream,
-      shouldDropHeader: Boolean,
       tokenizer: CsvParser,
-      checkHeader: Array[String] => Unit = _ => ())(
+      handleHeader: () => Unit)(
       convert: Array[String] => T) = new Iterator[T] {
     tokenizer.beginParsing(inputStream)
-    private var nextRecord = {
-      if (shouldDropHeader) {
-        val firstRecord = tokenizer.parseNext()
-        checkHeader(firstRecord)
-      }
-      tokenizer.parseNext()
-    }
+
+    // We can handle header here since here the stream is open.
+    handleHeader()
+
+    private var nextRecord = tokenizer.parseNext()
 
     override def hasNext: Boolean = nextRecord != null
 
@@ -330,7 +333,10 @@ private[csv] object UnivocityParser {
   def parseIterator(
       lines: Iterator[String],
       parser: UnivocityParser,
+      headerChecker: CSVHeaderChecker,
       schema: StructType): Iterator[InternalRow] = {
+    headerChecker.checkHeaderColumnNames(lines, parser.tokenizer)
+
     val options = parser.options
 
     val filteredLines: Iterator[String] = 
CSVUtils.filterCommentAndEmpty(lines, options)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to