Repository: spark Updated Branches: refs/heads/master bcf7121ed -> d463533de
[SPARK-24676][SQL] Project required data from CSV parsed data when column pruning disabled ## What changes were proposed in this pull request? This pr modified code to project required data from CSV parsed data when column pruning disabled. In the current master, an exception below happens if `spark.sql.csv.parser.columnPruning.enabled` is false. This is because required formats and CSV parsed formats are different from each other; ``` ./bin/spark-shell --conf spark.sql.csv.parser.columnPruning.enabled=false scala> val dir = "/tmp/spark-csv/csv" scala> spark.range(10).selectExpr("id % 2 AS p", "id").write.mode("overwrite").partitionBy("p").csv(dir) scala> spark.read.csv(dir).selectExpr("sum(p)").collect() 18/06/25 13:48:46 ERROR Executor: Exception in task 2.0 in stage 2.0 (TID 7) java.lang.ClassCastException: org.apache.spark.unsafe.types.UTF8String cannot be cast to java.lang.Integer at scala.runtime.BoxesRunTime.unboxToInt(BoxesRunTime.java:101) at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow$class.getInt(rows.scala:41) ... ``` ## How was this patch tested? Added tests in `CSVSuite`. Author: Takeshi Yamamuro <yamam...@apache.org> Closes #21657 from maropu/SPARK-24676. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d463533d Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d463533d Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d463533d Branch: refs/heads/master Commit: d463533ded89a05e9f77e590fd3de2ffa212d68b Parents: bcf7121 Author: Takeshi Yamamuro <yamam...@apache.org> Authored: Sun Jul 15 20:22:09 2018 -0700 Committer: Xiao Li <gatorsm...@gmail.com> Committed: Sun Jul 15 20:22:09 2018 -0700 ---------------------------------------------------------------------- .../datasources/csv/UnivocityParser.scala | 54 +++++++++++++++----- .../execution/datasources/csv/CSVSuite.scala | 29 +++++++++++ 2 files changed, 70 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d463533d/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index aa545e1..79143cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -33,29 +33,49 @@ import org.apache.spark.sql.execution.datasources.FailureSafeParser import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String + +/** + * Constructs a parser for a given schema that translates CSV data to an [[InternalRow]]. + * + * @param dataSchema The CSV data schema that is specified by the user, or inferred from underlying + * data files. + * @param requiredSchema The schema of the data that should be output for each row. This should be a + * subset of the columns in dataSchema. + * @param options Configuration options for a CSV parser. + */ class UnivocityParser( dataSchema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { require(requiredSchema.toSet.subsetOf(dataSchema.toSet), - "requiredSchema should be the subset of schema.") + s"requiredSchema (${requiredSchema.catalogString}) should be the subset of " + + s"dataSchema (${dataSchema.catalogString}).") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any + // This index is used to reorder parsed tokens + private val tokenIndexArr = + requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))).toArray + + // When column pruning is enabled, the parser only parses the required columns based on + // their positions in the data schema. + private val parsedSchema = if (options.columnPruning) requiredSchema else dataSchema + val tokenizer = { val parserSetting = options.asParserSettings - if (options.columnPruning && requiredSchema.length < dataSchema.length) { - val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) + // When to-be-parsed schema is shorter than the to-be-read data schema, we let Univocity CSV + // parser select a sequence of fields for reading by their positions. + // if (options.columnPruning && requiredSchema.length < dataSchema.length) { + if (parsedSchema.length < dataSchema.length) { parserSetting.selectIndexes(tokenIndexArr: _*) } new CsvParser(parserSetting) } - private val schema = if (options.columnPruning) requiredSchema else dataSchema - private val row = new GenericInternalRow(schema.length) + private val row = new GenericInternalRow(requiredSchema.length) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -82,7 +102,7 @@ class UnivocityParser( // // output row - ["A", 2] private val valueConverters: Array[ValueConverter] = { - schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray } /** @@ -183,7 +203,7 @@ class UnivocityParser( } } - private val doParse = if (schema.nonEmpty) { + private val doParse = if (requiredSchema.nonEmpty) { (input: String) => convert(tokenizer.parseLine(input)) } else { // If `columnPruning` enabled and partition attributes scanned only, @@ -197,15 +217,21 @@ class UnivocityParser( */ def parse(input: String): InternalRow = doParse(input) + private val getToken = if (options.columnPruning) { + (tokens: Array[String], index: Int) => tokens(index) + } else { + (tokens: Array[String], index: Int) => tokens(tokenIndexArr(index)) + } + private def convert(tokens: Array[String]): InternalRow = { - if (tokens.length != schema.length) { + if (tokens.length != parsedSchema.length) { // If the number of tokens doesn't match the schema, we should treat it as a malformed record. // However, we still have chance to parse some of the tokens, by adding extra null tokens in // the tail if the number is smaller, or by dropping extra tokens if the number is larger. - val checkedTokens = if (schema.length > tokens.length) { - tokens ++ new Array[String](schema.length - tokens.length) + val checkedTokens = if (parsedSchema.length > tokens.length) { + tokens ++ new Array[String](parsedSchema.length - tokens.length) } else { - tokens.take(schema.length) + tokens.take(parsedSchema.length) } def getPartialResult(): Option[InternalRow] = { try { @@ -222,9 +248,11 @@ class UnivocityParser( new RuntimeException("Malformed CSV record")) } else { try { + // When the length of the returned tokens is identical to the length of the parsed schema, + // we just need to convert the tokens that correspond to the required columns. var i = 0 - while (i < schema.length) { - row(i) = valueConverters(i).apply(tokens(i)) + while (i < requiredSchema.length) { + row(i) = valueConverters(i).apply(getToken(tokens, i)) i += 1 } row http://git-wip-us.apache.org/repos/asf/spark/blob/d463533d/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 84b91f6..ae8110f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1579,4 +1579,33 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } } } + + test("SPARK-24676 project required data from parsed data when columnPruning disabled") { + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { + withTempPath { path => + val dir = path.getAbsolutePath + spark.range(10).selectExpr("id % 2 AS p", "id AS c0", "id AS c1").write.partitionBy("p") + .option("header", "true").csv(dir) + val df1 = spark.read.option("header", true).csv(dir).selectExpr("sum(p)", "count(c0)") + checkAnswer(df1, Row(5, 10)) + + // empty required column case + val df2 = spark.read.option("header", true).csv(dir).selectExpr("sum(p)") + checkAnswer(df2, Row(5)) + } + + // the case where tokens length != parsedSchema length + withTempPath { path => + val dir = path.getAbsolutePath + Seq("1,2").toDF().write.text(dir) + // more tokens + val df1 = spark.read.schema("c0 int").format("csv").option("mode", "permissive").load(dir) + checkAnswer(df1, Row(1)) + // less tokens + val df2 = spark.read.schema("c0 int, c1 int, c2 int").format("csv") + .option("mode", "permissive").load(dir) + checkAnswer(df2, Row(1, 2, null)) + } + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org