Repository: spark
Updated Branches:
  refs/heads/master 14d7c1c3e -> a8a1ac01c


[SPARK-24959][SQL] Speed up count() for JSON and CSV

## What changes were proposed in this pull request?

In the PR, I propose to skip invoking of the CSV/JSON parser per each line in 
the case if the required schema is empty. Added benchmarks for `count()` shows 
performance improvement up to **3.5 times**.

Before:

```
Count a dataset with 10 columns:      Best/Avg Time(ms)    Rate(M/s)   Per 
Row(ns)
--------------------------------------------------------------------------------------
JSON count()                               7676 / 7715          1.3         
767.6
CSV count()                                3309 / 3363          3.0         
330.9
```

After:

```
Count a dataset with 10 columns:      Best/Avg Time(ms)    Rate(M/s)   Per 
Row(ns)
--------------------------------------------------------------------------------------
JSON count()                               2104 / 2156          4.8         
210.4
CSV count()                                2332 / 2386          4.3         
233.2
```

## How was this patch tested?

It was tested by `CSVSuite` and `JSONSuite` as well as on added benchmarks.

Author: Maxim Gekk <maxim.g...@databricks.com>
Author: Maxim Gekk <max.g...@gmail.com>

Closes #21909 from MaxGekk/empty-schema-optimization.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a8a1ac01
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a8a1ac01
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a8a1ac01

Branch: refs/heads/master
Commit: a8a1ac01c4732f8a738b973c8486514cd88bf99b
Parents: 14d7c1c
Author: Maxim Gekk <maxim.g...@databricks.com>
Authored: Sat Aug 18 10:34:49 2018 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Sat Aug 18 10:34:49 2018 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/json/JacksonParser.scala |  3 +-
 .../org/apache/spark/sql/DataFrameReader.scala  |  6 ++-
 .../datasources/FailureSafeParser.scala         | 12 +++++-
 .../datasources/csv/UnivocityParser.scala       | 16 +++----
 .../datasources/json/JsonDataSource.scala       |  6 ++-
 .../datasources/csv/CSVBenchmarks.scala         | 39 +++++++++++++++++
 .../execution/datasources/csv/CSVSuite.scala    | 26 +++++++++++
 .../datasources/json/JsonBenchmarks.scala       | 45 +++++++++++++++++++-
 .../execution/datasources/json/JsonSuite.scala  | 27 +++++++++++-
 9 files changed, 159 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
index 6feea50..984979a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.json
 
 import java.io.{ByteArrayOutputStream, CharConversionException}
+import java.nio.charset.MalformedInputException
 
 import scala.collection.mutable.ArrayBuffer
 import scala.util.Try
@@ -402,7 +403,7 @@ class JacksonParser(
         }
       }
     } catch {
-      case e @ (_: RuntimeException | _: JsonProcessingException) =>
+      case e @ (_: RuntimeException | _: JsonProcessingException | _: 
MalformedInputException) =>
         // JSON parser currently doesn't support partial results for corrupted 
records.
         // For such records, all fields other than the field configured by
         // `columnNameOfCorruptRecord` are set to `null`.

http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 9bd1134..1b3a9fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -450,7 +450,8 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
         input => rawParser.parse(input, createParser, UTF8String.fromString),
         parsedOptions.parseMode,
         schema,
-        parsedOptions.columnNameOfCorruptRecord)
+        parsedOptions.columnNameOfCorruptRecord,
+        parsedOptions.multiLine)
       iter.flatMap(parser.parse)
     }
     sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = 
jsonDataset.isStreaming)
@@ -521,7 +522,8 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
         input => Seq(rawParser.parse(input)),
         parsedOptions.parseMode,
         schema,
-        parsedOptions.columnNameOfCorruptRecord)
+        parsedOptions.columnNameOfCorruptRecord,
+        parsedOptions.multiLine)
       iter.flatMap(parser.parse)
     }
     sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = 
csvDataset.isStreaming)

http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala
index 43591a9..90e8166 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
 import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -28,7 +29,8 @@ class FailureSafeParser[IN](
     rawParser: IN => Seq[InternalRow],
     mode: ParseMode,
     schema: StructType,
-    columnNameOfCorruptRecord: String) {
+    columnNameOfCorruptRecord: String,
+    isMultiLine: Boolean) {
 
   private val corruptFieldIndex = 
schema.getFieldIndex(columnNameOfCorruptRecord)
   private val actualSchema = StructType(schema.filterNot(_.name == 
columnNameOfCorruptRecord))
@@ -56,9 +58,15 @@ class FailureSafeParser[IN](
     }
   }
 
