Repository: spark Updated Branches: refs/heads/master 59741887e -> f97326bcd
[SPARK-25977][SQL] Parsing decimals from CSV using locale ## What changes were proposed in this pull request? In the PR, I propose using of the locale option to parse decimals from CSV input. After the changes, `UnivocityParser` converts input string to `BigDecimal` and to Spark's Decimal by using `java.text.DecimalFormat`. ## How was this patch tested? Added a test for the `en-US`, `ko-KR`, `ru-RU`, `de-DE` locales. Closes #22979 from MaxGekk/decimal-parsing-locale. Lead-authored-by: Maxim Gekk <maxim.g...@databricks.com> Co-authored-by: Maxim Gekk <max.g...@gmail.com> Signed-off-by: hyukjinkwon <gurwls...@apache.org> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f97326bc Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f97326bc Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f97326bc Branch: refs/heads/master Commit: f97326bcdba532eabf25d4899b13709e9af2bfea Parents: 5974188 Author: Maxim Gekk <maxim.g...@databricks.com> Authored: Fri Nov 30 08:27:55 2018 +0800 Committer: hyukjinkwon <gurwls...@apache.org> Committed: Fri Nov 30 08:27:55 2018 +0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/csv/CSVExprUtils.scala | 4 + .../spark/sql/catalyst/csv/CSVInferSchema.scala | 72 ++++----- .../sql/catalyst/csv/UnivocityParser.scala | 8 +- .../catalyst/expressions/csvExpressions.scala | 5 +- .../sql/catalyst/csv/CSVInferSchemaSuite.scala | 147 ++++++++++++------- .../sql/catalyst/csv/UnivocityParserSuite.scala | 22 ++- .../datasources/csv/CSVDataSource.scala | 4 +- 7 files changed, 168 insertions(+), 94 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f97326bc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala index bbe2783..6c982a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.catalyst.csv +import java.math.BigDecimal +import java.text.{DecimalFormat, DecimalFormatSymbols, ParsePosition} +import java.util.Locale + object CSVExprUtils { /** * Filter ignorable rows for CSV iterator (lines empty and starting with `comment`). http://git-wip-us.apache.org/repos/asf/spark/blob/f97326bc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 799e999..94cb4b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -17,16 +17,19 @@ package org.apache.spark.sql.catalyst.csv -import java.math.BigDecimal - import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion +import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -object CSVInferSchema { +class CSVInferSchema(options: CSVOptions) extends Serializable { + + private val decimalParser = { + ExprUtils.getDecimalParser(options.locale) + } /** * Similar to the JSON schema inference @@ -36,14 +39,13 @@ object CSVInferSchema { */ def infer( tokenRDD: RDD[Array[String]], - header: Array[String], - options: CSVOptions): StructType = { + header: Array[String]): StructType = { val fields = if (options.inferSchemaFlag) { val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) val rootTypes: Array[DataType] = - tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes) + tokenRDD.aggregate(startType)(inferRowType, mergeRowTypes) - toStructFields(rootTypes, header, options) + toStructFields(rootTypes, header) } else { // By default fields are assumed to be StringType header.map(fieldName => StructField(fieldName, StringType, nullable = true)) @@ -54,8 +56,7 @@ object CSVInferSchema { def toStructFields( fieldTypes: Array[DataType], - header: Array[String], - options: CSVOptions): Array[StructField] = { + header: Array[String]): Array[StructField] = { header.zip(fieldTypes).map { case (thisHeader, rootType) => val dType = rootType match { case _: NullType => StringType @@ -65,11 +66,10 @@ object CSVInferSchema { } } - def inferRowType(options: CSVOptions) - (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { + def inferRowType(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { var i = 0 while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing. - rowSoFar(i) = inferField(rowSoFar(i), next(i), options) + rowSoFar(i) = inferField(rowSoFar(i), next(i)) i+=1 } rowSoFar @@ -85,20 +85,20 @@ object CSVInferSchema { * Infer type of string field. Given known type Double, and a string "1", there is no * point checking if it is an Int, as the final type must be Double or higher. */ - def inferField(typeSoFar: DataType, field: String, options: CSVOptions): DataType = { + def inferField(typeSoFar: DataType, field: String): DataType = { if (field == null || field.isEmpty || field == options.nullValue) { typeSoFar } else { typeSoFar match { - case NullType => tryParseInteger(field, options) - case IntegerType => tryParseInteger(field, options) - case LongType => tryParseLong(field, options) + case NullType => tryParseInteger(field) + case IntegerType => tryParseInteger(field) + case LongType => tryParseLong(field) case _: DecimalType => // DecimalTypes have different precisions and scales, so we try to find the common type. - compatibleType(typeSoFar, tryParseDecimal(field, options)).getOrElse(StringType) - case DoubleType => tryParseDouble(field, options) - case TimestampType => tryParseTimestamp(field, options) - case BooleanType => tryParseBoolean(field, options) + compatibleType(typeSoFar, tryParseDecimal(field)).getOrElse(StringType) + case DoubleType => tryParseDouble(field) + case TimestampType => tryParseTimestamp(field) + case BooleanType => tryParseBoolean(field) case StringType => StringType case other: DataType => throw new UnsupportedOperationException(s"Unexpected data type $other") @@ -106,30 +106,30 @@ object CSVInferSchema { } } - private def isInfOrNan(field: String, options: CSVOptions): Boolean = { + private def isInfOrNan(field: String): Boolean = { field == options.nanValue || field == options.negativeInf || field == options.positiveInf } - private def tryParseInteger(field: String, options: CSVOptions): DataType = { + private def tryParseInteger(field: String): DataType = { if ((allCatch opt field.toInt).isDefined) { IntegerType } else { - tryParseLong(field, options) + tryParseLong(field) } } - private def tryParseLong(field: String, options: CSVOptions): DataType = { + private def tryParseLong(field: String): DataType = { if ((allCatch opt field.toLong).isDefined) { LongType } else { - tryParseDecimal(field, options) + tryParseDecimal(field) } } - private def tryParseDecimal(field: String, options: CSVOptions): DataType = { + private def tryParseDecimal(field: String): DataType = { val decimalTry = allCatch opt { - // `BigDecimal` conversion can fail when the `field` is not a form of number. - val bigDecimal = new BigDecimal(field) + // The conversion can fail when the `field` is not a form of number. + val bigDecimal = decimalParser(field) // Because many other formats do not support decimal, it reduces the cases for // decimals by disallowing values having scale (eg. `1.1`). if (bigDecimal.scale <= 0) { @@ -138,21 +138,21 @@ object CSVInferSchema { // 2. scale is bigger than precision. DecimalType(bigDecimal.precision, bigDecimal.scale) } else { - tryParseDouble(field, options) + tryParseDouble(field) } } - decimalTry.getOrElse(tryParseDouble(field, options)) + decimalTry.getOrElse(tryParseDouble(field)) } - private def tryParseDouble(field: String, options: CSVOptions): DataType = { - if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field, options)) { + private def tryParseDouble(field: String): DataType = { + if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field)) { DoubleType } else { - tryParseTimestamp(field, options) + tryParseTimestamp(field) } } - private def tryParseTimestamp(field: String, options: CSVOptions): DataType = { + private def tryParseTimestamp(field: String): DataType = { // This case infers a custom `dataFormat` is set. if ((allCatch opt options.timestampFormat.parse(field)).isDefined) { TimestampType @@ -160,11 +160,11 @@ object CSVInferSchema { // We keep this for backwards compatibility. TimestampType } else { - tryParseBoolean(field, options) + tryParseBoolean(field) } } - private def tryParseBoolean(field: String, options: CSVOptions): DataType = { + private def tryParseBoolean(field: String): DataType = { if ((allCatch opt field.toBoolean).isDefined) { BooleanType } else { http://git-wip-us.apache.org/repos/asf/spark/blob/f97326bc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index ed19693..85e1292 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.csv import java.io.InputStream -import java.math.BigDecimal import scala.util.Try import scala.util.control.NonFatal @@ -27,7 +26,7 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow} import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -104,6 +103,8 @@ class UnivocityParser( requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray } + private val decimalParser = ExprUtils.getDecimalParser(options.locale) + /** * Create a converter which converts the string value to a value according to a desired type. * Currently, we do not support complex types (`ArrayType`, `MapType`, `StructType`). @@ -149,8 +150,7 @@ class UnivocityParser( case dt: DecimalType => (d: String) => nullSafeDatum(d, name, nullable, options) { datum => - val value = new BigDecimal(datum.replaceAll(",", "")) - Decimal(value, dt.precision, dt.scale) + Decimal(decimalParser(datum), dt.precision, dt.scale) } case _: TimestampType => (d: String) => http://git-wip-us.apache.org/repos/asf/spark/blob/f97326bc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 1e4e1c6..83b0299 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -180,8 +180,9 @@ case class SchemaOfCsv( val header = row.zipWithIndex.map { case (_, index) => s"_c$index" } val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) - val fieldTypes = CSVInferSchema.inferRowType(parsedOptions)(startType, row) - val st = StructType(CSVInferSchema.toStructFields(fieldTypes, header, parsedOptions)) + val inferSchema = new CSVInferSchema(parsedOptions) + val fieldTypes = inferSchema.inferRowType(startType, row) + val st = StructType(inferSchema.toStructFields(fieldTypes, header)) UTF8String.fromString(st.catalogString) } http://git-wip-us.apache.org/repos/asf/spark/blob/f97326bc/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala index 651846d..1a020e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala @@ -17,126 +17,175 @@ package org.apache.spark.sql.catalyst.csv +import java.text.{DecimalFormat, DecimalFormatSymbols} +import java.util.Locale + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -class CSVInferSchemaSuite extends SparkFunSuite { +class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper { test("String fields types are inferred correctly from null types") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(CSVInferSchema.inferField(NullType, "", options) == NullType) - assert(CSVInferSchema.inferField(NullType, null, options) == NullType) - assert(CSVInferSchema.inferField(NullType, "100000000000", options) == LongType) - assert(CSVInferSchema.inferField(NullType, "60", options) == IntegerType) - assert(CSVInferSchema.inferField(NullType, "3.5", options) == DoubleType) - assert(CSVInferSchema.inferField(NullType, "test", options) == StringType) - assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) == TimestampType) - assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType) - assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == BooleanType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(NullType, "") == NullType) + assert(inferSchema.inferField(NullType, null) == NullType) + assert(inferSchema.inferField(NullType, "100000000000") == LongType) + assert(inferSchema.inferField(NullType, "60") == IntegerType) + assert(inferSchema.inferField(NullType, "3.5") == DoubleType) + assert(inferSchema.inferField(NullType, "test") == StringType) + assert(inferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType) + assert(inferSchema.inferField(NullType, "True") == BooleanType) + assert(inferSchema.inferField(NullType, "FAlSE") == BooleanType) val textValueOne = Long.MaxValue.toString + "0" val decimalValueOne = new java.math.BigDecimal(textValueOne) val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale) - assert(CSVInferSchema.inferField(NullType, textValueOne, options) == expectedTypeOne) + assert(inferSchema.inferField(NullType, textValueOne) == expectedTypeOne) } test("String fields types are inferred correctly from other types") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType) - assert(CSVInferSchema.inferField(LongType, "test", options) == StringType) - assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == DoubleType) - assert(CSVInferSchema.inferField(DoubleType, null, options) == DoubleType) - assert(CSVInferSchema.inferField(DoubleType, "test", options) == StringType) - assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00", options) == TimestampType) - assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00", options) == TimestampType) - assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType) - assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == BooleanType) - assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == BooleanType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(LongType, "1.0") == DoubleType) + assert(inferSchema.inferField(LongType, "test") == StringType) + assert(inferSchema.inferField(IntegerType, "1.0") == DoubleType) + assert(inferSchema.inferField(DoubleType, null) == DoubleType) + assert(inferSchema.inferField(DoubleType, "test") == StringType) + assert(inferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType) + assert(inferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType) + assert(inferSchema.inferField(LongType, "True") == BooleanType) + assert(inferSchema.inferField(IntegerType, "FALSE") == BooleanType) + assert(inferSchema.inferField(TimestampType, "FALSE") == BooleanType) val textValueOne = Long.MaxValue.toString + "0" val decimalValueOne = new java.math.BigDecimal(textValueOne) val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale) - assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == expectedTypeOne) + assert(inferSchema.inferField(IntegerType, textValueOne) == expectedTypeOne) } test("Timestamp field types are inferred correctly via custom data format") { var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), false, "GMT") - assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) + var inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(TimestampType, "2015-08") == TimestampType) + options = new CSVOptions(Map("timestampFormat" -> "yyyy"), false, "GMT") - assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType) + inferSchema = new CSVInferSchema(options) + assert(inferSchema.inferField(TimestampType, "2015") == TimestampType) } test("Timestamp field types are inferred correctly from other types") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == StringType) - assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) == StringType) - assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == StringType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(IntegerType, "2015-08-20 14") == StringType) + assert(inferSchema.inferField(DoubleType, "2015-08-20 14:10") == StringType) + assert(inferSchema.inferField(LongType, "2015-08 14:49:00") == StringType) } test("Boolean fields types are inferred correctly from other types") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType) - assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == StringType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(LongType, "Fale") == StringType) + assert(inferSchema.inferField(DoubleType, "TRUEe") == StringType) } test("Type arrays are merged to highest common type") { + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val inferSchema = new CSVInferSchema(options) + assert( - CSVInferSchema.mergeRowTypes(Array(StringType), + inferSchema.mergeRowTypes(Array(StringType), Array(DoubleType)).deep == Array(StringType).deep) assert( - CSVInferSchema.mergeRowTypes(Array(IntegerType), + inferSchema.mergeRowTypes(Array(IntegerType), Array(LongType)).deep == Array(LongType).deep) assert( - CSVInferSchema.mergeRowTypes(Array(DoubleType), + inferSchema.mergeRowTypes(Array(DoubleType), Array(LongType)).deep == Array(DoubleType).deep) } test("Null fields are handled properly when a nullValue is specified") { var options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") - assert(CSVInferSchema.inferField(NullType, "null", options) == NullType) - assert(CSVInferSchema.inferField(StringType, "null", options) == StringType) - assert(CSVInferSchema.inferField(LongType, "null", options) == LongType) + var inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(NullType, "null") == NullType) + assert(inferSchema.inferField(StringType, "null") == StringType) + assert(inferSchema.inferField(LongType, "null") == LongType) options = new CSVOptions(Map("nullValue" -> "\\N"), false, "GMT") - assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType) - assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType) - assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType) - assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType) - assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == DecimalType(1, 1)) + inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(IntegerType, "\\N") == IntegerType) + assert(inferSchema.inferField(DoubleType, "\\N") == DoubleType) + assert(inferSchema.inferField(TimestampType, "\\N") == TimestampType) + assert(inferSchema.inferField(BooleanType, "\\N") == BooleanType) + assert(inferSchema.inferField(DecimalType(1, 1), "\\N") == DecimalType(1, 1)) } test("Merging Nulltypes should yield Nulltype.") { - val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType)) + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val inferSchema = new CSVInferSchema(options) + + val mergedNullTypes = inferSchema.mergeRowTypes(Array(NullType), Array(NullType)) assert(mergedNullTypes.deep == Array(NullType).deep) } test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), false, "GMT") - assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(TimestampType, "2015-08") == TimestampType) } test("SPARK-18877: `inferField` on DecimalType should find a common type with `typeSoFar`") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val inferSchema = new CSVInferSchema(options) // 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9). - assert(CSVInferSchema.inferField(DecimalType(3, -10), "1.19E+11", options) == + assert(inferSchema.inferField(DecimalType(3, -10), "1.19E11") == DecimalType(4, -9)) // BigDecimal("12345678901234567890.01234567890123456789") is precision 40 and scale 20. val value = "12345678901234567890.01234567890123456789" - assert(CSVInferSchema.inferField(DecimalType(3, -10), value, options) == DoubleType) + assert(inferSchema.inferField(DecimalType(3, -10), value) == DoubleType) // Seq(s"${Long.MaxValue}1", "2015-12-01 00:00:00") should be StringType - assert(CSVInferSchema.inferField(NullType, s"${Long.MaxValue}1", options) == DecimalType(20, 0)) - assert(CSVInferSchema.inferField(DecimalType(20, 0), "2015-12-01 00:00:00", options) + assert(inferSchema.inferField(NullType, s"${Long.MaxValue}1") == DecimalType(20, 0)) + assert(inferSchema.inferField(DecimalType(20, 0), "2015-12-01 00:00:00") == StringType) } test("DoubleType should be inferred when user defined nan/inf are provided") { val options = new CSVOptions(Map("nanValue" -> "nan", "negativeInf" -> "-inf", "positiveInf" -> "inf"), false, "GMT") - assert(CSVInferSchema.inferField(NullType, "nan", options) == DoubleType) - assert(CSVInferSchema.inferField(NullType, "inf", options) == DoubleType) - assert(CSVInferSchema.inferField(NullType, "-inf", options) == DoubleType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(NullType, "nan") == DoubleType) + assert(inferSchema.inferField(NullType, "inf") == DoubleType) + assert(inferSchema.inferField(NullType, "-inf") == DoubleType) + } + + test("inferring the decimal type using locale") { + def checkDecimalInfer(langTag: String, expectedType: DataType): Unit = { + val options = new CSVOptions( + parameters = Map("locale" -> langTag, "inferSchema" -> "true", "sep" -> "|"), + columnPruning = false, + defaultTimeZoneId = "GMT") + val inferSchema = new CSVInferSchema(options) + + val df = new DecimalFormat("", new DecimalFormatSymbols(Locale.forLanguageTag(langTag))) + val input = df.format(Decimal(1000001).toBigDecimal) + + assert(inferSchema.inferField(NullType, input) == expectedType) + } + + Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalInfer(_, DecimalType(7, 0))) } } http://git-wip-us.apache.org/repos/asf/spark/blob/f97326bc/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index e4e7dc2..7212402 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -18,13 +18,17 @@ package org.apache.spark.sql.catalyst.csv import java.math.BigDecimal +import java.text.{DecimalFormat, DecimalFormatSymbols} +import java.util.Locale import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class UnivocityParserSuite extends SparkFunSuite { +class UnivocityParserSuite extends SparkFunSuite with SQLHelper { private val parser = new UnivocityParser( StructType(Seq.empty), new CSVOptions(Map.empty[String, String], false, "GMT")) @@ -196,4 +200,20 @@ class UnivocityParserSuite extends SparkFunSuite { assert(doubleVal2 == Double.PositiveInfinity) } + test("parse decimals using locale") { + def checkDecimalParsing(langTag: String): Unit = { + val decimalVal = new BigDecimal("1000.001") + val decimalType = new DecimalType(10, 5) + val expected = Decimal(decimalVal, decimalType.precision, decimalType.scale) + val df = new DecimalFormat("", new DecimalFormatSymbols(Locale.forLanguageTag(langTag))) + val input = df.format(expected.toBigDecimal) + + val options = new CSVOptions(Map("locale" -> langTag), false, "GMT") + val parser = new UnivocityParser(new StructType().add("d", decimalType), options) + + assert(parser.makeConverter("_1", decimalType, options = options).apply(input) === expected) + } + + Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalParsing) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/f97326bc/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index b35b885..b46dfb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -135,7 +135,7 @@ object TextInputCSVDataSource extends CSVDataSource { val parser = new CsvParser(parsedOptions.asParserSettings) linesWithoutHeader.map(parser.parseLine) } - CSVInferSchema.infer(tokenRDD, header, parsedOptions) + new CSVInferSchema(parsedOptions).infer(tokenRDD, header) case _ => // If the first line could not be read, just return the empty schema. StructType(Nil) @@ -208,7 +208,7 @@ object MultiLineCSVDataSource extends CSVDataSource { encoding = parsedOptions.charset) } val sampled = CSVUtils.sample(tokenRDD, parsedOptions) - CSVInferSchema.infer(sampled, header, parsedOptions) + new CSVInferSchema(parsedOptions).infer(sampled, header) case None => // If the first row could not be read, just return the empty schema. StructType(Nil) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org