[SPARK-25393][SQL] Adding new function from_csv() ## What changes were proposed in this pull request?
The PR adds new function `from_csv()` similar to `from_json()` to parse columns with CSV strings. I added the following methods: ```Scala def from_csv(e: Column, schema: StructType, options: Map[String, String]): Column ``` and this signature to call it from Python, R and Java: ```Scala def from_csv(e: Column, schema: String, options: java.util.Map[String, String]): Column ``` ## How was this patch tested? Added new test suites `CsvExpressionsSuite`, `CsvFunctionsSuite` and sql tests. Closes #22379 from MaxGekk/from_csv. Lead-authored-by: Maxim Gekk <maxim.g...@databricks.com> Co-authored-by: Maxim Gekk <max.g...@gmail.com> Co-authored-by: Hyukjin Kwon <gurwls...@gmail.com> Co-authored-by: hyukjinkwon <gurwls...@apache.org> Signed-off-by: hyukjinkwon <gurwls...@apache.org> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e9af9460 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e9af9460 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e9af9460 Branch: refs/heads/master Commit: e9af9460bc008106b670abac44a869721bfde42a Parents: 9d4dd79 Author: Maxim Gekk <maxim.g...@databricks.com> Authored: Wed Oct 17 09:32:05 2018 +0800 Committer: hyukjinkwon <gurwls...@apache.org> Committed: Wed Oct 17 09:32:05 2018 +0800 ---------------------------------------------------------------------- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 40 ++- R/pkg/R/generics.R | 4 + R/pkg/tests/fulltests/test_sparkSQL.R | 7 + python/pyspark/sql/functions.py | 37 +- sql/catalyst/pom.xml | 6 + .../catalyst/analysis/FunctionRegistry.scala | 5 +- .../spark/sql/catalyst/csv/CSVExprUtils.scala | 82 +++++ .../sql/catalyst/csv/CSVHeaderChecker.scala | 131 +++++++ .../spark/sql/catalyst/csv/CSVOptions.scala | 217 ++++++++++++ .../sql/catalyst/csv/UnivocityParser.scala | 351 ++++++++++++++++++ .../sql/catalyst/expressions/ExprUtils.scala | 45 +++ .../catalyst/expressions/csvExpressions.scala | 120 +++++++ .../catalyst/expressions/jsonExpressions.scala | 21 +- .../sql/catalyst/util/FailureSafeParser.scala | 80 +++++ .../sql/catalyst/csv/CSVExprUtilsSuite.scala | 61 ++++ .../expressions/CsvExpressionsSuite.scala | 158 +++++++++ .../org/apache/spark/sql/DataFrameReader.scala | 5 +- .../datasources/FailureSafeParser.scala | 82 ----- .../datasources/csv/CSVDataSource.scala | 1 + .../datasources/csv/CSVFileFormat.scala | 1 + .../datasources/csv/CSVHeaderChecker.scala | 131 ------- .../datasources/csv/CSVInferSchema.scala | 1 + .../execution/datasources/csv/CSVOptions.scala | 217 ------------ .../execution/datasources/csv/CSVUtils.scala | 67 +--- .../datasources/csv/UnivocityGenerator.scala | 1 + .../datasources/csv/UnivocityParser.scala | 352 ------------------- .../datasources/json/JsonDataSource.scala | 1 + .../scala/org/apache/spark/sql/functions.scala | 32 ++ .../sql-tests/inputs/csv-functions.sql | 9 + .../sql-tests/results/csv-functions.sql.out | 69 ++++ .../apache/spark/sql/CsvFunctionsSuite.scala | 62 ++++ .../datasources/csv/CSVInferSchemaSuite.scala | 1 + .../datasources/csv/CSVUtilsSuite.scala | 61 ---- .../datasources/csv/UnivocityParserSuite.scala | 2 +- 35 files changed, 1531 insertions(+), 930 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/R/pkg/NAMESPACE ---------------------------------------------------------------------- diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 96ff389..c512284 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -274,6 +274,7 @@ exportMethods("%<=>%", "floor", "format_number", "format_string", + "from_csv", "from_json", "from_unixtime", "from_utc_timestamp", http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/R/pkg/R/functions.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 6a8fef5..d2ca1d6 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -188,6 +188,7 @@ NULL #' \item \code{to_json}: it is the column containing the struct, array of the structs, #' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. +#' \item \code{from_csv}: it is the column containing the CSV string. #' } #' @param y Column to compute on. #' @param value A value to compute on. @@ -196,6 +197,13 @@ NULL #' \item \code{array_position}: a value to locate in the given array. #' \item \code{array_remove}: a value to remove in the given array. #' } +#' @param schema +#' \itemize{ +#' \item \code{from_json}: a structType object to use as the schema to use +#' when parsing the JSON string. Since Spark 2.3, the DDL-formatted string is +#' also supported for the schema. +#' \item \code{from_csv}: a DDL-formatted string +#' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains #' additional named properties to control how it is converted, accepts the same #' options as the JSON data source. Additionally \code{to_json} supports the "pretty" @@ -2165,8 +2173,6 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' to \code{TRUE}. If the string is unparseable, the Column will contain the value NA. #' #' @rdname column_collection_functions -#' @param schema a structType object to use as the schema to use when parsing the JSON string. -#' Since Spark 2.3, the DDL-formatted string is also supported for the schema. #' @param as.json.array indicating if input string is JSON array of objects or a single object. #' @aliases from_json from_json,Column,characterOrstructType-method #' @examples @@ -2204,6 +2210,36 @@ setMethod("from_json", signature(x = "Column", schema = "characterOrstructType") }) #' @details +#' \code{from_csv}: Parses a column containing a CSV string into a Column of \code{structType} +#' with the specified \code{schema}. +#' If the string is unparseable, the Column will contain the value NA. +#' +#' @rdname column_collection_functions +#' @aliases from_csv from_csv,Column,character-method +#' @examples +#' +#' \dontrun{ +#' df <- sql("SELECT 'Amsterdam,2018' as csv") +#' schema <- "city STRING, year INT" +#' head(select(df, from_csv(df$csv, schema)))} +#' @note from_csv since 3.0.0 +setMethod("from_csv", signature(x = "Column", schema = "characterOrColumn"), + function(x, schema, ...) { + if (class(schema) == "Column") { + jschema <- schema@jc + } else if (is.character(schema)) { + jschema <- callJStatic("org.apache.spark.sql.functions", "lit", schema) + } else { + stop("schema argument should be a column or character") + } + options <- varargsToStrEnv(...) + jc <- callJStatic("org.apache.spark.sql.functions", + "from_csv", + x@jc, jschema, options) + column(jc) + }) + +#' @details #' \code{from_utc_timestamp}: This is a common function for databases supporting TIMESTAMP WITHOUT #' TIMEZONE. This function takes a timestamp which is timezone-agnostic, and interprets it as a #' timestamp in UTC, and renders that timestamp as a timestamp in the given time zone. http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/R/pkg/R/generics.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 697d124..d501f73 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -984,6 +984,10 @@ setGeneric("format_string", function(format, x, ...) { standardGeneric("format_s #' @name NULL setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("from_csv", function(x, schema, ...) { standardGeneric("from_csv") }) + #' @rdname column_datetime_functions #' @name NULL setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/R/pkg/tests/fulltests/test_sparkSQL.R ---------------------------------------------------------------------- diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 5cc75aa..5ad5d78 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1647,6 +1647,13 @@ test_that("column functions", { expect_equal(collect(select(df, bround(df$x, 0)))[[1]][1], 2) expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4) + # Test from_csv() + df <- as.DataFrame(list(list("col" = "1"))) + c <- collect(select(df, alias(from_csv(df$col, "a INT"), "csv"))) + expect_equal(c[[1]][[1]]$a, 1) + c <- collect(select(df, alias(from_csv(df$col, lit("a INT")), "csv"))) + expect_equal(c[[1]][[1]]$a, 1) + # Test to_json(), from_json() df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") j <- collect(select(df, alias(to_json(df$people), "json"))) http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/python/pyspark/sql/functions.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5425d31..32d7f02 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -25,9 +25,12 @@ import warnings if sys.version < "3": from itertools import imap as map +if sys.version >= '3': + basestring = str + from pyspark import since, SparkContext from pyspark.rdd import ignore_unicode_prefix, PythonEvalType -from pyspark.sql.column import Column, _to_java_column, _to_seq +from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import StringType, DataType # Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 @@ -2678,6 +2681,38 @@ def sequence(start, stop, step=None): _to_java_column(start), _to_java_column(stop), _to_java_column(step))) +@ignore_unicode_prefix +@since(3.0) +def from_csv(col, schema, options={}): + """ + Parses a column containing a CSV string to a row with the specified schema. + Returns `null`, in the case of an unparseable string. + + :param col: string column in CSV format + :param schema: a string with schema in DDL format to use when parsing the CSV column. + :param options: options to control parsing. accepts the same options as the CSV datasource + + >>> data = [(1, '1')] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_csv(df.value, "a INT").alias("csv")).collect() + [Row(csv=Row(a=1))] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_csv(df.value, lit("a INT")).alias("csv")).collect() + [Row(csv=Row(a=1))] + """ + + sc = SparkContext._active_spark_context + if isinstance(schema, basestring): + schema = _create_column_from_literal(schema) + elif isinstance(schema, Column): + schema = _to_java_column(schema) + else: + raise TypeError("schema argument should be a column or string") + + jc = sc._jvm.functions.from_csv(_to_java_column(col), schema, options) + return Column(jc) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/catalyst/pom.xml ---------------------------------------------------------------------- diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 2e7df4f..16ecebf 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -103,6 +103,12 @@ <groupId>commons-codec</groupId> <artifactId>commons-codec</artifactId> </dependency> + <dependency> + <groupId>com.univocity</groupId> + <artifactId>univocity-parsers</artifactId> + <version>2.7.3</version> + <type>jar</type> + </dependency> </dependencies> <build> <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7dafebf..38f5c02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -520,7 +520,10 @@ object FunctionRegistry { castAlias("date", DateType), castAlias("timestamp", TimestampType), castAlias("binary", BinaryType), - castAlias("string", StringType) + castAlias("string", StringType), + + // csv + expression[CsvToStructs]("from_csv") ) val builtin: SimpleFunctionRegistry = { http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala new file mode 100644 index 0000000..bbe2783 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.csv + +object CSVExprUtils { + /** + * Filter ignorable rows for CSV iterator (lines empty and starting with `comment`). + * This is currently being used in CSV reading path and CSV schema inference. + */ + def filterCommentAndEmpty(iter: Iterator[String], options: CSVOptions): Iterator[String] = { + iter.filter { line => + line.trim.nonEmpty && !line.startsWith(options.comment.toString) + } + } + + def skipComments(iter: Iterator[String], options: CSVOptions): Iterator[String] = { + if (options.isCommentSet) { + val commentPrefix = options.comment.toString + iter.dropWhile { line => + line.trim.isEmpty || line.trim.startsWith(commentPrefix) + } + } else { + iter.dropWhile(_.trim.isEmpty) + } + } + + /** + * Extracts header and moves iterator forward so that only data remains in it + */ + def extractHeader(iter: Iterator[String], options: CSVOptions): Option[String] = { + val nonEmptyLines = skipComments(iter, options) + if (nonEmptyLines.hasNext) { + Some(nonEmptyLines.next()) + } else { + None + } + } + + /** + * Helper method that converts string representation of a character to actual character. + * It handles some Java escaped strings and throws exception if given string is longer than one + * character. + */ + @throws[IllegalArgumentException] + def toChar(str: String): Char = { + (str: Seq[Char]) match { + case Seq() => throw new IllegalArgumentException("Delimiter cannot be empty string") + case Seq('\\') => throw new IllegalArgumentException("Single backslash is prohibited." + + " It has special meaning as beginning of an escape sequence." + + " To get the backslash character, pass a string with two backslashes as the delimiter.") + case Seq(c) => c + case Seq('\\', 't') => '\t' + case Seq('\\', 'r') => '\r' + case Seq('\\', 'b') => '\b' + case Seq('\\', 'f') => '\f' + // In case user changes quote char and uses \" as delimiter in options + case Seq('\\', '\"') => '\"' + case Seq('\\', '\'') => '\'' + case Seq('\\', '\\') => '\\' + case _ if str == """\u0000""" => '\u0000' + case Seq('\\', _) => + throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str") + case _ => + throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str") + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala new file mode 100644 index 0000000..c39f77e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.csv + +import com.univocity.parsers.csv.CsvParser + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + +/** + * Checks that column names in a CSV header and field names in the schema are the same + * by taking into account case sensitivity. + * + * @param schema provided (or inferred) schema to which CSV must conform. + * @param options parsed CSV options. + * @param source name of CSV source that are currently checked. It is used in error messages. + * @param isStartOfFile indicates if the currently processing partition is the start of the file. + * if unknown or not applicable (for instance when the input is a dataset), + * can be omitted. + */ +class CSVHeaderChecker( + schema: StructType, + options: CSVOptions, + source: String, + isStartOfFile: Boolean = false) extends Logging { + + // Indicates if it is set to `false`, comparison of column names and schema field + // names is not case sensitive. + private val caseSensitive = SQLConf.get.caseSensitiveAnalysis + + // Indicates if it is `true`, column names are ignored otherwise the CSV column + // names are checked for conformance to the schema. In the case if + // the column name don't conform to the schema, an exception is thrown. + private val enforceSchema = options.enforceSchema + + /** + * Checks that column names in a CSV header and field names in the schema are the same + * by taking into account case sensitivity. + * + * @param columnNames names of CSV columns that must be checked against to the schema. + */ + private def checkHeaderColumnNames(columnNames: Array[String]): Unit = { + if (columnNames != null) { + val fieldNames = schema.map(_.name).toIndexedSeq + val (headerLen, schemaSize) = (columnNames.size, fieldNames.length) + var errorMessage: Option[String] = None + + if (headerLen == schemaSize) { + var i = 0 + while (errorMessage.isEmpty && i < headerLen) { + var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i)) + if (!caseSensitive) { + // scalastyle:off caselocale + nameInSchema = nameInSchema.toLowerCase + nameInHeader = nameInHeader.toLowerCase + // scalastyle:on caselocale + } + if (nameInHeader != nameInSchema) { + errorMessage = Some( + s"""|CSV header does not conform to the schema. + | Header: ${columnNames.mkString(", ")} + | Schema: ${fieldNames.mkString(", ")} + |Expected: ${fieldNames(i)} but found: ${columnNames(i)} + |$source""".stripMargin) + } + i += 1 + } + } else { + errorMessage = Some( + s"""|Number of column in CSV header is not equal to number of fields in the schema: + | Header length: $headerLen, schema size: $schemaSize + |$source""".stripMargin) + } + + errorMessage.foreach { msg => + if (enforceSchema) { + logWarning(msg) + } else { + throw new IllegalArgumentException(msg) + } + } + } + } + + // This is currently only used to parse CSV from Dataset[String]. + def checkHeaderColumnNames(line: String): Unit = { + if (options.headerFlag) { + val parser = new CsvParser(options.asParserSettings) + checkHeaderColumnNames(parser.parseLine(line)) + } + } + + // This is currently only used to parse CSV with multiLine mode. + private[csv] def checkHeaderColumnNames(tokenizer: CsvParser): Unit = { + assert(options.multiLine, "This method should be executed with multiLine.") + if (options.headerFlag) { + val firstRecord = tokenizer.parseNext() + checkHeaderColumnNames(firstRecord) + } + } + + // This is currently only used to parse CSV with non-multiLine mode. + private[csv] def checkHeaderColumnNames(lines: Iterator[String], tokenizer: CsvParser): Unit = { + assert(!options.multiLine, "This method should not be executed with multiline.") + // Checking that column names in the header are matched to field names of the schema. + // The header will be removed from lines. + // Note: if there are only comments in the first block, the header would probably + // be not extracted. + if (options.headerFlag && isStartOfFile) { + CSVExprUtils.extractHeader(lines, options).foreach { header => + checkHeaderColumnNames(tokenizer.parseLine(header)) + } + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala new file mode 100644 index 0000000..3e25d82 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.csv + +import java.nio.charset.StandardCharsets +import java.util.{Locale, TimeZone} + +import com.univocity.parsers.csv.{CsvParserSettings, CsvWriterSettings, UnescapedQuoteHandling} +import org.apache.commons.lang3.time.FastDateFormat + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.util._ + +class CSVOptions( + @transient val parameters: CaseInsensitiveMap[String], + val columnPruning: Boolean, + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String) + extends Logging with Serializable { + + def this( + parameters: Map[String, String], + columnPruning: Boolean, + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String = "") = { + this( + CaseInsensitiveMap(parameters), + columnPruning, + defaultTimeZoneId, + defaultColumnNameOfCorruptRecord) + } + + private def getChar(paramName: String, default: Char): Char = { + val paramValue = parameters.get(paramName) + paramValue match { + case None => default + case Some(null) => default + case Some(value) if value.length == 0 => '\u0000' + case Some(value) if value.length == 1 => value.charAt(0) + case _ => throw new RuntimeException(s"$paramName cannot be more than one character") + } + } + + private def getInt(paramName: String, default: Int): Int = { + val paramValue = parameters.get(paramName) + paramValue match { + case None => default + case Some(null) => default + case Some(value) => try { + value.toInt + } catch { + case e: NumberFormatException => + throw new RuntimeException(s"$paramName should be an integer. Found $value") + } + } + } + + private def getBool(paramName: String, default: Boolean = false): Boolean = { + val param = parameters.getOrElse(paramName, default.toString) + if (param == null) { + default + } else if (param.toLowerCase(Locale.ROOT) == "true") { + true + } else if (param.toLowerCase(Locale.ROOT) == "false") { + false + } else { + throw new Exception(s"$paramName flag can be true or false") + } + } + + val delimiter = CSVExprUtils.toChar( + parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) + val parseMode: ParseMode = + parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode) + val charset = parameters.getOrElse("encoding", + parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) + + val quote = getChar("quote", '\"') + val escape = getChar("escape", '\\') + val charToEscapeQuoteEscaping = parameters.get("charToEscapeQuoteEscaping") match { + case None => None + case Some(null) => None + case Some(value) if value.length == 0 => None + case Some(value) if value.length == 1 => Some(value.charAt(0)) + case _ => + throw new RuntimeException("charToEscapeQuoteEscaping cannot be more than one character") + } + val comment = getChar("comment", '\u0000') + + val headerFlag = getBool("header") + val inferSchemaFlag = getBool("inferSchema") + val ignoreLeadingWhiteSpaceInRead = getBool("ignoreLeadingWhiteSpace", default = false) + val ignoreTrailingWhiteSpaceInRead = getBool("ignoreTrailingWhiteSpace", default = false) + + // For write, both options were `true` by default. We leave it as `true` for + // backwards compatibility. + val ignoreLeadingWhiteSpaceFlagInWrite = getBool("ignoreLeadingWhiteSpace", default = true) + val ignoreTrailingWhiteSpaceFlagInWrite = getBool("ignoreTrailingWhiteSpace", default = true) + + val columnNameOfCorruptRecord = + parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + + val nullValue = parameters.getOrElse("nullValue", "") + + val nanValue = parameters.getOrElse("nanValue", "NaN") + + val positiveInf = parameters.getOrElse("positiveInf", "Inf") + val negativeInf = parameters.getOrElse("negativeInf", "-Inf") + + + val compressionCodec: Option[String] = { + val name = parameters.get("compression").orElse(parameters.get("codec")) + name.map(CompressionCodecs.getCodecClassName) + } + + val timeZone: TimeZone = DateTimeUtils.getTimeZone( + parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) + + // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. + val dateFormat: FastDateFormat = + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) + + val timestampFormat: FastDateFormat = + FastDateFormat.getInstance( + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) + + val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) + + val maxColumns = getInt("maxColumns", 20480) + + val maxCharsPerColumn = getInt("maxCharsPerColumn", -1) + + val escapeQuotes = getBool("escapeQuotes", true) + + val quoteAll = getBool("quoteAll", false) + + val inputBufferSize = 128 + + val isCommentSet = this.comment != '\u0000' + + val samplingRatio = + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + + /** + * Forcibly apply the specified or inferred schema to datasource files. + * If the option is enabled, headers of CSV files will be ignored. + */ + val enforceSchema = getBool("enforceSchema", default = true) + + + /** + * String representation of an empty value in read and in write. + */ + val emptyValue = parameters.get("emptyValue") + /** + * The string is returned when CSV reader doesn't have any characters for input value, + * or an empty quoted string `""`. Default value is empty string. + */ + val emptyValueInRead = emptyValue.getOrElse("") + /** + * The value is used instead of an empty string in write. Default value is `""` + */ + val emptyValueInWrite = emptyValue.getOrElse("\"\"") + + def asWriterSettings: CsvWriterSettings = { + val writerSettings = new CsvWriterSettings() + val format = writerSettings.getFormat + format.setDelimiter(delimiter) + format.setQuote(quote) + format.setQuoteEscape(escape) + charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping) + format.setComment(comment) + writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) + writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) + writerSettings.setNullValue(nullValue) + writerSettings.setEmptyValue(emptyValueInWrite) + writerSettings.setSkipEmptyLines(true) + writerSettings.setQuoteAllFields(quoteAll) + writerSettings.setQuoteEscapingEnabled(escapeQuotes) + writerSettings + } + + def asParserSettings: CsvParserSettings = { + val settings = new CsvParserSettings() + val format = settings.getFormat + format.setDelimiter(delimiter) + format.setQuote(quote) + format.setQuoteEscape(escape) + charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping) + format.setComment(comment) + settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceInRead) + settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceInRead) + settings.setReadInputOnSeparateThread(false) + settings.setInputBufferSize(inputBufferSize) + settings.setMaxColumns(maxColumns) + settings.setNullValue(nullValue) + settings.setEmptyValue(emptyValueInRead) + settings.setMaxCharsPerColumn(maxCharsPerColumn) + settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) + settings + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala new file mode 100644 index 0000000..46ed58e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.csv + +import java.io.InputStream +import java.math.BigDecimal + +import scala.util.Try +import scala.util.control.NonFatal + +import com.univocity.parsers.csv.CsvParser + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + + +/** + * Constructs a parser for a given schema that translates CSV data to an [[InternalRow]]. + * + * @param dataSchema The CSV data schema that is specified by the user, or inferred from underlying + * data files. + * @param requiredSchema The schema of the data that should be output for each row. This should be a + * subset of the columns in dataSchema. + * @param options Configuration options for a CSV parser. + */ +class UnivocityParser( + dataSchema: StructType, + requiredSchema: StructType, + val options: CSVOptions) extends Logging { + require(requiredSchema.toSet.subsetOf(dataSchema.toSet), + s"requiredSchema (${requiredSchema.catalogString}) should be the subset of " + + s"dataSchema (${dataSchema.catalogString}).") + + def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) + + // A `ValueConverter` is responsible for converting the given value to a desired type. + private type ValueConverter = String => Any + + // This index is used to reorder parsed tokens + private val tokenIndexArr = + requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))).toArray + + // When column pruning is enabled, the parser only parses the required columns based on + // their positions in the data schema. + private val parsedSchema = if (options.columnPruning) requiredSchema else dataSchema + + val tokenizer = { + val parserSetting = options.asParserSettings + // When to-be-parsed schema is shorter than the to-be-read data schema, we let Univocity CSV + // parser select a sequence of fields for reading by their positions. + // if (options.columnPruning && requiredSchema.length < dataSchema.length) { + if (parsedSchema.length < dataSchema.length) { + parserSetting.selectIndexes(tokenIndexArr: _*) + } + new CsvParser(parserSetting) + } + + private val row = new GenericInternalRow(requiredSchema.length) + + // Retrieve the raw record string. + private def getCurrentInput: UTF8String = { + UTF8String.fromString(tokenizer.getContext.currentParsedContent().stripLineEnd) + } + + // This parser first picks some tokens from the input tokens, according to the required schema, + // then parse these tokens and put the values in a row, with the order specified by the required + // schema. + // + // For example, let's say there is CSV data as below: + // + // a,b,c + // 1,2,A + // + // So the CSV data schema is: ["a", "b", "c"] + // And let's say the required schema is: ["c", "b"] + // + // with the input tokens, + // + // input tokens - [1, 2, "A"] + // + // Each input token is placed in each output row's position by mapping these. In this case, + // + // output row - ["A", 2] + private val valueConverters: Array[ValueConverter] = { + requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + } + + /** + * Create a converter which converts the string value to a value according to a desired type. + * Currently, we do not support complex types (`ArrayType`, `MapType`, `StructType`). + * + * For other nullable types, returns null if it is null or equals to the value specified + * in `nullValue` option. + */ + def makeConverter( + name: String, + dataType: DataType, + nullable: Boolean = true, + options: CSVOptions): ValueConverter = dataType match { + case _: ByteType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toByte) + + case _: ShortType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toShort) + + case _: IntegerType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toInt) + + case _: LongType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toLong) + + case _: FloatType => (d: String) => + nullSafeDatum(d, name, nullable, options) { + case options.nanValue => Float.NaN + case options.negativeInf => Float.NegativeInfinity + case options.positiveInf => Float.PositiveInfinity + case datum => datum.toFloat + } + + case _: DoubleType => (d: String) => + nullSafeDatum(d, name, nullable, options) { + case options.nanValue => Double.NaN + case options.negativeInf => Double.NegativeInfinity + case options.positiveInf => Double.PositiveInfinity + case datum => datum.toDouble + } + + case _: BooleanType => (d: String) => + nullSafeDatum(d, name, nullable, options)(_.toBoolean) + + case dt: DecimalType => (d: String) => + nullSafeDatum(d, name, nullable, options) { datum => + val value = new BigDecimal(datum.replaceAll(",", "")) + Decimal(value, dt.precision, dt.scale) + } + + case _: TimestampType => (d: String) => + nullSafeDatum(d, name, nullable, options) { datum => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. + Try(options.timestampFormat.parse(datum).getTime * 1000L) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.stringToTime(datum).getTime * 1000L + } + } + + case _: DateType => (d: String) => + nullSafeDatum(d, name, nullable, options) { datum => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681.x + Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime)) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) + } + } + + case _: StringType => (d: String) => + nullSafeDatum(d, name, nullable, options)(UTF8String.fromString) + + case udt: UserDefinedType[_] => (datum: String) => + makeConverter(name, udt.sqlType, nullable, options) + + // We don't actually hit this exception though, we keep it for understandability + case _ => throw new RuntimeException(s"Unsupported type: ${dataType.typeName}") + } + + private def nullSafeDatum( + datum: String, + name: String, + nullable: Boolean, + options: CSVOptions)(converter: ValueConverter): Any = { + if (datum == options.nullValue || datum == null) { + if (!nullable) { + throw new RuntimeException(s"null value found but field $name is not nullable.") + } + null + } else { + converter.apply(datum) + } + } + + /** + * Parses a single CSV string and turns it into either one resulting row or no row (if the + * the record is malformed). + */ + def parse(input: String): InternalRow = convert(tokenizer.parseLine(input)) + + private val getToken = if (options.columnPruning) { + (tokens: Array[String], index: Int) => tokens(index) + } else { + (tokens: Array[String], index: Int) => tokens(tokenIndexArr(index)) + } + + private def convert(tokens: Array[String]): InternalRow = { + if (tokens == null) { + throw BadRecordException( + () => getCurrentInput, + () => None, + new RuntimeException("Malformed CSV record")) + } else if (tokens.length != parsedSchema.length) { + // If the number of tokens doesn't match the schema, we should treat it as a malformed record. + // However, we still have chance to parse some of the tokens, by adding extra null tokens in + // the tail if the number is smaller, or by dropping extra tokens if the number is larger. + val checkedTokens = if (parsedSchema.length > tokens.length) { + tokens ++ new Array[String](parsedSchema.length - tokens.length) + } else { + tokens.take(parsedSchema.length) + } + def getPartialResult(): Option[InternalRow] = { + try { + Some(convert(checkedTokens)) + } catch { + case _: BadRecordException => None + } + } + // For records with less or more tokens than the schema, tries to return partial results + // if possible. + throw BadRecordException( + () => getCurrentInput, + () => getPartialResult(), + new RuntimeException("Malformed CSV record")) + } else { + try { + // When the length of the returned tokens is identical to the length of the parsed schema, + // we just need to convert the tokens that correspond to the required columns. + var i = 0 + while (i < requiredSchema.length) { + row(i) = valueConverters(i).apply(getToken(tokens, i)) + i += 1 + } + row + } catch { + case NonFatal(e) => + // For corrupted records with the number of tokens same as the schema, + // CSV reader doesn't support partial results. All fields other than the field + // configured by `columnNameOfCorruptRecord` are set to `null`. + throw BadRecordException(() => getCurrentInput, () => None, e) + } + } + } +} + +private[sql] object UnivocityParser { + + /** + * Parses a stream that contains CSV strings and turns it into an iterator of tokens. + */ + def tokenizeStream( + inputStream: InputStream, + shouldDropHeader: Boolean, + tokenizer: CsvParser): Iterator[Array[String]] = { + val handleHeader: () => Unit = + () => if (shouldDropHeader) tokenizer.parseNext + + convertStream(inputStream, tokenizer, handleHeader)(tokens => tokens) + } + + /** + * Parses a stream that contains CSV strings and turns it into an iterator of rows. + */ + def parseStream( + inputStream: InputStream, + parser: UnivocityParser, + headerChecker: CSVHeaderChecker, + schema: StructType): Iterator[InternalRow] = { + val tokenizer = parser.tokenizer + val safeParser = new FailureSafeParser[Array[String]]( + input => Seq(parser.convert(input)), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) + + val handleHeader: () => Unit = + () => headerChecker.checkHeaderColumnNames(tokenizer) + + convertStream(inputStream, tokenizer, handleHeader) { tokens => + safeParser.parse(tokens) + }.flatten + } + + private def convertStream[T]( + inputStream: InputStream, + tokenizer: CsvParser, + handleHeader: () => Unit)( + convert: Array[String] => T) = new Iterator[T] { + tokenizer.beginParsing(inputStream) + + // We can handle header here since here the stream is open. + handleHeader() + + private var nextRecord = tokenizer.parseNext() + + override def hasNext: Boolean = nextRecord != null + + override def next(): T = { + if (!hasNext) { + throw new NoSuchElementException("End of stream") + } + val curRecord = convert(nextRecord) + nextRecord = tokenizer.parseNext() + curRecord + } + } + + /** + * Parses an iterator that contains CSV strings and turns it into an iterator of rows. + */ + def parseIterator( + lines: Iterator[String], + parser: UnivocityParser, + headerChecker: CSVHeaderChecker, + schema: StructType): Iterator[InternalRow] = { + headerChecker.checkHeaderColumnNames(lines, parser.tokenizer) + + val options = parser.options + + val filteredLines: Iterator[String] = CSVExprUtils.filterCommentAndEmpty(lines, options) + + val safeParser = new FailureSafeParser[String]( + input => Seq(parser.parse(input)), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) + filteredLines.flatMap(safeParser.parse) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala new file mode 100644 index 0000000..e570889 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData +import org.apache.spark.sql.types.{MapType, StringType, StructType} + +object ExprUtils { + + def evalSchemaExpr(exp: Expression): StructType = exp match { + case Literal(s, StringType) => StructType.fromDDL(s.toString) + case e => throw new AnalysisException( + s"Schema should be specified in DDL format as a string literal instead of ${e.sql}") + } + + def convertToMapData(exp: Expression): Map[String, String] = exp match { + case m: CreateMap + if m.dataType.acceptsType(MapType(StringType, StringType, valueContainsNull = false)) => + val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData] + ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) => + key.toString -> value.toString + } + case m: CreateMap => + throw new AnalysisException( + s"A type of keys and values in map() must be string, but got ${m.dataType.catalogString}") + case _ => + throw new AnalysisException("Must use a map() function for options") + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala new file mode 100644 index 0000000..a63b624 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Converts a CSV input string to a [[StructType]] with the specified schema. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(csvStr, schema[, options]) - Returns a struct value with the given `csvStr` and `schema`.", + examples = """ + Examples: + > SELECT _FUNC_('1, 0.8', 'a INT, b DOUBLE'); + {"a":1, "b":0.8} + > SELECT _FUNC_('26/08/2015', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')) + {"time":2015-08-26 00:00:00.0} + """, + since = "3.0.0") +// scalastyle:on line.size.limit +case class CsvToStructs( + schema: StructType, + options: Map[String, String], + child: Expression, + timeZoneId: Option[String] = None) + extends UnaryExpression + with TimeZoneAwareExpression + with CodegenFallback + with ExpectsInputTypes + with NullIntolerant { + + override def nullable: Boolean = child.nullable + + // The CSV input data might be missing certain fields. We force the nullability + // of the user-provided schema to avoid data corruptions. + val nullableSchema: StructType = schema.asNullable + + // Used in `FunctionRegistry` + def this(child: Expression, schema: Expression, options: Map[String, String]) = + this( + schema = ExprUtils.evalSchemaExpr(schema), + options = options, + child = child, + timeZoneId = None) + + def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String]) + + def this(child: Expression, schema: Expression, options: Expression) = + this( + schema = ExprUtils.evalSchemaExpr(schema), + options = ExprUtils.convertToMapData(options), + child = child, + timeZoneId = None) + + // This converts parsed rows to the desired output by the given schema. + @transient + lazy val converter = (rows: Iterator[InternalRow]) => { + if (rows.hasNext) { + val result = rows.next() + // CSV's parser produces one record only. + assert(!rows.hasNext) + result + } else { + throw new IllegalArgumentException("Expected one row from CSV parser.") + } + } + + @transient lazy val parser = { + val parsedOptions = new CSVOptions(options, columnPruning = true, timeZoneId.get) + val mode = parsedOptions.parseMode + if (mode != PermissiveMode && mode != FailFastMode) { + throw new AnalysisException(s"from_csv() doesn't support the ${mode.name} mode. " + + s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}.") + } + val actualSchema = + StructType(nullableSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) + val rawParser = new UnivocityParser(actualSchema, actualSchema, parsedOptions) + new FailureSafeParser[String]( + input => Seq(rawParser.parse(input)), + mode, + nullableSchema, + parsedOptions.columnNameOfCorruptRecord, + parsedOptions.multiLine) + } + + override def dataType: DataType = nullableSchema + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = { + copy(timeZoneId = Option(timeZoneId)) + } + + override def nullSafeEval(input: Any): Any = { + val csv = input.asInstanceOf[UTF8String].toString + converter(parser.parse(csv)) + } + + override def inputTypes: Seq[AbstractDataType] = StringType :: Nil +} http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index f5297dd..9f28483 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -539,7 +539,7 @@ case class JsonToStructs( def this(child: Expression, schema: Expression, options: Expression) = this( schema = JsonExprUtils.evalSchemaExpr(schema), - options = JsonExprUtils.convertToMapData(options), + options = ExprUtils.convertToMapData(options), child = child, timeZoneId = None) @@ -650,7 +650,7 @@ case class StructsToJson( def this(child: Expression) = this(Map.empty, child, None) def this(child: Expression, options: Expression) = this( - options = JsonExprUtils.convertToMapData(options), + options = ExprUtils.convertToMapData(options), child = child, timeZoneId = None) @@ -754,7 +754,7 @@ case class SchemaOfJson( def this(child: Expression, options: Expression) = this( child = child, - options = JsonExprUtils.convertToMapData(options)) + options = ExprUtils.convertToMapData(options)) @transient private lazy val jsonOptions = new JSONOptions(options, "UTC") @@ -777,7 +777,6 @@ case class SchemaOfJson( } object JsonExprUtils { - def evalSchemaExpr(exp: Expression): DataType = exp match { case Literal(s, StringType) => DataType.fromDDL(s.toString) case e @ SchemaOfJson(_: Literal, _) => @@ -787,18 +786,4 @@ object JsonExprUtils { "Schema should be specified in DDL format as a string literal" + s" or output of the schema_of_json function instead of ${e.sql}") } - - def convertToMapData(exp: Expression): Map[String, String] = exp match { - case m: CreateMap - if m.dataType.acceptsType(MapType(StringType, StringType, valueContainsNull = false)) => - val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData] - ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) => - key.toString -> value.toString - } - case m: CreateMap => - throw new AnalysisException( - s"A type of keys and values in map() must be string, but got ${m.dataType.catalogString}") - case _ => - throw new AnalysisException("Must use a map() function for options") - } } http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala new file mode 100644 index 0000000..fecfff5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +class FailureSafeParser[IN]( + rawParser: IN => Seq[InternalRow], + mode: ParseMode, + schema: StructType, + columnNameOfCorruptRecord: String, + isMultiLine: Boolean) { + + private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) + private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) + private val resultRow = new GenericInternalRow(schema.length) + private val nullResult = new GenericInternalRow(schema.length) + + // This function takes 2 parameters: an optional partial result, and the bad record. If the given + // schema doesn't contain a field for corrupted record, we just return the partial result or a + // row with all fields null. If the given schema contains a field for corrupted record, we will + // set the bad record to this field, and set other fields according to the partial result or null. + private val toResultRow: (Option[InternalRow], () => UTF8String) => InternalRow = { + if (corruptFieldIndex.isDefined) { + (row, badRecord) => { + var i = 0 + while (i < actualSchema.length) { + val from = actualSchema(i) + resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i, from.dataType)).orNull + i += 1 + } + resultRow(corruptFieldIndex.get) = badRecord() + resultRow + } + } else { + (row, _) => row.getOrElse(nullResult) + } + } + + private val skipParsing = !isMultiLine && mode == PermissiveMode && schema.isEmpty + + def parse(input: IN): Iterator[InternalRow] = { + try { + if (skipParsing) { + Iterator.single(InternalRow.empty) + } else { + rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) + } + } catch { + case e: BadRecordException => mode match { + case PermissiveMode => + Iterator(toResultRow(e.partialResult(), e.record)) + case DropMalformedMode => + Iterator.empty + case FailFastMode => + throw new SparkException("Malformed records are detected in record parsing. " + + s"Parse Mode: ${FailFastMode.name}.", e.cause) + } + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala new file mode 100644 index 0000000..838ac42 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.csv + +import org.apache.spark.SparkFunSuite + +class CSVExprUtilsSuite extends SparkFunSuite { + test("Can parse escaped characters") { + assert(CSVExprUtils.toChar("""\t""") === '\t') + assert(CSVExprUtils.toChar("""\r""") === '\r') + assert(CSVExprUtils.toChar("""\b""") === '\b') + assert(CSVExprUtils.toChar("""\f""") === '\f') + assert(CSVExprUtils.toChar("""\"""") === '\"') + assert(CSVExprUtils.toChar("""\'""") === '\'') + assert(CSVExprUtils.toChar("""\u0000""") === '\u0000') + assert(CSVExprUtils.toChar("""\\""") === '\\') + } + + test("Does not accept delimiter larger than one character") { + val exception = intercept[IllegalArgumentException]{ + CSVExprUtils.toChar("ab") + } + assert(exception.getMessage.contains("cannot be more than one character")) + } + + test("Throws exception for unsupported escaped characters") { + val exception = intercept[IllegalArgumentException]{ + CSVExprUtils.toChar("""\1""") + } + assert(exception.getMessage.contains("Unsupported special character for delimiter")) + } + + test("string with one backward slash is prohibited") { + val exception = intercept[IllegalArgumentException]{ + CSVExprUtils.toChar("""\""") + } + assert(exception.getMessage.contains("Single backslash is prohibited")) + } + + test("output proper error message for empty string") { + val exception = intercept[IllegalArgumentException]{ + CSVExprUtils.toChar("") + } + assert(exception.getMessage.contains("Delimiter cannot be empty string")) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala new file mode 100644 index 0000000..65987af --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.Calendar + +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.PlanTestBase +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with PlanTestBase { + val badCsv = "\u0000\u0000\u0000A\u0001AAA" + + val gmtId = Option(DateTimeUtils.TimeZoneGMT.getID) + + test("from_csv") { + val csvData = "1" + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal(csvData), gmtId), + InternalRow(1) + ) + } + + test("from_csv - invalid data") { + val csvData = "---" + val schema = StructType(StructField("a", DoubleType) :: Nil) + checkEvaluation( + CsvToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(csvData), gmtId), + InternalRow(null)) + + // Default mode is Permissive + checkEvaluation(CsvToStructs(schema, Map.empty, Literal(csvData), gmtId), InternalRow(null)) + } + + test("from_csv null input column") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId), + null + ) + } + + test("from_csv bad UTF-8") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal(badCsv), gmtId), + InternalRow(null)) + } + + test("from_csv with timestamp") { + val schema = StructType(StructField("t", TimestampType) :: Nil) + + val csvData1 = "2016-01-01T00:00:00.123Z" + var c = Calendar.getInstance(DateTimeUtils.TimeZoneGMT) + c.set(2016, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 123) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal(csvData1), gmtId), + InternalRow(c.getTimeInMillis * 1000L) + ) + // The result doesn't change because the CSV string includes timezone string ("Z" here), + // which means the string represents the timestamp string in the timezone regardless of + // the timeZoneId parameter. + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal(csvData1), Option("PST")), + InternalRow(c.getTimeInMillis * 1000L) + ) + + val csvData2 = "2016-01-01T00:00:00" + for (tz <- DateTimeTestUtils.outstandingTimezones) { + c = Calendar.getInstance(tz) + c.set(2016, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation( + CsvToStructs( + schema, + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), + Literal(csvData2), + Option(tz.getID)), + InternalRow(c.getTimeInMillis * 1000L) + ) + checkEvaluation( + CsvToStructs( + schema, + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", + DateTimeUtils.TIMEZONE_OPTION -> tz.getID), + Literal(csvData2), + gmtId), + InternalRow(c.getTimeInMillis * 1000L) + ) + } + } + + test("from_csv empty input column") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), + InternalRow(null) + ) + } + + test("forcing schema nullability") { + val input = """1,,"foo"""" + val csvSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val output = InternalRow(1L, null, UTF8String.fromString("foo")) + val expr = CsvToStructs(csvSchema, Map.empty, Literal.create(input, StringType), gmtId) + checkEvaluation(expr, output) + val schema = expr.dataType + val schemaToCompare = csvSchema.asNullable + assert(schemaToCompare == schema) + } + + + test("from_csv missing columns") { + val schema = new StructType() + .add("a", IntegerType) + .add("b", IntegerType) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal.create("1"), gmtId), + InternalRow(1, null) + ) + } + + test("unsupported mode") { + val csvData = "---" + val schema = StructType(StructField("a", DoubleType) :: Nil) + val exception = intercept[TestFailedException] { + checkEvaluation( + CsvToStructs(schema, Map("mode" -> DropMalformedMode.name), Literal(csvData), gmtId), + InternalRow(null)) + }.getCause + assert(exception.getMessage.contains("from_csv() doesn't support the DROPMALFORMED mode")) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 3af70b5..4f6d8b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -22,16 +22,17 @@ import java.util.{Locale, Properties} import scala.collection.JavaConverters._ import com.fasterxml.jackson.databind.ObjectMapper -import com.univocity.parsers.csv.CsvParser import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser} +import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala deleted file mode 100644 index 90e8166..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources - -import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.UTF8String - -class FailureSafeParser[IN]( - rawParser: IN => Seq[InternalRow], - mode: ParseMode, - schema: StructType, - columnNameOfCorruptRecord: String, - isMultiLine: Boolean) { - - private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) - private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) - private val resultRow = new GenericInternalRow(schema.length) - private val nullResult = new GenericInternalRow(schema.length) - - // This function takes 2 parameters: an optional partial result, and the bad record. If the given - // schema doesn't contain a field for corrupted record, we just return the partial result or a - // row with all fields null. If the given schema contains a field for corrupted record, we will - // set the bad record to this field, and set other fields according to the partial result or null. - private val toResultRow: (Option[InternalRow], () => UTF8String) => InternalRow = { - if (corruptFieldIndex.isDefined) { - (row, badRecord) => { - var i = 0 - while (i < actualSchema.length) { - val from = actualSchema(i) - resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i, from.dataType)).orNull - i += 1 - } - resultRow(corruptFieldIndex.get) = badRecord() - resultRow - } - } else { - (row, _) => row.getOrElse(nullResult) - } - } - - private val skipParsing = !isMultiLine && mode == PermissiveMode && schema.isEmpty - - def parse(input: IN): Iterator[InternalRow] = { - try { - if (skipParsing) { - Iterator.single(InternalRow.empty) - } else { - rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) - } - } catch { - case e: BadRecordException => mode match { - case PermissiveMode => - Iterator(toResultRow(e.partialResult(), e.record)) - case DropMalformedMode => - Iterator.empty - case FailFastMode => - throw new SparkException("Malformed records are detected in record parsing. " + - s"Parse Mode: ${FailFastMode.name}.", e.cause) - } - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 0b5a719..9e7b45d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 3de1c2d..954a5a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.mapreduce._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala deleted file mode 100644 index 558ee91..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.csv - -import com.univocity.parsers.csv.CsvParser - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType - -/** - * Checks that column names in a CSV header and field names in the schema are the same - * by taking into account case sensitivity. - * - * @param schema provided (or inferred) schema to which CSV must conform. - * @param options parsed CSV options. - * @param source name of CSV source that are currently checked. It is used in error messages. - * @param isStartOfFile indicates if the currently processing partition is the start of the file. - * if unknown or not applicable (for instance when the input is a dataset), - * can be omitted. - */ -class CSVHeaderChecker( - schema: StructType, - options: CSVOptions, - source: String, - isStartOfFile: Boolean = false) extends Logging { - - // Indicates if it is set to `false`, comparison of column names and schema field - // names is not case sensitive. - private val caseSensitive = SQLConf.get.caseSensitiveAnalysis - - // Indicates if it is `true`, column names are ignored otherwise the CSV column - // names are checked for conformance to the schema. In the case if - // the column name don't conform to the schema, an exception is thrown. - private val enforceSchema = options.enforceSchema - - /** - * Checks that column names in a CSV header and field names in the schema are the same - * by taking into account case sensitivity. - * - * @param columnNames names of CSV columns that must be checked against to the schema. - */ - private def checkHeaderColumnNames(columnNames: Array[String]): Unit = { - if (columnNames != null) { - val fieldNames = schema.map(_.name).toIndexedSeq - val (headerLen, schemaSize) = (columnNames.size, fieldNames.length) - var errorMessage: Option[String] = None - - if (headerLen == schemaSize) { - var i = 0 - while (errorMessage.isEmpty && i < headerLen) { - var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i)) - if (!caseSensitive) { - // scalastyle:off caselocale - nameInSchema = nameInSchema.toLowerCase - nameInHeader = nameInHeader.toLowerCase - // scalastyle:on caselocale - } - if (nameInHeader != nameInSchema) { - errorMessage = Some( - s"""|CSV header does not conform to the schema. - | Header: ${columnNames.mkString(", ")} - | Schema: ${fieldNames.mkString(", ")} - |Expected: ${fieldNames(i)} but found: ${columnNames(i)} - |$source""".stripMargin) - } - i += 1 - } - } else { - errorMessage = Some( - s"""|Number of column in CSV header is not equal to number of fields in the schema: - | Header length: $headerLen, schema size: $schemaSize - |$source""".stripMargin) - } - - errorMessage.foreach { msg => - if (enforceSchema) { - logWarning(msg) - } else { - throw new IllegalArgumentException(msg) - } - } - } - } - - // This is currently only used to parse CSV from Dataset[String]. - def checkHeaderColumnNames(line: String): Unit = { - if (options.headerFlag) { - val parser = new CsvParser(options.asParserSettings) - checkHeaderColumnNames(parser.parseLine(line)) - } - } - - // This is currently only used to parse CSV with multiLine mode. - private[csv] def checkHeaderColumnNames(tokenizer: CsvParser): Unit = { - assert(options.multiLine, "This method should be executed with multiLine.") - if (options.headerFlag) { - val firstRecord = tokenizer.parseNext() - checkHeaderColumnNames(firstRecord) - } - } - - // This is currently only used to parse CSV with non-multiLine mode. - private[csv] def checkHeaderColumnNames(lines: Iterator[String], tokenizer: CsvParser): Unit = { - assert(!options.multiLine, "This method should not be executed with multiline.") - // Checking that column names in the header are matched to field names of the schema. - // The header will be removed from lines. - // Note: if there are only comments in the first block, the header would probably - // be not extracted. - if (options.headerFlag && isStartOfFile) { - CSVUtils.extractHeader(lines, options).foreach { header => - checkHeaderColumnNames(tokenizer.parseLine(header)) - } - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/e9af9460/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 3596ff1..4326a18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -23,6 +23,7 @@ import scala.util.control.Exception._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion +import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org