+  private val skipParsing = !isMultiLine && mode == PermissiveMode && 
schema.isEmpty
+
   def parse(input: IN): Iterator[InternalRow] = {
     try {
-      rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () 
=> null))
+     if (skipParsing) {
+       Iterator.single(InternalRow.empty)
+     } else {
+       rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () 
=> null))
+     }
     } catch {
       case e: BadRecordException => mode match {
         case PermissiveMode =>

http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
index 79143cc..e15af42 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
@@ -203,19 +203,11 @@ class UnivocityParser(
     }
   }
 
-  private val doParse = if (requiredSchema.nonEmpty) {
-    (input: String) => convert(tokenizer.parseLine(input))
-  } else {
-    // If `columnPruning` enabled and partition attributes scanned only,
-    // `schema` gets empty.
-    (_: String) => InternalRow.empty
-  }
-
   /**
    * Parses a single CSV string and turns it into either one resulting row or 
no row (if the
    * the record is malformed).
    */
-  def parse(input: String): InternalRow = doParse(input)
+  def parse(input: String): InternalRow = convert(tokenizer.parseLine(input))
 
   private val getToken = if (options.columnPruning) {
     (tokens: Array[String], index: Int) => tokens(index)
@@ -293,7 +285,8 @@ private[csv] object UnivocityParser {
       input => Seq(parser.convert(input)),
       parser.options.parseMode,
       schema,
-      parser.options.columnNameOfCorruptRecord)
+      parser.options.columnNameOfCorruptRecord,
+      parser.options.multiLine)
     convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { 
tokens =>
       safeParser.parse(tokens)
     }.flatten
@@ -341,7 +334,8 @@ private[csv] object UnivocityParser {
       input => Seq(parser.parse(input)),
       parser.options.parseMode,
       schema,
-      parser.options.columnNameOfCorruptRecord)
+      parser.options.columnNameOfCorruptRecord,
+      parser.options.multiLine)
     filteredLines.flatMap(safeParser.parse)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
index d6c5888..76f5837 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -139,7 +139,8 @@ object TextInputJsonDataSource extends JsonDataSource {
       input => parser.parse(input, textParser, textToUTF8String),
       parser.options.parseMode,
       schema,
-      parser.options.columnNameOfCorruptRecord)
+      parser.options.columnNameOfCorruptRecord,
+      parser.options.multiLine)
     linesReader.flatMap(safeParser.parse)
   }
 
@@ -223,7 +224,8 @@ object MultiLineJsonDataSource extends JsonDataSource {
       input => parser.parse[InputStream](input, streamParser, 
partitionedFileString),
       parser.options.parseMode,
       schema,
-      parser.options.columnNameOfCorruptRecord)
+      parser.options.columnNameOfCorruptRecord,
+      parser.options.multiLine)
 
     safeParser.parse(
       CodecStreams.createInputStreamWithCloseResource(conf, new Path(new 
URI(file.filePath))))

http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala
index 1a3dacb..24f5f55 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala
@@ -119,8 +119,47 @@ object CSVBenchmarks {
     }
   }
 
+  def countBenchmark(rowsNum: Int): Unit = {
+    val colsNum = 10
+    val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", 
rowsNum)
+
+    withTempPath { path =>
+      val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", 
IntegerType))
+      val schema = StructType(fields)
+
+      spark.range(rowsNum)
+        .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*)
+        .write
+        .csv(path.getAbsolutePath)
+
+      val ds = spark.read.schema(schema).csv(path.getAbsolutePath)
+
+      benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ =>
+        ds.select("*").filter((_: Row) => true).count()
+      }
+      benchmark.addCase(s"Select 1 column + count()", 3) { _ =>
+        ds.select($"col1").filter((_: Row) => true).count()
+      }
+      benchmark.addCase(s"count()", 3) { _ =>
+        ds.count()
+      }
+
+      /*
+      Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz
+
+      Count a dataset with 10 columns:      Best/Avg Time(ms)    Rate(M/s)   
Per Row(ns)   Relative
+      
---------------------------------------------------------------------------------------------
+      Select 10 columns + count()              12598 / 12740          0.8      
  1259.8       1.0X
+      Select 1 column + count()                  7960 / 8175          1.3      
   796.0       1.6X
+      count()                                    2332 / 2386          4.3      
   233.2       5.4X
+      */
+      benchmark.run()
+    }
+  }
+
   def main(args: Array[String]): Unit = {
     quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3)
     multiColumnsBenchmark(rowsNum = 1000 * 1000)
