Repository: spark Updated Branches: refs/heads/master 14d7c1c3e -> a8a1ac01c
[SPARK-24959][SQL] Speed up count() for JSON and CSV ## What changes were proposed in this pull request? In the PR, I propose to skip invoking of the CSV/JSON parser per each line in the case if the required schema is empty. Added benchmarks for `count()` shows performance improvement up to **3.5 times**. Before: ``` Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) -------------------------------------------------------------------------------------- JSON count() 7676 / 7715 1.3 767.6 CSV count() 3309 / 3363 3.0 330.9 ``` After: ``` Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) -------------------------------------------------------------------------------------- JSON count() 2104 / 2156 4.8 210.4 CSV count() 2332 / 2386 4.3 233.2 ``` ## How was this patch tested? It was tested by `CSVSuite` and `JSONSuite` as well as on added benchmarks. Author: Maxim Gekk <maxim.g...@databricks.com> Author: Maxim Gekk <max.g...@gmail.com> Closes #21909 from MaxGekk/empty-schema-optimization. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a8a1ac01 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a8a1ac01 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a8a1ac01 Branch: refs/heads/master Commit: a8a1ac01c4732f8a738b973c8486514cd88bf99b Parents: 14d7c1c Author: Maxim Gekk <maxim.g...@databricks.com> Authored: Sat Aug 18 10:34:49 2018 -0700 Committer: Xiao Li <gatorsm...@gmail.com> Committed: Sat Aug 18 10:34:49 2018 -0700 ---------------------------------------------------------------------- .../spark/sql/catalyst/json/JacksonParser.scala | 3 +- .../org/apache/spark/sql/DataFrameReader.scala | 6 ++- .../datasources/FailureSafeParser.scala | 12 +++++- .../datasources/csv/UnivocityParser.scala | 16 +++---- .../datasources/json/JsonDataSource.scala | 6 ++- .../datasources/csv/CSVBenchmarks.scala | 39 +++++++++++++++++ .../execution/datasources/csv/CSVSuite.scala | 26 +++++++++++ .../datasources/json/JsonBenchmarks.scala | 45 +++++++++++++++++++- .../execution/datasources/json/JsonSuite.scala | 27 +++++++++++- 9 files changed, 159 insertions(+), 21 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 6feea50..984979a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.json import java.io.{ByteArrayOutputStream, CharConversionException} +import java.nio.charset.MalformedInputException import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -402,7 +403,7 @@ class JacksonParser( } } } catch { - case e @ (_: RuntimeException | _: JsonProcessingException) => + case e @ (_: RuntimeException | _: JsonProcessingException | _: MalformedInputException) => // JSON parser currently doesn't support partial results for corrupted records. // For such records, all fields other than the field configured by // `columnNameOfCorruptRecord` are set to `null`. http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 9bd1134..1b3a9fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -450,7 +450,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { input => rawParser.parse(input, createParser, UTF8String.fromString), parsedOptions.parseMode, schema, - parsedOptions.columnNameOfCorruptRecord) + parsedOptions.columnNameOfCorruptRecord, + parsedOptions.multiLine) iter.flatMap(parser.parse) } sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = jsonDataset.isStreaming) @@ -521,7 +522,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { input => Seq(rawParser.parse(input)), parsedOptions.parseMode, schema, - parsedOptions.columnNameOfCorruptRecord) + parsedOptions.columnNameOfCorruptRecord, + parsedOptions.multiLine) iter.flatMap(parser.parse) } sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = csvDataset.isStreaming) http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala index 43591a9..90e8166 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String @@ -28,7 +29,8 @@ class FailureSafeParser[IN]( rawParser: IN => Seq[InternalRow], mode: ParseMode, schema: StructType, - columnNameOfCorruptRecord: String) { + columnNameOfCorruptRecord: String, + isMultiLine: Boolean) { private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) @@ -56,9 +58,15 @@ class FailureSafeParser[IN]( } } + private val skipParsing = !isMultiLine && mode == PermissiveMode && schema.isEmpty + def parse(input: IN): Iterator[InternalRow] = { try { - rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) + if (skipParsing) { + Iterator.single(InternalRow.empty) + } else { + rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) + } } catch { case e: BadRecordException => mode match { case PermissiveMode => http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 79143cc..e15af42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -203,19 +203,11 @@ class UnivocityParser( } } - private val doParse = if (requiredSchema.nonEmpty) { - (input: String) => convert(tokenizer.parseLine(input)) - } else { - // If `columnPruning` enabled and partition attributes scanned only, - // `schema` gets empty. - (_: String) => InternalRow.empty - } - /** * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): InternalRow = doParse(input) + def parse(input: String): InternalRow = convert(tokenizer.parseLine(input)) private val getToken = if (options.columnPruning) { (tokens: Array[String], index: Int) => tokens(index) @@ -293,7 +285,8 @@ private[csv] object UnivocityParser { input => Seq(parser.convert(input)), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens => safeParser.parse(tokens) }.flatten @@ -341,7 +334,8 @@ private[csv] object UnivocityParser { input => Seq(parser.parse(input)), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) filteredLines.flatMap(safeParser.parse) } } http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index d6c5888..76f5837 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -139,7 +139,8 @@ object TextInputJsonDataSource extends JsonDataSource { input => parser.parse(input, textParser, textToUTF8String), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) linesReader.flatMap(safeParser.parse) } @@ -223,7 +224,8 @@ object MultiLineJsonDataSource extends JsonDataSource { input => parser.parse[InputStream](input, streamParser, partitionedFileString), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) safeParser.parse( CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath)))) http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index 1a3dacb..24f5f55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -119,8 +119,47 @@ object CSVBenchmarks { } } + def countBenchmark(rowsNum: Int): Unit = { + val colsNum = 10 + val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write + .csv(path.getAbsolutePath) + + val ds = spark.read.schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ => + ds.select("*").filter((_: Row) => true).count() + } + benchmark.addCase(s"Select 1 column + count()", 3) { _ => + ds.select($"col1").filter((_: Row) => true).count() + } + benchmark.addCase(s"count()", 3) { _ => + ds.count() + } + + /* + Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz + + Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + --------------------------------------------------------------------------------------------- + Select 10 columns + count() 12598 / 12740 0.8 1259.8 1.0X + Select 1 column + count() 7960 / 8175 1.3 796.0 1.6X + count() 2332 / 2386 4.3 233.2 5.4X + */ + benchmark.run() + } + } + def main(args: Array[String]): Unit = { quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) multiColumnsBenchmark(rowsNum = 1000 * 1000) + countBenchmark(10 * 1000 * 1000) } } http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 456b453..14840e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1641,4 +1641,30 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } } } + + test("count() for malformed input") { + def countForMalformedCSV(expected: Long, input: Seq[String]): Unit = { + val schema = new StructType().add("a", IntegerType) + val strings = spark.createDataset(input) + val df = spark.read.schema(schema).option("header", false).csv(strings) + + assert(df.count() == expected) + } + def checkCount(expected: Long): Unit = { + val validRec = "1" + val inputs = Seq( + Seq("{-}", validRec), + Seq(validRec, "?"), + Seq("0xAC", validRec), + Seq(validRec, "0.314"), + Seq("\\\\\\", validRec) + ) + inputs.foreach { input => + countForMalformedCSV(expected, input) + } + } + + checkCount(2) + countForMalformedCSV(0, Seq("")) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala index 85cf054..a2b747e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution.datasources.json import java.io.File import org.apache.spark.SparkConf -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.types.{LongType, StringType, StructType} +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ import org.apache.spark.util.{Benchmark, Utils} /** @@ -171,9 +172,49 @@ object JSONBenchmarks { } } + def countBenchmark(rowsNum: Int): Unit = { + val colsNum = 10 + val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + val columnNames = schema.fieldNames + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write + .json(path.getAbsolutePath) + + val ds = spark.read.schema(schema).json(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ => + ds.select("*").filter((_: Row) => true).count() + } + benchmark.addCase(s"Select 1 column + count()", 3) { _ => + ds.select($"col1").filter((_: Row) => true).count() + } + benchmark.addCase(s"count()", 3) { _ => + ds.count() + } + + /* + Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz + + Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + --------------------------------------------------------------------------------------------- + Select 10 columns + count() 9961 / 10006 1.0 996.1 1.0X + Select 1 column + count() 8355 / 8470 1.2 835.5 1.2X + count() 2104 / 2156 4.8 210.4 4.7X + */ + benchmark.run() + } + } + def main(args: Array[String]): Unit = { schemaInferring(100 * 1000 * 1000) perlineParsing(100 * 1000 * 1000) perlineParsingOfWideColumn(10 * 1000 * 1000) + countBenchmark(10 * 1000 * 1000) } } http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 655f40a..3e4cc8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2223,7 +2223,6 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) } - test("SPARK-23723: specified encoding is not matched to actual encoding") { val fileName = "test-data/utf16LE.json" val schema = new StructType().add("firstName", StringType).add("lastName", StringType) @@ -2490,4 +2489,30 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(exception.getMessage.contains("encoding must not be included in the blacklist")) } } + + test("count() for malformed input") { + def countForMalformedJSON(expected: Long, input: Seq[String]): Unit = { + val schema = new StructType().add("a", StringType) + val strings = spark.createDataset(input) + val df = spark.read.schema(schema).json(strings) + + assert(df.count() == expected) + } + def checkCount(expected: Long): Unit = { + val validRec = """{"a":"b"}""" + val inputs = Seq( + Seq("{-}", validRec), + Seq(validRec, "?"), + Seq("}", validRec), + Seq(validRec, """{"a": [1, 2, 3]}"""), + Seq("""{"a": {"a": "b"}}""", validRec) + ) + inputs.foreach { input => + countForMalformedJSON(expected, input) + } + } + + checkCount(2) + countForMalformedJSON(0, Seq("")) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org