Repository: spark Updated Branches: refs/heads/master ce233f18e -> 7e5359be5
[SPARK-19610][SQL] Support parsing multiline CSV files ## What changes were proposed in this pull request? This PR proposes the support for multiple lines for CSV by resembling the multiline supports in JSON datasource (in case of JSON, per file). So, this PR introduces `wholeFile` option which makes the format not splittable and reads each whole file. Since Univocity parser can produces each row from a stream, it should be capable of parsing very large documents when the internal rows are fix in the memory. ## How was this patch tested? Unit tests in `CSVSuite` and `tests.py` Manual tests with a single 9GB CSV file in local file system, for example, ```scala spark.read.option("wholeFile", true).option("inferSchema", true).csv("tmp.csv").count() ``` Author: hyukjinkwon <gurwls...@gmail.com> Closes #16976 from HyukjinKwon/SPARK-19610. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7e5359be Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7e5359be Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7e5359be Branch: refs/heads/master Commit: 7e5359be5ca038fdb579712b18e7f226d705c276 Parents: ce233f1 Author: hyukjinkwon <gurwls...@gmail.com> Authored: Tue Feb 28 13:34:33 2017 -0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Tue Feb 28 13:34:33 2017 -0800 ---------------------------------------------------------------------- python/pyspark/sql/readwriter.py | 6 +- python/pyspark/sql/streaming.py | 6 +- python/pyspark/sql/tests.py | 9 +- python/test_support/sql/ages_newlines.csv | 6 + .../org/apache/spark/sql/DataFrameReader.scala | 1 + .../execution/datasources/CodecStreams.scala | 12 + .../datasources/csv/CSVDataSource.scala | 239 +++++++++++++++++++ .../datasources/csv/CSVFileFormat.scala | 77 ++---- .../datasources/csv/CSVInferSchema.scala | 59 +---- .../execution/datasources/csv/CSVOptions.scala | 2 + .../datasources/csv/UnivocityParser.scala | 94 +++++++- .../datasources/json/JsonDataSource.scala | 18 +- .../spark/sql/streaming/DataStreamReader.scala | 1 + .../execution/datasources/csv/CSVSuite.scala | 192 ++++++++++----- 14 files changed, 525 insertions(+), 197 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/7e5359be/python/pyspark/sql/readwriter.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b5e5b18..ec47618 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -308,7 +308,7 @@ class DataFrameReader(OptionUtils): ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, - columnNameOfCorruptRecord=None): + columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -385,6 +385,8 @@ class DataFrameReader(OptionUtils): ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. + :param wholeFile: parse records, which may span multiple lines. If None is + set, it uses the default value, ``false``. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -398,7 +400,7 @@ class DataFrameReader(OptionUtils): dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, - columnNameOfCorruptRecord=columnNameOfCorruptRecord) + columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): path = [path] return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) http://git-wip-us.apache.org/repos/asf/spark/blob/7e5359be/python/pyspark/sql/streaming.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index bd19fd4..7587875 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -562,7 +562,7 @@ class DataStreamReader(OptionUtils): ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, - columnNameOfCorruptRecord=None): + columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -637,6 +637,8 @@ class DataStreamReader(OptionUtils): ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. + :param wholeFile: parse one record, which may span multiple lines. If None is + set, it uses the default value, ``false``. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming @@ -652,7 +654,7 @@ class DataStreamReader(OptionUtils): dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, - columnNameOfCorruptRecord=columnNameOfCorruptRecord) + columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: http://git-wip-us.apache.org/repos/asf/spark/blob/7e5359be/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fd083e4..e943f8d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -437,12 +437,19 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(res.collect(), [Row(id=0, copy=0)]) def test_wholefile_json(self): - from pyspark.sql.types import StringType people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", wholeFile=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_wholefile_csv(self): + ages_newlines = self.spark.read.csv( + "python/test_support/sql/ages_newlines.csv", wholeFile=True) + expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'), + Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'), + Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')] + self.assertEqual(ages_newlines.collect(), expected) + def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name from pyspark.sql.types import StringType http://git-wip-us.apache.org/repos/asf/spark/blob/7e5359be/python/test_support/sql/ages_newlines.csv ---------------------------------------------------------------------- diff --git a/python/test_support/sql/ages_newlines.csv b/python/test_support/sql/ages_newlines.csv new file mode 100644 index 0000000..d19f673 --- /dev/null +++ b/python/test_support/sql/ages_newlines.csv @@ -0,0 +1,6 @@ +Joe,20,"Hi, +I am Jeo" +Tom,30,"My name is Tom" +Hyukjin,25,"I am Hyukjin + +I love Spark!" http://git-wip-us.apache.org/repos/asf/spark/blob/7e5359be/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 59baf6e..63be1e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -463,6 +463,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * <li>`columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li> + * <li>`wholeFile` (default `false`): parse one record, which may span multiple lines.</li> * </ul> * @since 2.0.0 */ http://git-wip-us.apache.org/repos/asf/spark/blob/7e5359be/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala index 0762d1b..54549f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala @@ -27,6 +27,8 @@ import org.apache.hadoop.mapreduce.JobContext import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.util.ReflectionUtils +import org.apache.spark.TaskContext + object CodecStreams { private def getDecompressionCodec(config: Configuration, file: Path): Option[CompressionCodec] = { val compressionCodecs = new CompressionCodecFactory(config) @@ -42,6 +44,16 @@ object CodecStreams { .getOrElse(inputStream) } + /** + * Creates an input stream from the string path and add a closure for the input stream to be + * closed on task completion. + */ + def createInputStreamWithCloseResource(config: Configuration, path: String): InputStream = { + val inputStream = createInputStream(config, new Path(path)) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) + inputStream + } + private def getCompressionCodec( context: JobContext, file: Option[Path] = None): Option[CompressionCodec] = { http://git-wip-us.apache.org/repos/asf/spark/blob/7e5359be/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala new file mode 100644 index 0000000..73e6abc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.io.InputStream +import java.nio.charset.{Charset, StandardCharsets} + +import com.univocity.parsers.csv.{CsvParser, CsvParserSettings} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapred.TextInputFormat +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat + +import org.apache.spark.TaskContext +import org.apache.spark.input.{PortableDataStream, StreamInputFormat} +import org.apache.spark.rdd.{BinaryFileRDD, RDD} +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.types.StructType + +/** + * Common functions for parsing CSV files + */ +abstract class CSVDataSource extends Serializable { + def isSplitable: Boolean + + /** + * Parse a [[PartitionedFile]] into [[InternalRow]] instances. + */ + def readFile( + conf: Configuration, + file: PartitionedFile, + parser: UnivocityParser, + parsedOptions: CSVOptions): Iterator[InternalRow] + + /** + * Infers the schema from `inputPaths` files. + */ + def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: CSVOptions): Option[StructType] + + /** + * Generates a header from the given row which is null-safe and duplicate-safe. + */ + protected def makeSafeHeader( + row: Array[String], + caseSensitive: Boolean, + options: CSVOptions): Array[String] = { + if (options.headerFlag) { + val duplicates = { + val headerNames = row.filter(_ != null) + .map(name => if (caseSensitive) name else name.toLowerCase) + headerNames.diff(headerNames.distinct).distinct + } + + row.zipWithIndex.map { case (value, index) => + if (value == null || value.isEmpty || value == options.nullValue) { + // When there are empty strings or the values set in `nullValue`, put the + // index as the suffix. + s"_c$index" + } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { + // When there are case-insensitive duplicates, put the index as the suffix. + s"$value$index" + } else if (duplicates.contains(value)) { + // When there are duplicates, put the index as the suffix. + s"$value$index" + } else { + value + } + } + } else { + row.zipWithIndex.map { case (_, index) => + // Uses default column names, "_c#" where # is its position of fields + // when header option is disabled. + s"_c$index" + } + } + } +} + +object CSVDataSource { + def apply(options: CSVOptions): CSVDataSource = { + if (options.wholeFile) { + WholeFileCSVDataSource + } else { + TextInputCSVDataSource + } + } +} + +object TextInputCSVDataSource extends CSVDataSource { + override val isSplitable: Boolean = true + + override def readFile( + conf: Configuration, + file: PartitionedFile, + parser: UnivocityParser, + parsedOptions: CSVOptions): Iterator[InternalRow] = { + val lines = { + val linesReader = new HadoopFileLinesReader(file, conf) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + linesReader.map { line => + new String(line.getBytes, 0, line.getLength, parsedOptions.charset) + } + } + + val shouldDropHeader = parsedOptions.headerFlag && file.start == 0 + UnivocityParser.parseIterator(lines, shouldDropHeader, parser) + } + + override def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: CSVOptions): Option[StructType] = { + val csv: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) + val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).first() + val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + + Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + } + + private def createBaseDataset( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + options: CSVOptions): Dataset[String] = { + val paths = inputPaths.map(_.getPath.toString) + if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + className = classOf[TextFileFormat].getName + ).resolveRelation(checkFilesExist = false)) + .select("value").as[String](Encoders.STRING) + } else { + val charset = options.charset + val rdd = sparkSession.sparkContext + .hadoopFile[LongWritable, Text, TextInputFormat](paths.mkString(",")) + .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) + sparkSession.createDataset(rdd)(Encoders.STRING) + } + } +} + +object WholeFileCSVDataSource extends CSVDataSource { + override val isSplitable: Boolean = false + + override def readFile( + conf: Configuration, + file: PartitionedFile, + parser: UnivocityParser, + parsedOptions: CSVOptions): Iterator[InternalRow] = { + UnivocityParser.parseStream( + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), + parsedOptions.headerFlag, + parser) + } + + override def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: CSVOptions): Option[StructType] = { + val csv: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions) + val maybeFirstRow: Option[Array[String]] = csv.flatMap { lines => + UnivocityParser.tokenizeStream( + CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), + false, + new CsvParser(parsedOptions.asParserSettings)) + }.take(1).headOption + + if (maybeFirstRow.isDefined) { + val firstRow = maybeFirstRow.get + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.flatMap { lines => + UnivocityParser.tokenizeStream( + CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), + parsedOptions.headerFlag, + new CsvParser(parsedOptions.asParserSettings)) + } + Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + } else { + // If the first row could not be read, just return the empty schema. + Some(StructType(Nil)) + } + } + + private def createBaseRdd( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + options: CSVOptions): RDD[PortableDataStream] = { + val paths = inputPaths.map(_.getPath) + val name = paths.mkString(",") + val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + FileInputFormat.setInputPaths(job, paths: _*) + val conf = job.getConfiguration + + val rdd = new BinaryFileRDD( + sparkSession.sparkContext, + classOf[StreamInputFormat], + classOf[String], + classOf[PortableDataStream], + conf, + sparkSession.sparkContext.defaultMinPartitions) + + // Only returns `PortableDataStream`s without paths. + rdd.setName(s"CSVFile: $name").values + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/7e5359be/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 59f2919..29c4145 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -17,21 +17,15 @@ package org.apache.spark.sql.execution.datasources.csv -import java.nio.charset.{Charset, StandardCharsets} - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce._ -import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -43,11 +37,15 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def shortName(): String = "csv" - override def toString: String = "CSV" - - override def hashCode(): Int = getClass.hashCode() - - override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + val parsedOptions = + new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val csvDataSource = CSVDataSource(parsedOptions) + csvDataSource.isSplitable && super.isSplitable(sparkSession, options, path) + } override def inferSchema( sparkSession: SparkSession, @@ -55,11 +53,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { files: Seq[FileStatus]): Option[StructType] = { require(files.nonEmpty, "Cannot infer schema from an empty set of files") - val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - val paths = files.map(_.getPath.toString) - val lines: Dataset[String] = createBaseDataset(sparkSession, csvOptions, paths) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - Some(CSVInferSchema.infer(lines, caseSensitive, csvOptions)) + val parsedOptions = + new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + + CSVDataSource(parsedOptions).infer(sparkSession, files, parsedOptions) } override def prepareWrite( @@ -115,49 +112,17 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { } (file: PartitionedFile) => { - val lines = { - val conf = broadcastedHadoopConf.value.value - val linesReader = new HadoopFileLinesReader(file, conf) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) - linesReader.map { line => - new String(line.getBytes, 0, line.getLength, parsedOptions.charset) - } - } - - val linesWithoutHeader = if (parsedOptions.headerFlag && file.start == 0) { - // Note that if there are only comments in the first block, the header would probably - // be not dropped. - CSVUtils.dropHeaderLine(lines, parsedOptions) - } else { - lines - } - - val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, parsedOptions) + val conf = broadcastedHadoopConf.value.value val parser = new UnivocityParser(dataSchema, requiredSchema, parsedOptions) - filteredLines.flatMap(parser.parse) + CSVDataSource(parsedOptions).readFile(conf, file, parser, parsedOptions) } } - private def createBaseDataset( - sparkSession: SparkSession, - options: CSVOptions, - inputPaths: Seq[String]): Dataset[String] = { - if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = inputPaths, - className = classOf[TextFileFormat].getName - ).resolveRelation(checkFilesExist = false)) - .select("value").as[String](Encoders.STRING) - } else { - val charset = options.charset - val rdd = sparkSession.sparkContext - .hadoopFile[LongWritable, Text, TextInputFormat](inputPaths.mkString(",")) - .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) - sparkSession.createDataset(rdd)(Encoders.STRING) - } - } + override def toString: String = "CSV" + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] } private[csv] class CsvOutputWriter( http://git-wip-us.apache.org/repos/asf/spark/blob/7e5359be/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 3fa30fe..b64d71b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -21,11 +21,9 @@ import java.math.BigDecimal import scala.util.control.Exception._ -import com.univocity.parsers.csv.CsvParser - +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.Dataset import org.apache.spark.sql.types._ private[csv] object CSVInferSchema { @@ -37,24 +35,13 @@ private[csv] object CSVInferSchema { * 3. Replace any null types with string type */ def infer( - csv: Dataset[String], - caseSensitive: Boolean, + tokenRDD: RDD[Array[String]], + header: Array[String], options: CSVOptions): StructType = { - val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, options).first() - val firstRow = new CsvParser(options.asParserSettings).parseLine(firstLine) - val header = makeSafeHeader(firstRow, caseSensitive, options) - val fields = if (options.inferSchemaFlag) { - val tokenRdd = csv.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, options) - val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, options) - val parser = new CsvParser(options.asParserSettings) - linesWithoutHeader.map(parser.parseLine) - } - val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) val rootTypes: Array[DataType] = - tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes) + tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes) header.zip(rootTypes).map { case (thisHeader, rootType) => val dType = rootType match { @@ -71,44 +58,6 @@ private[csv] object CSVInferSchema { StructType(fields) } - /** - * Generates a header from the given row which is null-safe and duplicate-safe. - */ - private def makeSafeHeader( - row: Array[String], - caseSensitive: Boolean, - options: CSVOptions): Array[String] = { - if (options.headerFlag) { - val duplicates = { - val headerNames = row.filter(_ != null) - .map(name => if (caseSensitive) name else name.toLowerCase) - headerNames.diff(headerNames.distinct).distinct - } - - row.zipWithIndex.map { case (value, index) => - if (value == null || value.isEmpty || value == options.nullValue) { - // When there are empty strings or the values set in `nullValue`, put the - // index as the suffix. - s"_c$index" - } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { - // When there are case-insensitive duplicates, put the index as the suffix. - s"$value$index" - } else if (duplicates.contains(value)) { - // When there are duplicates, put the index as the suffix. - s"$value$index" - } else { - value - } - } - } else { - row.zipWithIndex.map { case (_, index) => - // Uses default column names, "_c#" where # is its position of fields - // when header option is disabled. - s"_c$index" - } - } - } - private def inferRowType(options: CSVOptions) (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { var i = 0 http://git-wip-us.apache.org/repos/asf/spark/blob/7e5359be/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 1caeec7..5050338 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -130,6 +130,8 @@ private[csv] class CSVOptions( FastDateFormat.getInstance( parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US) + val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) + val maxColumns = getInt("maxColumns", 20480) val maxCharsPerColumn = getInt("maxCharsPerColumn", -1) http://git-wip-us.apache.org/repos/asf/spark/blob/7e5359be/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index eb47165..804031a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.csv +import java.io.InputStream import java.math.BigDecimal import java.text.NumberFormat import java.util.Locale @@ -36,7 +37,7 @@ import org.apache.spark.unsafe.types.UTF8String private[csv] class UnivocityParser( schema: StructType, requiredSchema: StructType, - options: CSVOptions) extends Logging { + private val options: CSVOptions) extends Logging { require(requiredSchema.toSet.subsetOf(schema.toSet), "requiredSchema should be the subset of schema.") @@ -56,12 +57,15 @@ private[csv] class UnivocityParser( private val valueConverters = dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - private val parser = new CsvParser(options.asParserSettings) + private val tokenizer = new CsvParser(options.asParserSettings) private var numMalformedRecords = 0 private val row = new GenericInternalRow(requiredSchema.length) + // This gets the raw input that is parsed lately. + private def getCurrentInput(): String = tokenizer.getContext.currentParsedContent().stripLineEnd + // This parser loads an `indexArr._1`-th position value in input tokens, // then put the value in `row(indexArr._2)`. private val indexArr: Array[(Int, Int)] = { @@ -188,12 +192,13 @@ private[csv] class UnivocityParser( } /** - * Parses a single CSV record (in the form of an array of strings in which - * each element represents a column) and turns it into either one resulting row or no row (if the + * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): Option[InternalRow] = { - convertWithParseMode(input) { tokens => + def parse(input: String): Option[InternalRow] = convert(tokenizer.parseLine(input)) + + private def convert(tokens: Array[String]): Option[InternalRow] = { + convertWithParseMode(tokens) { tokens => var i: Int = 0 while (i < indexArr.length) { val (pos, rowIdx) = indexArr(i) @@ -211,8 +216,7 @@ private[csv] class UnivocityParser( } private def convertWithParseMode( - input: String)(convert: Array[String] => InternalRow): Option[InternalRow] = { - val tokens = parser.parseLine(input) + tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = { if (options.dropMalformed && dataSchema.length != tokens.length) { if (numMalformedRecords < options.maxMalformedLogPerPartition) { logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") @@ -251,7 +255,7 @@ private[csv] class UnivocityParser( } catch { case NonFatal(e) if options.permissive => val row = new GenericInternalRow(requiredSchema.length) - corruptFieldIndex.foreach(row(_) = UTF8String.fromString(input)) + corruptFieldIndex.foreach(row(_) = UTF8String.fromString(getCurrentInput())) Some(row) case NonFatal(e) if options.dropMalformed => if (numMalformedRecords < options.maxMalformedLogPerPartition) { @@ -269,3 +273,75 @@ private[csv] class UnivocityParser( } } } + +private[csv] object UnivocityParser { + + /** + * Parses a stream that contains CSV strings and turns it into an iterator of tokens. + */ + def tokenizeStream( + inputStream: InputStream, + shouldDropHeader: Boolean, + tokenizer: CsvParser): Iterator[Array[String]] = { + convertStream(inputStream, shouldDropHeader, tokenizer)(tokens => tokens) + } + + /** + * Parses a stream that contains CSV strings and turns it into an iterator of rows. + */ + def parseStream( + inputStream: InputStream, + shouldDropHeader: Boolean, + parser: UnivocityParser): Iterator[InternalRow] = { + val tokenizer = parser.tokenizer + convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => + parser.convert(tokens) + }.flatten + } + + private def convertStream[T]( + inputStream: InputStream, + shouldDropHeader: Boolean, + tokenizer: CsvParser)(convert: Array[String] => T) = new Iterator[T] { + tokenizer.beginParsing(inputStream) + private var nextRecord = { + if (shouldDropHeader) { + tokenizer.parseNext() + } + tokenizer.parseNext() + } + + override def hasNext: Boolean = nextRecord != null + + override def next(): T = { + if (!hasNext) { + throw new NoSuchElementException("End of stream") + } + val curRecord = convert(nextRecord) + nextRecord = tokenizer.parseNext() + curRecord + } + } + + /** + * Parses an iterator that contains CSV strings and turns it into an iterator of rows. + */ + def parseIterator( + lines: Iterator[String], + shouldDropHeader: Boolean, + parser: UnivocityParser): Iterator[InternalRow] = { + val options = parser.options + + val linesWithoutHeader = if (shouldDropHeader) { + // Note that if there are only comments in the first block, the header would probably + // be not dropped. + CSVUtils.dropHeaderLine(lines, options) + } else { + lines + } + + val filteredLines: Iterator[String] = + CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) + filteredLines.flatMap(line => parser.parse(line)) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/7e5359be/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 3e984ef..18843bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.InputStream - import scala.reflect.ClassTag import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat} @@ -186,16 +184,10 @@ object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { } } - private def createInputStream(config: Configuration, path: String): InputStream = { - val inputStream = CodecStreams.createInputStream(config, new Path(path)) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) - inputStream - } - override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { CreateJacksonParser.inputStream( jsonFactory, - createInputStream(record.getConfiguration, record.getPath())) + CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath())) } override def readFile( @@ -203,13 +195,15 @@ object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { file: PartitionedFile, parser: JacksonParser): Iterator[InternalRow] = { def partitionedFileString(ignored: Any): UTF8String = { - Utils.tryWithResource(createInputStream(conf, file.filePath)) { inputStream => + Utils.tryWithResource { + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath) + } { inputStream => UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) } } parser.parse( - createInputStream(conf, file.filePath), + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), CreateJacksonParser.inputStream, partitionedFileString).toIterator } http://git-wip-us.apache.org/repos/asf/spark/blob/7e5359be/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index f78e73f..6a27528 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -261,6 +261,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * <li>`columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li> + * <li>`wholeFile` (default `false`): parse one record, which may span multiple lines.</li> * </ul> * * @since 2.0.0 http://git-wip-us.apache.org/repos/asf/spark/blob/7e5359be/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 371d431..d94eb66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -24,11 +24,12 @@ import java.text.SimpleDateFormat import java.util.Locale import org.apache.commons.lang3.time.FastDateFormat -import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec +import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} +import org.apache.spark.sql.functions.{col, regexp_replace} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ @@ -243,12 +244,15 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for DROPMALFORMED parsing mode") { - val cars = spark.read - .format("csv") - .options(Map("header" -> "true", "mode" -> "dropmalformed")) - .load(testFile(carsFile)) + Seq(false, true).foreach { wholeFile => + val cars = spark.read + .format("csv") + .option("wholeFile", wholeFile) + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) - assert(cars.select("year").collect().size === 2) + assert(cars.select("year").collect().size === 2) + } } test("test for blank column names on read and select columns") { @@ -263,14 +267,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for FAILFAST parsing mode") { - val exception = intercept[SparkException]{ - spark.read - .format("csv") - .options(Map("header" -> "true", "mode" -> "failfast")) - .load(testFile(carsFile)).collect() - } + Seq(false, true).foreach { wholeFile => + val exception = intercept[SparkException] { + spark.read + .format("csv") + .option("wholeFile", wholeFile) + .options(Map("header" -> "true", "mode" -> "failfast")) + .load(testFile(carsFile)).collect() + } - assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + } } test("test for tokens more than the fields in the schema") { @@ -961,56 +968,121 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") { - val schema = new StructType().add("a", IntegerType).add("b", TimestampType) - val df1 = spark - .read - .option("mode", "PERMISSIVE") - .schema(schema) - .csv(testFile(valueMalformedFile)) - checkAnswer(df1, - Row(null, null) :: - Row(1, java.sql.Date.valueOf("1983-08-04")) :: - Nil) - - // If `schema` has `columnNameOfCorruptRecord`, it should handle corrupt records - val columnNameOfCorruptRecord = "_unparsed" - val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType) - val df2 = spark - .read - .option("mode", "PERMISSIVE") - .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .schema(schemaWithCorrField1) - .csv(testFile(valueMalformedFile)) - checkAnswer(df2, - Row(null, null, "0,2013-111-11 12:13:14") :: - Row(1, java.sql.Date.valueOf("1983-08-04"), null) :: - Nil) - - // We put a `columnNameOfCorruptRecord` field in the middle of a schema - val schemaWithCorrField2 = new StructType() - .add("a", IntegerType) - .add(columnNameOfCorruptRecord, StringType) - .add("b", TimestampType) - val df3 = spark - .read - .option("mode", "PERMISSIVE") - .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .schema(schemaWithCorrField2) - .csv(testFile(valueMalformedFile)) - checkAnswer(df3, - Row(null, "0,2013-111-11 12:13:14", null) :: - Row(1, null, java.sql.Date.valueOf("1983-08-04")) :: - Nil) - - val errMsg = intercept[AnalysisException] { - spark + Seq(false, true).foreach { wholeFile => + val schema = new StructType().add("a", IntegerType).add("b", TimestampType) + val df1 = spark + .read + .option("mode", "PERMISSIVE") + .option("wholeFile", wholeFile) + .schema(schema) + .csv(testFile(valueMalformedFile)) + checkAnswer(df1, + Row(null, null) :: + Row(1, java.sql.Date.valueOf("1983-08-04")) :: + Nil) + + // If `schema` has `columnNameOfCorruptRecord`, it should handle corrupt records + val columnNameOfCorruptRecord = "_unparsed" + val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType) + val df2 = spark .read .option("mode", "PERMISSIVE") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .schema(schema.add(columnNameOfCorruptRecord, IntegerType)) + .option("wholeFile", wholeFile) + .schema(schemaWithCorrField1) .csv(testFile(valueMalformedFile)) - .collect - }.getMessage - assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + checkAnswer(df2, + Row(null, null, "0,2013-111-11 12:13:14") :: + Row(1, java.sql.Date.valueOf("1983-08-04"), null) :: + Nil) + + // We put a `columnNameOfCorruptRecord` field in the middle of a schema + val schemaWithCorrField2 = new StructType() + .add("a", IntegerType) + .add(columnNameOfCorruptRecord, StringType) + .add("b", TimestampType) + val df3 = spark + .read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .option("wholeFile", wholeFile) + .schema(schemaWithCorrField2) + .csv(testFile(valueMalformedFile)) + checkAnswer(df3, + Row(null, "0,2013-111-11 12:13:14", null) :: + Row(1, null, java.sql.Date.valueOf("1983-08-04")) :: + Nil) + + val errMsg = intercept[AnalysisException] { + spark + .read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .option("wholeFile", wholeFile) + .schema(schema.add(columnNameOfCorruptRecord, IntegerType)) + .csv(testFile(valueMalformedFile)) + .collect + }.getMessage + assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + } + } + + test("SPARK-19610: Parse normal multi-line CSV files") { + val primitiveFieldAndType = Seq( + """" + |string","integer + | + | + |","long + | + |","bigInteger",double,boolean,null""".stripMargin, + """"this is a + |simple + |string."," + | + |10"," + |21474836470","92233720368547758070"," + | + |1.7976931348623157E308",true,""".stripMargin) + + withTempPath { path => + primitiveFieldAndType.toDF("value").coalesce(1).write.text(path.getAbsolutePath) + + val df = spark.read + .option("header", true) + .option("wholeFile", true) + .csv(path.getAbsolutePath) + + // Check if headers have new lines in the names. + val actualFields = df.schema.fieldNames.toSeq + val expectedFields = + Seq("\nstring", "integer\n\n\n", "long\n\n", "bigInteger", "double", "boolean", "null") + assert(actualFields === expectedFields) + + // Check if the rows have new lines in the values. + val expected = Row( + "this is a\nsimple\nstring.", + "\n\n10", + "\n21474836470", + "92233720368547758070", + "\n\n1.7976931348623157E308", + "true", + null) + checkAnswer(df, expected) + } + } + + test("Empty file produces empty dataframe with empty schema - wholeFile option") { + withTempPath { path => + path.createNewFile() + + val df = spark.read.format("csv") + .option("header", true) + .option("wholeFile", true) + .load(path.getAbsolutePath) + + assert(df.schema === spark.emptyDataFrame.schema) + checkAnswer(df, spark.emptyDataFrame) + } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org