Github user cloud-fan commented on a diff in the pull request: https://github.com/apache/spark/pull/16351#discussion_r93381049 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala --- @@ -215,84 +215,133 @@ private[csv] object CSVInferSchema { } private[csv] object CSVTypeCast { + // A `ValueConverter` is responsible for converting the given value to a desired type. + private type ValueConverter = String => Any /** - * Casts given string datum to specified type. - * Currently we do not support complex types (ArrayType, MapType, StructType). + * Create converters which cast each given string datum to each specified type in given schema. + * Currently, we do not support complex types (`ArrayType`, `MapType`, `StructType`). * - * For string types, this is simply the datum. For other types. + * For string types, this is simply the datum. + * For other types, this is converted into the value according to the type. * For other nullable types, returns null if it is null or equals to the value specified * in `nullValue` option. * - * @param datum string value - * @param name field name in schema. - * @param castType data type to cast `datum` into. - * @param nullable nullability for the field. + * @param schema schema that contains data types to cast the given value into. * @param options CSV options. */ - def castTo( + def makeConverters( + schema: StructType, + options: CSVOptions = CSVOptions()): Array[ValueConverter] = { + schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + } + + /** + * Create a converter which converts the string value to a value according to a desired type. + */ + def makeConverter( + name: String, + dataType: DataType, + nullable: Boolean = true, + options: CSVOptions = CSVOptions()): ValueConverter = dataType match { + case _: ByteType => (d: String) => + nullSafeDatum(d, name, nullable, options) { case datum => + datum.toByte + } + + case _: ShortType => (d: String) => + nullSafeDatum(d, name, nullable, options) { case datum => + datum.toShort + } + + case _: IntegerType => (d: String) => + nullSafeDatum(d, name, nullable, options) { case datum => + datum.toInt + } + + case _: LongType => (d: String) => + nullSafeDatum(d, name, nullable, options) { case datum => + datum.toLong + } + + case _: FloatType => (d: String) => + nullSafeDatum(d, name, nullable, options) { + case options.nanValue => Float.NaN + case options.negativeInf => Float.NegativeInfinity + case options.positiveInf => Float.PositiveInfinity + case datum => + Try(datum.toFloat) + .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).floatValue()) + } + + case _: DoubleType => (d: String) => + nullSafeDatum(d, name, nullable, options) { + case options.nanValue => Double.NaN + case options.negativeInf => Double.NegativeInfinity + case options.positiveInf => Double.PositiveInfinity + case datum => + Try(datum.toDouble) + .getOrElse(NumberFormat.getInstance(Locale.US).parse(datum).doubleValue()) + } + + case _: BooleanType => (d: String) => + nullSafeDatum(d, name, nullable, options) { case datum => + datum.toBoolean + } + + case dt: DecimalType => (d: String) => + nullSafeDatum(d, name, nullable, options) { case datum => + val value = new BigDecimal(datum.replaceAll(",", "")) + Decimal(value, dt.precision, dt.scale) + } + + case _: TimestampType => (d: String) => + nullSafeDatum(d, name, nullable, options) { case datum => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. + Try(options.timestampFormat.parse(datum).getTime * 1000L) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.stringToTime(datum).getTime * 1000L + } + } + + case _: DateType => (d: String) => + nullSafeDatum(d, name, nullable, options) { case datum => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681.x + Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime)) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) + } + } + + case _: StringType => (d: String) => + nullSafeDatum(d, name, nullable, options) { case datum => + UTF8String.fromString(datum) + } + + case udt: UserDefinedType[_] => (datum: String) => + makeConverter(name, udt.sqlType, nullable, options) + + case _ => throw new RuntimeException(s"Unsupported type: ${dataType.typeName}") + } + + private def nullSafeDatum( datum: String, name: String, - castType: DataType, - nullable: Boolean = true, - options: CSVOptions = CSVOptions()): Any = { - - // datum can be null if the number of fields found is less than the length of the schema + nullable: Boolean, + options: CSVOptions)(f: PartialFunction[String, Any]): Any = { --- End diff -- why require a `PartialFunction` here?
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org