Repository: spark Updated Branches: refs/heads/master c5ef477d2 -> c9667aff4
[SPARK-25672][SQL] schema_of_csv() - schema inference from an example ## What changes were proposed in this pull request? In the PR, I propose to add new function - *schema_of_csv()* which infers schema of CSV string literal. The result of the function is a string containing a schema in DDL format. For example: ```sql select schema_of_csv('1|abc', map('delimiter', '|')) ``` ``` struct<_c0:int,_c1:string> ``` ## How was this patch tested? Added new tests to `CsvFunctionsSuite`, `CsvExpressionsSuite` and SQL tests to `csv-functions.sql` Closes #22666 from MaxGekk/schema_of_csv-function. Lead-authored-by: hyukjinkwon <gurwls...@apache.org> Co-authored-by: Maxim Gekk <maxim.g...@databricks.com> Signed-off-by: hyukjinkwon <gurwls...@apache.org> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c9667aff Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c9667aff Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c9667aff Branch: refs/heads/master Commit: c9667aff4f4888b650fad2ed41698025b1e84166 Parents: c5ef477 Author: hyukjinkwon <gurwls...@apache.org> Authored: Thu Nov 1 09:14:16 2018 +0800 Committer: hyukjinkwon <gurwls...@apache.org> Committed: Thu Nov 1 09:14:16 2018 +0800 ---------------------------------------------------------------------- python/pyspark/sql/functions.py | 41 +++- .../catalyst/analysis/FunctionRegistry.scala | 3 +- .../spark/sql/catalyst/csv/CSVInferSchema.scala | 220 +++++++++++++++++++ .../sql/catalyst/expressions/ExprUtils.scala | 33 ++- .../catalyst/expressions/csvExpressions.scala | 54 +++++ .../catalyst/expressions/jsonExpressions.scala | 16 +- .../sql/catalyst/csv/CSVInferSchemaSuite.scala | 142 ++++++++++++ .../sql/catalyst/csv/UnivocityParserSuite.scala | 199 +++++++++++++++++ .../expressions/CsvExpressionsSuite.scala | 10 + .../datasources/csv/CSVDataSource.scala | 2 +- .../datasources/csv/CSVInferSchema.scala | 214 ------------------ .../scala/org/apache/spark/sql/functions.scala | 35 +++ .../sql-tests/inputs/csv-functions.sql | 8 + .../sql-tests/results/csv-functions.sql.out | 54 ++++- .../apache/spark/sql/CsvFunctionsSuite.scala | 15 ++ .../datasources/csv/CSVInferSchemaSuite.scala | 143 ------------ .../datasources/csv/UnivocityParserSuite.scala | 200 ----------------- 17 files changed, 803 insertions(+), 586 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/python/pyspark/sql/functions.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ca2a256..beb1a06 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2364,6 +2364,33 @@ def schema_of_json(json, options={}): return Column(jc) +@ignore_unicode_prefix +@since(3.0) +def schema_of_csv(csv, options={}): + """ + Parses a CSV string and infers its schema in DDL format. + + :param col: a CSV string or a string literal containing a CSV string. + :param options: options to control parsing. accepts the same options as the CSV datasource + + >>> df = spark.range(1) + >>> df.select(schema_of_csv(lit('1|a'), {'sep':'|'}).alias("csv")).collect() + [Row(csv=u'struct<_c0:int,_c1:string>')] + >>> df.select(schema_of_csv('1|a', {'sep':'|'}).alias("csv")).collect() + [Row(csv=u'struct<_c0:int,_c1:string>')] + """ + if isinstance(csv, basestring): + col = _create_column_from_literal(csv) + elif isinstance(csv, Column): + col = _to_java_column(csv) + else: + raise TypeError("schema argument should be a column or string") + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.schema_of_csv(col, options) + return Column(jc) + + @since(1.5) def size(col): """ @@ -2664,13 +2691,13 @@ def from_csv(col, schema, options={}): :param schema: a string with schema in DDL format to use when parsing the CSV column. :param options: options to control parsing. accepts the same options as the CSV datasource - >>> data = [(1, '1')] - >>> df = spark.createDataFrame(data, ("key", "value")) - >>> df.select(from_csv(df.value, "a INT").alias("csv")).collect() - [Row(csv=Row(a=1))] - >>> df = spark.createDataFrame(data, ("key", "value")) - >>> df.select(from_csv(df.value, lit("a INT")).alias("csv")).collect() - [Row(csv=Row(a=1))] + >>> data = [("1,2,3",)] + >>> df = spark.createDataFrame(data, ("value",)) + >>> df.select(from_csv(df.value, "a INT, b INT, c INT").alias("csv")).collect() + [Row(csv=Row(a=1, b=2, c=3))] + >>> value = data[0][0] + >>> df.select(from_csv(df.value, schema_of_csv(value)).alias("csv")).collect() + [Row(csv=Row(_c0=1, _c1=2, _c2=3))] """ sc = SparkContext._active_spark_context http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index af6166b..cf8fb7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -526,7 +526,8 @@ object FunctionRegistry { castAlias("string", StringType), // csv - expression[CsvToStructs]("from_csv") + expression[CsvToStructs]("from_csv"), + expression[SchemaOfCsv]("schema_of_csv") ) val builtin: SimpleFunctionRegistry = { http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala new file mode 100644 index 0000000..799e999 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.csv + +import java.math.BigDecimal + +import scala.util.control.Exception.allCatch + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.TypeCoercion +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +object CSVInferSchema { + + /** + * Similar to the JSON schema inference + * 1. Infer type of each row + * 2. Merge row types to find common type + * 3. Replace any null types with string type + */ + def infer( + tokenRDD: RDD[Array[String]], + header: Array[String], + options: CSVOptions): StructType = { + val fields = if (options.inferSchemaFlag) { + val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) + val rootTypes: Array[DataType] = + tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes) + + toStructFields(rootTypes, header, options) + } else { + // By default fields are assumed to be StringType + header.map(fieldName => StructField(fieldName, StringType, nullable = true)) + } + + StructType(fields) + } + + def toStructFields( + fieldTypes: Array[DataType], + header: Array[String], + options: CSVOptions): Array[StructField] = { + header.zip(fieldTypes).map { case (thisHeader, rootType) => + val dType = rootType match { + case _: NullType => StringType + case other => other + } + StructField(thisHeader, dType, nullable = true) + } + } + + def inferRowType(options: CSVOptions) + (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { + var i = 0 + while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing. + rowSoFar(i) = inferField(rowSoFar(i), next(i), options) + i+=1 + } + rowSoFar + } + + def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = { + first.zipAll(second, NullType, NullType).map { case (a, b) => + compatibleType(a, b).getOrElse(NullType) + } + } + + /** + * Infer type of string field. Given known type Double, and a string "1", there is no + * point checking if it is an Int, as the final type must be Double or higher. + */ + def inferField(typeSoFar: DataType, field: String, options: CSVOptions): DataType = { + if (field == null || field.isEmpty || field == options.nullValue) { + typeSoFar + } else { + typeSoFar match { + case NullType => tryParseInteger(field, options) + case IntegerType => tryParseInteger(field, options) + case LongType => tryParseLong(field, options) + case _: DecimalType => + // DecimalTypes have different precisions and scales, so we try to find the common type. + compatibleType(typeSoFar, tryParseDecimal(field, options)).getOrElse(StringType) + case DoubleType => tryParseDouble(field, options) + case TimestampType => tryParseTimestamp(field, options) + case BooleanType => tryParseBoolean(field, options) + case StringType => StringType + case other: DataType => + throw new UnsupportedOperationException(s"Unexpected data type $other") + } + } + } + + private def isInfOrNan(field: String, options: CSVOptions): Boolean = { + field == options.nanValue || field == options.negativeInf || field == options.positiveInf + } + + private def tryParseInteger(field: String, options: CSVOptions): DataType = { + if ((allCatch opt field.toInt).isDefined) { + IntegerType + } else { + tryParseLong(field, options) + } + } + + private def tryParseLong(field: String, options: CSVOptions): DataType = { + if ((allCatch opt field.toLong).isDefined) { + LongType + } else { + tryParseDecimal(field, options) + } + } + + private def tryParseDecimal(field: String, options: CSVOptions): DataType = { + val decimalTry = allCatch opt { + // `BigDecimal` conversion can fail when the `field` is not a form of number. + val bigDecimal = new BigDecimal(field) + // Because many other formats do not support decimal, it reduces the cases for + // decimals by disallowing values having scale (eg. `1.1`). + if (bigDecimal.scale <= 0) { + // `DecimalType` conversion can fail when + // 1. The precision is bigger than 38. + // 2. scale is bigger than precision. + DecimalType(bigDecimal.precision, bigDecimal.scale) + } else { + tryParseDouble(field, options) + } + } + decimalTry.getOrElse(tryParseDouble(field, options)) + } + + private def tryParseDouble(field: String, options: CSVOptions): DataType = { + if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field, options)) { + DoubleType + } else { + tryParseTimestamp(field, options) + } + } + + private def tryParseTimestamp(field: String, options: CSVOptions): DataType = { + // This case infers a custom `dataFormat` is set. + if ((allCatch opt options.timestampFormat.parse(field)).isDefined) { + TimestampType + } else if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { + // We keep this for backwards compatibility. + TimestampType + } else { + tryParseBoolean(field, options) + } + } + + private def tryParseBoolean(field: String, options: CSVOptions): DataType = { + if ((allCatch opt field.toBoolean).isDefined) { + BooleanType + } else { + stringType() + } + } + + // Defining a function to return the StringType constant is necessary in order to work around + // a Scala compiler issue which leads to runtime incompatibilities with certain Spark versions; + // see issue #128 for more details. + private def stringType(): DataType = { + StringType + } + + /** + * Returns the common data type given two input data types so that the return type + * is compatible with both input data types. + */ + private def compatibleType(t1: DataType, t2: DataType): Option[DataType] = { + TypeCoercion.findTightestCommonType(t1, t2).orElse(findCompatibleTypeForCSV(t1, t2)) + } + + /** + * The following pattern matching represents additional type promotion rules that + * are CSV specific. + */ + private val findCompatibleTypeForCSV: (DataType, DataType) => Option[DataType] = { + case (StringType, t2) => Some(StringType) + case (t1, StringType) => Some(StringType) + + // These two cases below deal with when `IntegralType` is larger than `DecimalType`. + case (t1: IntegralType, t2: DecimalType) => + compatibleType(DecimalType.forType(t1), t2) + case (t1: DecimalType, t2: IntegralType) => + compatibleType(t1, DecimalType.forType(t2)) + + // Double support larger range than fixed decimal, DecimalType.Maximum should be enough + // in most case, also have better precision. + case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => + Some(DoubleType) + + case (t1: DecimalType, t2: DecimalType) => + val scale = math.max(t1.scale, t2.scale) + val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) + if (range + scale > 38) { + // DecimalType can't support precision > 38 + Some(DoubleType) + } else { + Some(DecimalType(range + scale, scale)) + } + case _ => None + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index e570889..040b56c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -19,14 +19,39 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.ArrayBasedMapData -import org.apache.spark.sql.types.{MapType, StringType, StructType} +import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String object ExprUtils { - def evalSchemaExpr(exp: Expression): StructType = exp match { - case Literal(s, StringType) => StructType.fromDDL(s.toString) + def evalSchemaExpr(exp: Expression): StructType = { + // Use `DataType.fromDDL` since the type string can be struct<...>. + val dataType = exp match { + case Literal(s, StringType) => + DataType.fromDDL(s.toString) + case e @ SchemaOfCsv(_: Literal, _) => + val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String] + DataType.fromDDL(ddlSchema.toString) + case e => throw new AnalysisException( + "Schema should be specified in DDL format as a string literal or output of " + + s"the schema_of_csv function instead of ${e.sql}") + } + + if (!dataType.isInstanceOf[StructType]) { + throw new AnalysisException( + s"Schema should be struct type but got ${dataType.sql}.") + } + dataType.asInstanceOf[StructType] + } + + def evalTypeExpr(exp: Expression): DataType = exp match { + case Literal(s, StringType) => DataType.fromDDL(s.toString) + case e @ SchemaOfJson(_: Literal, _) => + val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String] + DataType.fromDDL(ddlSchema.toString) case e => throw new AnalysisException( - s"Schema should be specified in DDL format as a string literal instead of ${e.sql}") + "Schema should be specified in DDL format as a string literal or output of " + + s"the schema_of_json function instead of ${e.sql}") } def convertToMapData(exp: Expression): Map[String, String] = exp match { http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 853b1ea..e70296f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import com.univocity.parsers.csv.CsvParser + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.csv._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util._ @@ -120,3 +123,54 @@ case class CsvToStructs( override def prettyName: String = "from_csv" } + +/** + * A function infers schema of CSV string. + */ +@ExpressionDescription( + usage = "_FUNC_(csv[, options]) - Returns schema in the DDL format of CSV string.", + examples = """ + Examples: + > SELECT _FUNC_('1,abc'); + struct<_c0:int,_c1:string> + """, + since = "3.0.0") +case class SchemaOfCsv( + child: Expression, + options: Map[String, String]) + extends UnaryExpression with CodegenFallback { + + def this(child: Expression) = this(child, Map.empty[String, String]) + + def this(child: Expression, options: Expression) = this( + child = child, + options = ExprUtils.convertToMapData(options)) + + override def dataType: DataType = StringType + + override def nullable: Boolean = false + + @transient + private lazy val csv = child.eval().asInstanceOf[UTF8String] + + override def checkInputDataTypes(): TypeCheckResult = child match { + case Literal(s, StringType) if s != null => super.checkInputDataTypes() + case _ => TypeCheckResult.TypeCheckFailure( + s"The input csv should be a string literal and not null; however, got ${child.sql}.") + } + + override def eval(v: InternalRow): Any = { + val parsedOptions = new CSVOptions(options, true, "UTC") + val parser = new CsvParser(parsedOptions.asParserSettings) + val row = parser.parseLine(csv.toString) + assert(row != null, "Parsed CSV record should not be null.") + + val header = row.zipWithIndex.map { case (_, index) => s"_c$index" } + val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) + val fieldTypes = CSVInferSchema.inferRowType(parsedOptions)(startType, row) + val st = StructType(CSVInferSchema.toStructFields(fieldTypes, header, parsedOptions)) + UTF8String.fromString(st.catalogString) + } + + override def prettyName: String = "schema_of_csv" +} http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 77af590..eafcb61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -529,7 +529,7 @@ case class JsonToStructs( // Used in `FunctionRegistry` def this(child: Expression, schema: Expression, options: Map[String, String]) = this( - schema = JsonExprUtils.evalSchemaExpr(schema), + schema = ExprUtils.evalTypeExpr(schema), options = options, child = child, timeZoneId = None) @@ -538,7 +538,7 @@ case class JsonToStructs( def this(child: Expression, schema: Expression, options: Expression) = this( - schema = JsonExprUtils.evalSchemaExpr(schema), + schema = ExprUtils.evalTypeExpr(schema), options = ExprUtils.convertToMapData(options), child = child, timeZoneId = None) @@ -784,15 +784,3 @@ case class SchemaOfJson( override def prettyName: String = "schema_of_json" } - -object JsonExprUtils { - def evalSchemaExpr(exp: Expression): DataType = exp match { - case Literal(s, StringType) => DataType.fromDDL(s.toString) - case e @ SchemaOfJson(_: Literal, _) => - val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String] - DataType.fromDDL(ddlSchema.toString) - case e => throw new AnalysisException( - "Schema should be specified in DDL format as a string literal" + - s" or output of the schema_of_json function instead of ${e.sql}") - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala new file mode 100644 index 0000000..651846d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.csv + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class CSVInferSchemaSuite extends SparkFunSuite { + + test("String fields types are inferred correctly from null types") { + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + assert(CSVInferSchema.inferField(NullType, "", options) == NullType) + assert(CSVInferSchema.inferField(NullType, null, options) == NullType) + assert(CSVInferSchema.inferField(NullType, "100000000000", options) == LongType) + assert(CSVInferSchema.inferField(NullType, "60", options) == IntegerType) + assert(CSVInferSchema.inferField(NullType, "3.5", options) == DoubleType) + assert(CSVInferSchema.inferField(NullType, "test", options) == StringType) + assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) == TimestampType) + assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType) + assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == BooleanType) + + val textValueOne = Long.MaxValue.toString + "0" + val decimalValueOne = new java.math.BigDecimal(textValueOne) + val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale) + assert(CSVInferSchema.inferField(NullType, textValueOne, options) == expectedTypeOne) + } + + test("String fields types are inferred correctly from other types") { + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType) + assert(CSVInferSchema.inferField(LongType, "test", options) == StringType) + assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == DoubleType) + assert(CSVInferSchema.inferField(DoubleType, null, options) == DoubleType) + assert(CSVInferSchema.inferField(DoubleType, "test", options) == StringType) + assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00", options) == TimestampType) + assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00", options) == TimestampType) + assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType) + assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == BooleanType) + assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == BooleanType) + + val textValueOne = Long.MaxValue.toString + "0" + val decimalValueOne = new java.math.BigDecimal(textValueOne) + val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale) + assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == expectedTypeOne) + } + + test("Timestamp field types are inferred correctly via custom data format") { + var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), false, "GMT") + assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) + options = new CSVOptions(Map("timestampFormat" -> "yyyy"), false, "GMT") + assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType) + } + + test("Timestamp field types are inferred correctly from other types") { + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == StringType) + assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) == StringType) + assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == StringType) + } + + test("Boolean fields types are inferred correctly from other types") { + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType) + assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == StringType) + } + + test("Type arrays are merged to highest common type") { + assert( + CSVInferSchema.mergeRowTypes(Array(StringType), + Array(DoubleType)).deep == Array(StringType).deep) + assert( + CSVInferSchema.mergeRowTypes(Array(IntegerType), + Array(LongType)).deep == Array(LongType).deep) + assert( + CSVInferSchema.mergeRowTypes(Array(DoubleType), + Array(LongType)).deep == Array(DoubleType).deep) + } + + test("Null fields are handled properly when a nullValue is specified") { + var options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") + assert(CSVInferSchema.inferField(NullType, "null", options) == NullType) + assert(CSVInferSchema.inferField(StringType, "null", options) == StringType) + assert(CSVInferSchema.inferField(LongType, "null", options) == LongType) + + options = new CSVOptions(Map("nullValue" -> "\\N"), false, "GMT") + assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType) + assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType) + assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType) + assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType) + assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == DecimalType(1, 1)) + } + + test("Merging Nulltypes should yield Nulltype.") { + val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType)) + assert(mergedNullTypes.deep == Array(NullType).deep) + } + + test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { + val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), false, "GMT") + assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) + } + + test("SPARK-18877: `inferField` on DecimalType should find a common type with `typeSoFar`") { + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + + // 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9). + assert(CSVInferSchema.inferField(DecimalType(3, -10), "1.19E+11", options) == + DecimalType(4, -9)) + + // BigDecimal("12345678901234567890.01234567890123456789") is precision 40 and scale 20. + val value = "12345678901234567890.01234567890123456789" + assert(CSVInferSchema.inferField(DecimalType(3, -10), value, options) == DoubleType) + + // Seq(s"${Long.MaxValue}1", "2015-12-01 00:00:00") should be StringType + assert(CSVInferSchema.inferField(NullType, s"${Long.MaxValue}1", options) == DecimalType(20, 0)) + assert(CSVInferSchema.inferField(DecimalType(20, 0), "2015-12-01 00:00:00", options) + == StringType) + } + + test("DoubleType should be inferred when user defined nan/inf are provided") { + val options = new CSVOptions(Map("nanValue" -> "nan", "negativeInf" -> "-inf", + "positiveInf" -> "inf"), false, "GMT") + assert(CSVInferSchema.inferField(NullType, "nan", options) == DoubleType) + assert(CSVInferSchema.inferField(NullType, "inf", options) == DoubleType) + assert(CSVInferSchema.inferField(NullType, "-inf", options) == DoubleType) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala new file mode 100644 index 0000000..e4e7dc2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.csv + +import java.math.BigDecimal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class UnivocityParserSuite extends SparkFunSuite { + private val parser = new UnivocityParser( + StructType(Seq.empty), + new CSVOptions(Map.empty[String, String], false, "GMT")) + + private def assertNull(v: Any) = assert(v == null) + + test("Can parse decimal type values") { + val stringValues = Seq("10.05", "1,000.01", "158,058,049.001") + val decimalValues = Seq(10.05, 1000.01, 158058049.001) + val decimalType = new DecimalType() + + stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => + val decimalValue = new BigDecimal(decimalVal.toString) + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + assert(parser.makeConverter("_1", decimalType, options = options).apply(strVal) === + Decimal(decimalValue, decimalType.precision, decimalType.scale)) + } + } + + test("Nullable types are handled") { + val types = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, + BooleanType, DecimalType.DoubleDecimal, TimestampType, DateType, StringType) + + // Nullable field with nullValue option. + types.foreach { t => + // Tests that a custom nullValue. + val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") + val converter = + parser.makeConverter("_1", t, nullable = true, options = nullValueOptions) + assertNull(converter.apply("-")) + assertNull(converter.apply(null)) + + // Tests that the default nullValue is empty string. + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + assertNull(parser.makeConverter("_1", t, nullable = true, options = options).apply("")) + } + + // Not nullable field with nullValue option. + types.foreach { t => + // Casts a null to not nullable field should throw an exception. + val options = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") + val converter = + parser.makeConverter("_1", t, nullable = false, options = options) + var message = intercept[RuntimeException] { + converter.apply("-") + }.getMessage + assert(message.contains("null value found but field _1 is not nullable.")) + message = intercept[RuntimeException] { + converter.apply(null) + }.getMessage + assert(message.contains("null value found but field _1 is not nullable.")) + } + + // If nullValue is different with empty string, then, empty string should not be casted into + // null. + Seq(true, false).foreach { b => + val options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") + val converter = + parser.makeConverter("_1", StringType, nullable = b, options = options) + assert(converter.apply("") == UTF8String.fromString("")) + } + } + + test("Throws exception for empty string with non null type") { + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val exception = intercept[RuntimeException]{ + parser.makeConverter("_1", IntegerType, nullable = false, options = options).apply("") + } + assert(exception.getMessage.contains("null value found but field _1 is not nullable.")) + } + + test("Types are cast correctly") { + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + assert(parser.makeConverter("_1", ByteType, options = options).apply("10") == 10) + assert(parser.makeConverter("_1", ShortType, options = options).apply("10") == 10) + assert(parser.makeConverter("_1", IntegerType, options = options).apply("10") == 10) + assert(parser.makeConverter("_1", LongType, options = options).apply("10") == 10) + assert(parser.makeConverter("_1", FloatType, options = options).apply("1.00") == 1.0) + assert(parser.makeConverter("_1", DoubleType, options = options).apply("1.00") == 1.0) + assert(parser.makeConverter("_1", BooleanType, options = options).apply("true") == true) + + val timestampsOptions = + new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), false, "GMT") + val customTimestamp = "31/01/2015 00:00" + val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime + val castedTimestamp = + parser.makeConverter("_1", TimestampType, nullable = true, options = timestampsOptions) + .apply(customTimestamp) + assert(castedTimestamp == expectedTime * 1000L) + + val customDate = "31/01/2015" + val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), false, "GMT") + val expectedDate = dateOptions.dateFormat.parse(customDate).getTime + val castedDate = + parser.makeConverter("_1", DateType, nullable = true, options = dateOptions) + .apply(customTimestamp) + assert(castedDate == DateTimeUtils.millisToDays(expectedDate)) + + val timestamp = "2015-01-01 00:00:00" + assert(parser.makeConverter("_1", TimestampType, options = options).apply(timestamp) == + DateTimeUtils.stringToTime(timestamp).getTime * 1000L) + assert(parser.makeConverter("_1", DateType, options = options).apply("2015-01-01") == + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime)) + } + + test("Throws exception for casting an invalid string to Float and Double Types") { + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val types = Seq(DoubleType, FloatType) + val input = Seq("10u000", "abc", "1 2/3") + types.foreach { dt => + input.foreach { v => + val message = intercept[NumberFormatException] { + parser.makeConverter("_1", dt, options = options).apply(v) + }.getMessage + assert(message.contains(v)) + } + } + } + + test("Float NaN values are parsed correctly") { + val options = new CSVOptions(Map("nanValue" -> "nn"), false, "GMT") + val floatVal: Float = parser.makeConverter( + "_1", FloatType, nullable = true, options = options + ).apply("nn").asInstanceOf[Float] + + // Java implements the IEEE-754 floating point standard which guarantees that any comparison + // against NaN will return false (except != which returns true) + assert(floatVal != floatVal) + } + + test("Double NaN values are parsed correctly") { + val options = new CSVOptions(Map("nanValue" -> "-"), false, "GMT") + val doubleVal: Double = parser.makeConverter( + "_1", DoubleType, nullable = true, options = options + ).apply("-").asInstanceOf[Double] + + assert(doubleVal.isNaN) + } + + test("Float infinite values can be parsed") { + val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") + val floatVal1 = parser.makeConverter( + "_1", FloatType, nullable = true, options = negativeInfOptions + ).apply("max").asInstanceOf[Float] + + assert(floatVal1 == Float.NegativeInfinity) + + val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") + val floatVal2 = parser.makeConverter( + "_1", FloatType, nullable = true, options = positiveInfOptions + ).apply("max").asInstanceOf[Float] + + assert(floatVal2 == Float.PositiveInfinity) + } + + test("Double infinite values can be parsed") { + val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") + val doubleVal1 = parser.makeConverter( + "_1", DoubleType, nullable = true, options = negativeInfOptions + ).apply("max").asInstanceOf[Double] + + assert(doubleVal1 == Double.NegativeInfinity) + + val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") + val doubleVal2 = parser.makeConverter( + "_1", DoubleType, nullable = true, options = positiveInfOptions + ).apply("max").asInstanceOf[Double] + + assert(doubleVal2 == Double.PositiveInfinity) + } + +} http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index 65987af..386e0d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -155,4 +155,14 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P }.getCause assert(exception.getMessage.contains("from_csv() doesn't support the DROPMALFORMED mode")) } + + test("infer schema of CSV strings") { + checkEvaluation(new SchemaOfCsv(Literal.create("1,abc")), "struct<_c0:int,_c1:string>") + } + + test("infer schema of CSV strings by using options") { + checkEvaluation( + new SchemaOfCsv(Literal.create("1|abc"), Map("delimiter" -> "|")), + "struct<_c0:int,_c1:string>") + } } http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 9e7b45d..4808e8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -34,7 +34,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} +import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVInferSchema, CSVOptions, UnivocityParser} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala deleted file mode 100644 index 4326a18..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ /dev/null @@ -1,214 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.csv - -import java.math.BigDecimal - -import scala.util.control.Exception._ - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.TypeCoercion -import org.apache.spark.sql.catalyst.csv.CSVOptions -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types._ - -private[csv] object CSVInferSchema { - - /** - * Similar to the JSON schema inference - * 1. Infer type of each row - * 2. Merge row types to find common type - * 3. Replace any null types with string type - */ - def infer( - tokenRDD: RDD[Array[String]], - header: Array[String], - options: CSVOptions): StructType = { - val fields = if (options.inferSchemaFlag) { - val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) - val rootTypes: Array[DataType] = - tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes) - - header.zip(rootTypes).map { case (thisHeader, rootType) => - val dType = rootType match { - case _: NullType => StringType - case other => other - } - StructField(thisHeader, dType, nullable = true) - } - } else { - // By default fields are assumed to be StringType - header.map(fieldName => StructField(fieldName, StringType, nullable = true)) - } - - StructType(fields) - } - - private def inferRowType(options: CSVOptions) - (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { - var i = 0 - while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing. - rowSoFar(i) = inferField(rowSoFar(i), next(i), options) - i+=1 - } - rowSoFar - } - - def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = { - first.zipAll(second, NullType, NullType).map { case (a, b) => - compatibleType(a, b).getOrElse(NullType) - } - } - - /** - * Infer type of string field. Given known type Double, and a string "1", there is no - * point checking if it is an Int, as the final type must be Double or higher. - */ - def inferField(typeSoFar: DataType, field: String, options: CSVOptions): DataType = { - if (field == null || field.isEmpty || field == options.nullValue) { - typeSoFar - } else { - typeSoFar match { - case NullType => tryParseInteger(field, options) - case IntegerType => tryParseInteger(field, options) - case LongType => tryParseLong(field, options) - case _: DecimalType => - // DecimalTypes have different precisions and scales, so we try to find the common type. - compatibleType(typeSoFar, tryParseDecimal(field, options)).getOrElse(StringType) - case DoubleType => tryParseDouble(field, options) - case TimestampType => tryParseTimestamp(field, options) - case BooleanType => tryParseBoolean(field, options) - case StringType => StringType - case other: DataType => - throw new UnsupportedOperationException(s"Unexpected data type $other") - } - } - } - - private def isInfOrNan(field: String, options: CSVOptions): Boolean = { - field == options.nanValue || field == options.negativeInf || field == options.positiveInf - } - - private def tryParseInteger(field: String, options: CSVOptions): DataType = { - if ((allCatch opt field.toInt).isDefined) { - IntegerType - } else { - tryParseLong(field, options) - } - } - - private def tryParseLong(field: String, options: CSVOptions): DataType = { - if ((allCatch opt field.toLong).isDefined) { - LongType - } else { - tryParseDecimal(field, options) - } - } - - private def tryParseDecimal(field: String, options: CSVOptions): DataType = { - val decimalTry = allCatch opt { - // `BigDecimal` conversion can fail when the `field` is not a form of number. - val bigDecimal = new BigDecimal(field) - // Because many other formats do not support decimal, it reduces the cases for - // decimals by disallowing values having scale (eg. `1.1`). - if (bigDecimal.scale <= 0) { - // `DecimalType` conversion can fail when - // 1. The precision is bigger than 38. - // 2. scale is bigger than precision. - DecimalType(bigDecimal.precision, bigDecimal.scale) - } else { - tryParseDouble(field, options) - } - } - decimalTry.getOrElse(tryParseDouble(field, options)) - } - - private def tryParseDouble(field: String, options: CSVOptions): DataType = { - if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field, options)) { - DoubleType - } else { - tryParseTimestamp(field, options) - } - } - - private def tryParseTimestamp(field: String, options: CSVOptions): DataType = { - // This case infers a custom `dataFormat` is set. - if ((allCatch opt options.timestampFormat.parse(field)).isDefined) { - TimestampType - } else if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { - // We keep this for backwards compatibility. - TimestampType - } else { - tryParseBoolean(field, options) - } - } - - private def tryParseBoolean(field: String, options: CSVOptions): DataType = { - if ((allCatch opt field.toBoolean).isDefined) { - BooleanType - } else { - stringType() - } - } - - // Defining a function to return the StringType constant is necessary in order to work around - // a Scala compiler issue which leads to runtime incompatibilities with certain Spark versions; - // see issue #128 for more details. - private def stringType(): DataType = { - StringType - } - - /** - * Returns the common data type given two input data types so that the return type - * is compatible with both input data types. - */ - private def compatibleType(t1: DataType, t2: DataType): Option[DataType] = { - TypeCoercion.findTightestCommonType(t1, t2).orElse(findCompatibleTypeForCSV(t1, t2)) - } - - /** - * The following pattern matching represents additional type promotion rules that - * are CSV specific. - */ - private val findCompatibleTypeForCSV: (DataType, DataType) => Option[DataType] = { - case (StringType, t2) => Some(StringType) - case (t1, StringType) => Some(StringType) - - // These two cases below deal with when `IntegralType` is larger than `DecimalType`. - case (t1: IntegralType, t2: DecimalType) => - compatibleType(DecimalType.forType(t1), t2) - case (t1: DecimalType, t2: IntegralType) => - compatibleType(t1, DecimalType.forType(t2)) - - // Double support larger range than fixed decimal, DecimalType.Maximum should be enough - // in most case, also have better precision. - case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => - Some(DoubleType) - - case (t1: DecimalType, t2: DecimalType) => - val scale = math.max(t1.scale, t2.scale) - val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) - if (range + scale > 38) { - // DecimalType can't support precision > 38 - Some(DoubleType) - } else { - Some(DecimalType(range + scale, scale)) - } - case _ => None - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5348b65..f8c4d88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3870,6 +3870,41 @@ object functions { withExpr(new CsvToStructs(e.expr, schema.expr, options.asScala.toMap)) } + /** + * Parses a CSV string and infers its schema in DDL format. + * + * @param csv a CSV string. + * + * @group collection_funcs + * @since 3.0.0 + */ + def schema_of_csv(csv: String): Column = schema_of_csv(lit(csv)) + + /** + * Parses a CSV string and infers its schema in DDL format. + * + * @param csv a string literal containing a CSV string. + * + * @group collection_funcs + * @since 3.0.0 + */ + def schema_of_csv(csv: Column): Column = withExpr(new SchemaOfCsv(csv.expr)) + + /** + * Parses a CSV string and infers its schema in DDL format using options. + * + * @param csv a string literal containing a CSV string. + * @param options options to control how the CSV is parsed. accepts the same options and the + * json data source. See [[DataFrameReader#csv]]. + * @return a column with string literal containing schema in DDL format. + * + * @group collection_funcs + * @since 3.0.0 + */ + def schema_of_csv(csv: Column, options: java.util.Map[String, String]): Column = { + withExpr(SchemaOfCsv(csv.expr, options.asScala.toMap)) + } + // scalastyle:off line.size.limit // scalastyle:off parameter.number http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql index d2214fd..5be6f80 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql @@ -7,3 +7,11 @@ select from_csv('1', 'a InvalidType'); select from_csv('1', 'a INT', named_struct('mode', 'PERMISSIVE')); select from_csv('1', 'a INT', map('mode', 1)); select from_csv(); +-- infer schema of json literal +select from_csv('1,abc', schema_of_csv('1,abc')); +select schema_of_csv('1|abc', map('delimiter', '|')); +select schema_of_csv(null); +CREATE TEMPORARY VIEW csvTable(csvField, a) AS SELECT * FROM VALUES ('1,abc', 'a'); +SELECT schema_of_csv(csvField) FROM csvTable; +-- Clean up +DROP VIEW IF EXISTS csvTable; http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out index f19f34a..677bbd9 100644 --- a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 13 -- !query 0 @@ -24,7 +24,7 @@ select from_csv('1', 1) struct<> -- !query 2 output org.apache.spark.sql.AnalysisException -Schema should be specified in DDL format as a string literal instead of 1;; line 1 pos 7 +Schema should be specified in DDL format as a string literal or output of the schema_of_csv function instead of 1;; line 1 pos 7 -- !query 3 @@ -67,3 +67,53 @@ struct<> -- !query 6 output org.apache.spark.sql.AnalysisException Invalid number of arguments for function from_csv. Expected: one of 2 and 3; Found: 0; line 1 pos 7 + + +-- !query 7 +select from_csv('1,abc', schema_of_csv('1,abc')) +-- !query 7 schema +struct<from_csv(1,abc):struct<_c0:int,_c1:string>> +-- !query 7 output +{"_c0":1,"_c1":"abc"} + + +-- !query 8 +select schema_of_csv('1|abc', map('delimiter', '|')) +-- !query 8 schema +struct<schema_of_csv(1|abc):string> +-- !query 8 output +struct<_c0:int,_c1:string> + + +-- !query 9 +select schema_of_csv(null) +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve 'schema_of_csv(NULL)' due to data type mismatch: The input csv should be a string literal and not null; however, got NULL.; line 1 pos 7 + + +-- !query 10 +CREATE TEMPORARY VIEW csvTable(csvField, a) AS SELECT * FROM VALUES ('1,abc', 'a') +-- !query 10 schema +struct<> +-- !query 10 output + + + +-- !query 11 +SELECT schema_of_csv(csvField) FROM csvTable +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve 'schema_of_csv(csvtable.`csvField`)' due to data type mismatch: The input csv should be a string literal and not null; however, got csvtable.`csvField`.; line 1 pos 7 + + +-- !query 12 +DROP VIEW IF EXISTS csvTable +-- !query 12 schema +struct<> +-- !query 12 output + http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 38a2143..9395f05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -59,4 +59,19 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { Row(Row(null, null, "0,2013-111-11 12:13:14")), Row(Row(1, java.sql.Date.valueOf("1983-08-04"), null)))) } + + test("schema_of_csv - infers schemas") { + checkAnswer( + spark.range(1).select(schema_of_csv(lit("0.1,1"))), + Seq(Row("struct<_c0:double,_c1:int>"))) + checkAnswer( + spark.range(1).select(schema_of_csv("0.1,1")), + Seq(Row("struct<_c0:double,_c1:int>"))) + } + + test("schema_of_csv - infers schemas using options") { + val df = spark.range(1) + .select(schema_of_csv(lit("0.1 1"), Map("sep" -> " ").asJava)) + checkAnswer(df, Seq(Row("struct<_c0:double,_c1:int>"))) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala deleted file mode 100644 index 6b64f2f..0000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.csv - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.csv.CSVOptions -import org.apache.spark.sql.types._ - -class CSVInferSchemaSuite extends SparkFunSuite { - - test("String fields types are inferred correctly from null types") { - val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(CSVInferSchema.inferField(NullType, "", options) == NullType) - assert(CSVInferSchema.inferField(NullType, null, options) == NullType) - assert(CSVInferSchema.inferField(NullType, "100000000000", options) == LongType) - assert(CSVInferSchema.inferField(NullType, "60", options) == IntegerType) - assert(CSVInferSchema.inferField(NullType, "3.5", options) == DoubleType) - assert(CSVInferSchema.inferField(NullType, "test", options) == StringType) - assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) == TimestampType) - assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType) - assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == BooleanType) - - val textValueOne = Long.MaxValue.toString + "0" - val decimalValueOne = new java.math.BigDecimal(textValueOne) - val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale) - assert(CSVInferSchema.inferField(NullType, textValueOne, options) == expectedTypeOne) - } - - test("String fields types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType) - assert(CSVInferSchema.inferField(LongType, "test", options) == StringType) - assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == DoubleType) - assert(CSVInferSchema.inferField(DoubleType, null, options) == DoubleType) - assert(CSVInferSchema.inferField(DoubleType, "test", options) == StringType) - assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00", options) == TimestampType) - assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00", options) == TimestampType) - assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType) - assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == BooleanType) - assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == BooleanType) - - val textValueOne = Long.MaxValue.toString + "0" - val decimalValueOne = new java.math.BigDecimal(textValueOne) - val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale) - assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == expectedTypeOne) - } - - test("Timestamp field types are inferred correctly via custom data format") { - var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), false, "GMT") - assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) - options = new CSVOptions(Map("timestampFormat" -> "yyyy"), false, "GMT") - assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType) - } - - test("Timestamp field types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == StringType) - assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) == StringType) - assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == StringType) - } - - test("Boolean fields types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType) - assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == StringType) - } - - test("Type arrays are merged to highest common type") { - assert( - CSVInferSchema.mergeRowTypes(Array(StringType), - Array(DoubleType)).deep == Array(StringType).deep) - assert( - CSVInferSchema.mergeRowTypes(Array(IntegerType), - Array(LongType)).deep == Array(LongType).deep) - assert( - CSVInferSchema.mergeRowTypes(Array(DoubleType), - Array(LongType)).deep == Array(DoubleType).deep) - } - - test("Null fields are handled properly when a nullValue is specified") { - var options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") - assert(CSVInferSchema.inferField(NullType, "null", options) == NullType) - assert(CSVInferSchema.inferField(StringType, "null", options) == StringType) - assert(CSVInferSchema.inferField(LongType, "null", options) == LongType) - - options = new CSVOptions(Map("nullValue" -> "\\N"), false, "GMT") - assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType) - assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType) - assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType) - assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType) - assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == DecimalType(1, 1)) - } - - test("Merging Nulltypes should yield Nulltype.") { - val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType)) - assert(mergedNullTypes.deep == Array(NullType).deep) - } - - test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { - val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), false, "GMT") - assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) - } - - test("SPARK-18877: `inferField` on DecimalType should find a common type with `typeSoFar`") { - val options = new CSVOptions(Map.empty[String, String], false, "GMT") - - // 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9). - assert(CSVInferSchema.inferField(DecimalType(3, -10), "1.19E+11", options) == - DecimalType(4, -9)) - - // BigDecimal("12345678901234567890.01234567890123456789") is precision 40 and scale 20. - val value = "12345678901234567890.01234567890123456789" - assert(CSVInferSchema.inferField(DecimalType(3, -10), value, options) == DoubleType) - - // Seq(s"${Long.MaxValue}1", "2015-12-01 00:00:00") should be StringType - assert(CSVInferSchema.inferField(NullType, s"${Long.MaxValue}1", options) == DecimalType(20, 0)) - assert(CSVInferSchema.inferField(DecimalType(20, 0), "2015-12-01 00:00:00", options) - == StringType) - } - - test("DoubleType should be inferred when user defined nan/inf are provided") { - val options = new CSVOptions(Map("nanValue" -> "nan", "negativeInf" -> "-inf", - "positiveInf" -> "inf"), false, "GMT") - assert(CSVInferSchema.inferField(NullType, "nan", options) == DoubleType) - assert(CSVInferSchema.inferField(NullType, "inf", options) == DoubleType) - assert(CSVInferSchema.inferField(NullType, "-inf", options) == DoubleType) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala deleted file mode 100644 index 6f23114..0000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.csv - -import java.math.BigDecimal - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityParser} -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -class UnivocityParserSuite extends SparkFunSuite { - private val parser = new UnivocityParser( - StructType(Seq.empty), - new CSVOptions(Map.empty[String, String], false, "GMT")) - - private def assertNull(v: Any) = assert(v == null) - - test("Can parse decimal type values") { - val stringValues = Seq("10.05", "1,000.01", "158,058,049.001") - val decimalValues = Seq(10.05, 1000.01, 158058049.001) - val decimalType = new DecimalType() - - stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => - val decimalValue = new BigDecimal(decimalVal.toString) - val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(parser.makeConverter("_1", decimalType, options = options).apply(strVal) === - Decimal(decimalValue, decimalType.precision, decimalType.scale)) - } - } - - test("Nullable types are handled") { - val types = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, - BooleanType, DecimalType.DoubleDecimal, TimestampType, DateType, StringType) - - // Nullable field with nullValue option. - types.foreach { t => - // Tests that a custom nullValue. - val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") - val converter = - parser.makeConverter("_1", t, nullable = true, options = nullValueOptions) - assertNull(converter.apply("-")) - assertNull(converter.apply(null)) - - // Tests that the default nullValue is empty string. - val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assertNull(parser.makeConverter("_1", t, nullable = true, options = options).apply("")) - } - - // Not nullable field with nullValue option. - types.foreach { t => - // Casts a null to not nullable field should throw an exception. - val options = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") - val converter = - parser.makeConverter("_1", t, nullable = false, options = options) - var message = intercept[RuntimeException] { - converter.apply("-") - }.getMessage - assert(message.contains("null value found but field _1 is not nullable.")) - message = intercept[RuntimeException] { - converter.apply(null) - }.getMessage - assert(message.contains("null value found but field _1 is not nullable.")) - } - - // If nullValue is different with empty string, then, empty string should not be casted into - // null. - Seq(true, false).foreach { b => - val options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") - val converter = - parser.makeConverter("_1", StringType, nullable = b, options = options) - assert(converter.apply("") == UTF8String.fromString("")) - } - } - - test("Throws exception for empty string with non null type") { - val options = new CSVOptions(Map.empty[String, String], false, "GMT") - val exception = intercept[RuntimeException]{ - parser.makeConverter("_1", IntegerType, nullable = false, options = options).apply("") - } - assert(exception.getMessage.contains("null value found but field _1 is not nullable.")) - } - - test("Types are cast correctly") { - val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(parser.makeConverter("_1", ByteType, options = options).apply("10") == 10) - assert(parser.makeConverter("_1", ShortType, options = options).apply("10") == 10) - assert(parser.makeConverter("_1", IntegerType, options = options).apply("10") == 10) - assert(parser.makeConverter("_1", LongType, options = options).apply("10") == 10) - assert(parser.makeConverter("_1", FloatType, options = options).apply("1.00") == 1.0) - assert(parser.makeConverter("_1", DoubleType, options = options).apply("1.00") == 1.0) - assert(parser.makeConverter("_1", BooleanType, options = options).apply("true") == true) - - val timestampsOptions = - new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), false, "GMT") - val customTimestamp = "31/01/2015 00:00" - val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime - val castedTimestamp = - parser.makeConverter("_1", TimestampType, nullable = true, options = timestampsOptions) - .apply(customTimestamp) - assert(castedTimestamp == expectedTime * 1000L) - - val customDate = "31/01/2015" - val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), false, "GMT") - val expectedDate = dateOptions.dateFormat.parse(customDate).getTime - val castedDate = - parser.makeConverter("_1", DateType, nullable = true, options = dateOptions) - .apply(customTimestamp) - assert(castedDate == DateTimeUtils.millisToDays(expectedDate)) - - val timestamp = "2015-01-01 00:00:00" - assert(parser.makeConverter("_1", TimestampType, options = options).apply(timestamp) == - DateTimeUtils.stringToTime(timestamp).getTime * 1000L) - assert(parser.makeConverter("_1", DateType, options = options).apply("2015-01-01") == - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime)) - } - - test("Throws exception for casting an invalid string to Float and Double Types") { - val options = new CSVOptions(Map.empty[String, String], false, "GMT") - val types = Seq(DoubleType, FloatType) - val input = Seq("10u000", "abc", "1 2/3") - types.foreach { dt => - input.foreach { v => - val message = intercept[NumberFormatException] { - parser.makeConverter("_1", dt, options = options).apply(v) - }.getMessage - assert(message.contains(v)) - } - } - } - - test("Float NaN values are parsed correctly") { - val options = new CSVOptions(Map("nanValue" -> "nn"), false, "GMT") - val floatVal: Float = parser.makeConverter( - "_1", FloatType, nullable = true, options = options - ).apply("nn").asInstanceOf[Float] - - // Java implements the IEEE-754 floating point standard which guarantees that any comparison - // against NaN will return false (except != which returns true) - assert(floatVal != floatVal) - } - - test("Double NaN values are parsed correctly") { - val options = new CSVOptions(Map("nanValue" -> "-"), false, "GMT") - val doubleVal: Double = parser.makeConverter( - "_1", DoubleType, nullable = true, options = options - ).apply("-").asInstanceOf[Double] - - assert(doubleVal.isNaN) - } - - test("Float infinite values can be parsed") { - val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") - val floatVal1 = parser.makeConverter( - "_1", FloatType, nullable = true, options = negativeInfOptions - ).apply("max").asInstanceOf[Float] - - assert(floatVal1 == Float.NegativeInfinity) - - val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") - val floatVal2 = parser.makeConverter( - "_1", FloatType, nullable = true, options = positiveInfOptions - ).apply("max").asInstanceOf[Float] - - assert(floatVal2 == Float.PositiveInfinity) - } - - test("Double infinite values can be parsed") { - val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") - val doubleVal1 = parser.makeConverter( - "_1", DoubleType, nullable = true, options = negativeInfOptions - ).apply("max").asInstanceOf[Double] - - assert(doubleVal1 == Double.NegativeInfinity) - - val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") - val doubleVal2 = parser.makeConverter( - "_1", DoubleType, nullable = true, options = positiveInfOptions - ).apply("max").asInstanceOf[Double] - - assert(doubleVal2 == Double.PositiveInfinity) - } - -} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org