+    countBenchmark(10 * 1000 * 1000)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 456b453..14840e5 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -1641,4 +1641,30 @@ class CSVSuite extends QueryTest with SharedSQLContext 
with SQLTestUtils with Te
       }
     }
   }
+
+  test("count() for malformed input") {
+    def countForMalformedCSV(expected: Long, input: Seq[String]): Unit = {
+      val schema = new StructType().add("a", IntegerType)
+      val strings = spark.createDataset(input)
+      val df = spark.read.schema(schema).option("header", false).csv(strings)
+
+      assert(df.count() == expected)
+    }
+    def checkCount(expected: Long): Unit = {
+      val validRec = "1"
+      val inputs = Seq(
+        Seq("{-}", validRec),
+        Seq(validRec, "?"),
+        Seq("0xAC", validRec),
+        Seq(validRec, "0.314"),
+        Seq("\\\\\\", validRec)
+      )
+      inputs.foreach { input =>
+        countForMalformedCSV(expected, input)
+      }
+    }
+
+    checkCount(2)
+    countForMalformedCSV(0, Seq(""))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala
index 85cf054..a2b747e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.execution.datasources.json
 import java.io.File
 
 import org.apache.spark.SparkConf
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.types.{LongType, StringType, StructType}
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.types._
 import org.apache.spark.util.{Benchmark, Utils}
 
 /**
@@ -171,9 +172,49 @@ object JSONBenchmarks {
     }
   }
 
+  def countBenchmark(rowsNum: Int): Unit = {
+    val colsNum = 10
+    val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", 
rowsNum)
+
+    withTempPath { path =>
+      val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", 
IntegerType))
+      val schema = StructType(fields)
+      val columnNames = schema.fieldNames
+
+      spark.range(rowsNum)
+        .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*)
+        .write
+        .json(path.getAbsolutePath)
+
+      val ds = spark.read.schema(schema).json(path.getAbsolutePath)
+
+      benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ =>
+        ds.select("*").filter((_: Row) => true).count()
+      }
+      benchmark.addCase(s"Select 1 column + count()", 3) { _ =>
+        ds.select($"col1").filter((_: Row) => true).count()
+      }
+      benchmark.addCase(s"count()", 3) { _ =>
+        ds.count()
+      }
+
+      /*
+      Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz
+
+      Count a dataset with 10 columns:      Best/Avg Time(ms)    Rate(M/s)   
Per Row(ns)   Relative
+      
---------------------------------------------------------------------------------------------
+      Select 10 columns + count()               9961 / 10006          1.0      
   996.1       1.0X
+      Select 1 column + count()                  8355 / 8470          1.2      
   835.5       1.2X
+      count()                                    2104 / 2156          4.8      
   210.4       4.7X
+      */
+      benchmark.run()
+    }
+  }
+
   def main(args: Array[String]): Unit = {
     schemaInferring(100 * 1000 * 1000)
     perlineParsing(100 * 1000 * 1000)
     perlineParsingOfWideColumn(10 * 1000 * 1000)
+    countBenchmark(10 * 1000 * 1000)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a8a1ac01/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 655f40a..3e4cc8f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -2223,7 +2223,6 @@ class JsonSuite extends QueryTest with SharedSQLContext 
with TestJsonData {
     checkAnswer(jsonDF, Seq(Row("Chris", "Baird")))
   }
 
-
   test("SPARK-23723: specified encoding is not matched to actual encoding") {
     val fileName = "test-data/utf16LE.json"
     val schema = new StructType().add("firstName", StringType).add("lastName", 
StringType)
@@ -2490,4 +2489,30 @@ class JsonSuite extends QueryTest with SharedSQLContext 
with TestJsonData {
       assert(exception.getMessage.contains("encoding must not be included in 
the blacklist"))
     }
   }
+
+  test("count() for malformed input") {
+    def countForMalformedJSON(expected: Long, input: Seq[String]): Unit = {
+      val schema = new StructType().add("a", StringType)
+      val strings = spark.createDataset(input)
+      val df = spark.read.schema(schema).json(strings)
+
+      assert(df.count() == expected)
+    }
+    def checkCount(expected: Long): Unit = {
+      val validRec = """{"a":"b"}"""
+      val inputs = Seq(
+        Seq("{-}", validRec),
+        Seq(validRec, "?"),
+        Seq("}", validRec),
+        Seq(validRec, """{"a": [1, 2, 3]}"""),
+        Seq("""{"a": {"a": "b"}}""", validRec)
+      )
+      inputs.foreach { input =>
+        countForMalformedJSON(expected, input)
+      }
+    }
+
+    checkCount(2)
+    countForMalformedJSON(0, Seq(""))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to