Repository: spark
Updated Branches:
  refs/heads/master c5ef477d2 -> c9667aff4


[SPARK-25672][SQL] schema_of_csv() - schema inference from an example

## What changes were proposed in this pull request?

In the PR, I propose to add new function - *schema_of_csv()* which infers 
schema of CSV string literal. The result of the function is a string containing 
a schema in DDL format. For example:

```sql
select schema_of_csv('1|abc', map('delimiter', '|'))
```
```
struct<_c0:int,_c1:string>
```

## How was this patch tested?

Added new tests to `CsvFunctionsSuite`, `CsvExpressionsSuite` and SQL tests to 
`csv-functions.sql`

Closes #22666 from MaxGekk/schema_of_csv-function.

Lead-authored-by: hyukjinkwon <gurwls...@apache.org>
Co-authored-by: Maxim Gekk <maxim.g...@databricks.com>
Signed-off-by: hyukjinkwon <gurwls...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c9667aff
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c9667aff
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c9667aff

Branch: refs/heads/master
Commit: c9667aff4f4888b650fad2ed41698025b1e84166
Parents: c5ef477
Author: hyukjinkwon <gurwls...@apache.org>
Authored: Thu Nov 1 09:14:16 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Thu Nov 1 09:14:16 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 |  41 +++-
 .../catalyst/analysis/FunctionRegistry.scala    |   3 +-
 .../spark/sql/catalyst/csv/CSVInferSchema.scala | 220 +++++++++++++++++++
 .../sql/catalyst/expressions/ExprUtils.scala    |  33 ++-
 .../catalyst/expressions/csvExpressions.scala   |  54 +++++
 .../catalyst/expressions/jsonExpressions.scala  |  16 +-
 .../sql/catalyst/csv/CSVInferSchemaSuite.scala  | 142 ++++++++++++
 .../sql/catalyst/csv/UnivocityParserSuite.scala | 199 +++++++++++++++++
 .../expressions/CsvExpressionsSuite.scala       |  10 +
 .../datasources/csv/CSVDataSource.scala         |   2 +-
 .../datasources/csv/CSVInferSchema.scala        | 214 ------------------
 .../scala/org/apache/spark/sql/functions.scala  |  35 +++
 .../sql-tests/inputs/csv-functions.sql          |   8 +
 .../sql-tests/results/csv-functions.sql.out     |  54 ++++-
 .../apache/spark/sql/CsvFunctionsSuite.scala    |  15 ++
 .../datasources/csv/CSVInferSchemaSuite.scala   | 143 ------------
 .../datasources/csv/UnivocityParserSuite.scala  | 200 -----------------
 17 files changed, 803 insertions(+), 586 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index ca2a256..beb1a06 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2364,6 +2364,33 @@ def schema_of_json(json, options={}):
     return Column(jc)
 
 
+@ignore_unicode_prefix
+@since(3.0)
+def schema_of_csv(csv, options={}):
+    """
+    Parses a CSV string and infers its schema in DDL format.
+
+    :param col: a CSV string or a string literal containing a CSV string.
+    :param options: options to control parsing. accepts the same options as 
the CSV datasource
+
+    >>> df = spark.range(1)
+    >>> df.select(schema_of_csv(lit('1|a'), 
{'sep':'|'}).alias("csv")).collect()
+    [Row(csv=u'struct<_c0:int,_c1:string>')]
+    >>> df.select(schema_of_csv('1|a', {'sep':'|'}).alias("csv")).collect()
+    [Row(csv=u'struct<_c0:int,_c1:string>')]
+    """
+    if isinstance(csv, basestring):
+        col = _create_column_from_literal(csv)
+    elif isinstance(csv, Column):
+        col = _to_java_column(csv)
+    else:
+        raise TypeError("schema argument should be a column or string")
+
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.schema_of_csv(col, options)
+    return Column(jc)
+
+
 @since(1.5)
 def size(col):
     """
@@ -2664,13 +2691,13 @@ def from_csv(col, schema, options={}):
     :param schema: a string with schema in DDL format to use when parsing the 
CSV column.
     :param options: options to control parsing. accepts the same options as 
the CSV datasource
 
-    >>> data = [(1, '1')]
-    >>> df = spark.createDataFrame(data, ("key", "value"))
-    >>> df.select(from_csv(df.value, "a INT").alias("csv")).collect()
-    [Row(csv=Row(a=1))]
-    >>> df = spark.createDataFrame(data, ("key", "value"))
-    >>> df.select(from_csv(df.value, lit("a INT")).alias("csv")).collect()
-    [Row(csv=Row(a=1))]
+    >>> data = [("1,2,3",)]
+    >>> df = spark.createDataFrame(data, ("value",))
+    >>> df.select(from_csv(df.value, "a INT, b INT, c 
INT").alias("csv")).collect()
+    [Row(csv=Row(a=1, b=2, c=3))]
+    >>> value = data[0][0]
+    >>> df.select(from_csv(df.value, 
schema_of_csv(value)).alias("csv")).collect()
+    [Row(csv=Row(_c0=1, _c1=2, _c2=3))]
     """
 
     sc = SparkContext._active_spark_context

http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index af6166b..cf8fb7e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -526,7 +526,8 @@ object FunctionRegistry {
     castAlias("string", StringType),
 
     // csv
-    expression[CsvToStructs]("from_csv")
+    expression[CsvToStructs]("from_csv"),
+    expression[SchemaOfCsv]("schema_of_csv")
   )
 
   val builtin: SimpleFunctionRegistry = {

http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
new file mode 100644
index 0000000..799e999
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.csv
+
+import java.math.BigDecimal
+
+import scala.util.control.Exception.allCatch
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.types._
+
+object CSVInferSchema {
+
+  /**
+   * Similar to the JSON schema inference
+   *     1. Infer type of each row
+   *     2. Merge row types to find common type
+   *     3. Replace any null types with string type
+   */
+  def infer(
+      tokenRDD: RDD[Array[String]],
+      header: Array[String],
+      options: CSVOptions): StructType = {
+    val fields = if (options.inferSchemaFlag) {
+      val startType: Array[DataType] = 
Array.fill[DataType](header.length)(NullType)
+      val rootTypes: Array[DataType] =
+        tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes)
+
+      toStructFields(rootTypes, header, options)
+    } else {
+      // By default fields are assumed to be StringType
+      header.map(fieldName => StructField(fieldName, StringType, nullable = 
true))
+    }
+
+    StructType(fields)
+  }
+
+  def toStructFields(
+      fieldTypes: Array[DataType],
+      header: Array[String],
+      options: CSVOptions): Array[StructField] = {
+    header.zip(fieldTypes).map { case (thisHeader, rootType) =>
+      val dType = rootType match {
+        case _: NullType => StringType
+        case other => other
+      }
+      StructField(thisHeader, dType, nullable = true)
+    }
+  }
+
+  def inferRowType(options: CSVOptions)
+      (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
+    var i = 0
+    while (i < math.min(rowSoFar.length, next.length)) {  // May have columns 
on right missing.
+      rowSoFar(i) = inferField(rowSoFar(i), next(i), options)
+      i+=1
+    }
+    rowSoFar
+  }
+
+  def mergeRowTypes(first: Array[DataType], second: Array[DataType]): 
Array[DataType] = {
+    first.zipAll(second, NullType, NullType).map { case (a, b) =>
+      compatibleType(a, b).getOrElse(NullType)
+    }
+  }
+
+  /**
+   * Infer type of string field. Given known type Double, and a string "1", 
there is no
+   * point checking if it is an Int, as the final type must be Double or 
higher.
+   */
+  def inferField(typeSoFar: DataType, field: String, options: CSVOptions): 
DataType = {
+    if (field == null || field.isEmpty || field == options.nullValue) {
+      typeSoFar
+    } else {
+      typeSoFar match {
+        case NullType => tryParseInteger(field, options)
+        case IntegerType => tryParseInteger(field, options)
+        case LongType => tryParseLong(field, options)
+        case _: DecimalType =>
+          // DecimalTypes have different precisions and scales, so we try to 
find the common type.
+          compatibleType(typeSoFar, tryParseDecimal(field, 
options)).getOrElse(StringType)
+        case DoubleType => tryParseDouble(field, options)
+        case TimestampType => tryParseTimestamp(field, options)
+        case BooleanType => tryParseBoolean(field, options)
+        case StringType => StringType
+        case other: DataType =>
+          throw new UnsupportedOperationException(s"Unexpected data type 
$other")
+      }
+    }
+  }
+
+  private def isInfOrNan(field: String, options: CSVOptions): Boolean = {
+    field == options.nanValue || field == options.negativeInf || field == 
options.positiveInf
+  }
+
+  private def tryParseInteger(field: String, options: CSVOptions): DataType = {
+    if ((allCatch opt field.toInt).isDefined) {
+      IntegerType
+    } else {
+      tryParseLong(field, options)
+    }
+  }
+
+  private def tryParseLong(field: String, options: CSVOptions): DataType = {
+    if ((allCatch opt field.toLong).isDefined) {
+      LongType
+    } else {
+      tryParseDecimal(field, options)
+    }
+  }
+
+  private def tryParseDecimal(field: String, options: CSVOptions): DataType = {
+    val decimalTry = allCatch opt {
+      // `BigDecimal` conversion can fail when the `field` is not a form of 
number.
+      val bigDecimal = new BigDecimal(field)
+      // Because many other formats do not support decimal, it reduces the 
cases for
+      // decimals by disallowing values having scale (eg. `1.1`).
+      if (bigDecimal.scale <= 0) {
+        // `DecimalType` conversion can fail when
+        //   1. The precision is bigger than 38.
+        //   2. scale is bigger than precision.
+        DecimalType(bigDecimal.precision, bigDecimal.scale)
+      } else {
+        tryParseDouble(field, options)
+      }
+    }
+    decimalTry.getOrElse(tryParseDouble(field, options))
+  }
+
+  private def tryParseDouble(field: String, options: CSVOptions): DataType = {
+    if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field, options)) 
{
+      DoubleType
+    } else {
+      tryParseTimestamp(field, options)
+    }
+  }
+
+  private def tryParseTimestamp(field: String, options: CSVOptions): DataType 
= {
+    // This case infers a custom `dataFormat` is set.
+    if ((allCatch opt options.timestampFormat.parse(field)).isDefined) {
+      TimestampType
+    } else if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) {
+      // We keep this for backwards compatibility.
+      TimestampType
+    } else {
+      tryParseBoolean(field, options)
+    }
+  }
+
+  private def tryParseBoolean(field: String, options: CSVOptions): DataType = {
+    if ((allCatch opt field.toBoolean).isDefined) {
+      BooleanType
+    } else {
+      stringType()
+    }
+  }
+
+  // Defining a function to return the StringType constant is necessary in 
order to work around
+  // a Scala compiler issue which leads to runtime incompatibilities with 
certain Spark versions;
+  // see issue #128 for more details.
+  private def stringType(): DataType = {
+    StringType
+  }
+
+  /**
+   * Returns the common data type given two input data types so that the 
return type
+   * is compatible with both input data types.
+   */
+  private def compatibleType(t1: DataType, t2: DataType): Option[DataType] = {
+    TypeCoercion.findTightestCommonType(t1, 
t2).orElse(findCompatibleTypeForCSV(t1, t2))
+  }
+
+  /**
+   * The following pattern matching represents additional type promotion rules 
that
+   * are CSV specific.
+   */
+  private val findCompatibleTypeForCSV: (DataType, DataType) => 
Option[DataType] = {
+    case (StringType, t2) => Some(StringType)
+    case (t1, StringType) => Some(StringType)
+
+    // These two cases below deal with when `IntegralType` is larger than 
`DecimalType`.
+    case (t1: IntegralType, t2: DecimalType) =>
+      compatibleType(DecimalType.forType(t1), t2)
+    case (t1: DecimalType, t2: IntegralType) =>
+      compatibleType(t1, DecimalType.forType(t2))
+
+    // Double support larger range than fixed decimal, DecimalType.Maximum 
should be enough
+    // in most case, also have better precision.
+    case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
+      Some(DoubleType)
+
+    case (t1: DecimalType, t2: DecimalType) =>
+      val scale = math.max(t1.scale, t2.scale)
+      val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
+      if (range + scale > 38) {
+        // DecimalType can't support precision > 38
+        Some(DoubleType)
+      } else {
+        Some(DecimalType(range + scale, scale))
+      }
+    case _ => None
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
index e570889..040b56c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
@@ -19,14 +19,39 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
-import org.apache.spark.sql.types.{MapType, StringType, StructType}
+import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType}
+import org.apache.spark.unsafe.types.UTF8String
 
 object ExprUtils {
 
-  def evalSchemaExpr(exp: Expression): StructType = exp match {
-    case Literal(s, StringType) => StructType.fromDDL(s.toString)
+  def evalSchemaExpr(exp: Expression): StructType = {
+    // Use `DataType.fromDDL` since the type string can be struct<...>.
+    val dataType = exp match {
+      case Literal(s, StringType) =>
+        DataType.fromDDL(s.toString)
+      case e @ SchemaOfCsv(_: Literal, _) =>
+        val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String]
+        DataType.fromDDL(ddlSchema.toString)
+      case e => throw new AnalysisException(
+        "Schema should be specified in DDL format as a string literal or 
output of " +
+          s"the schema_of_csv function instead of ${e.sql}")
+    }
+
+    if (!dataType.isInstanceOf[StructType]) {
+      throw new AnalysisException(
+        s"Schema should be struct type but got ${dataType.sql}.")
+    }
+    dataType.asInstanceOf[StructType]
+  }
+
+  def evalTypeExpr(exp: Expression): DataType = exp match {
+    case Literal(s, StringType) => DataType.fromDDL(s.toString)
+    case e @ SchemaOfJson(_: Literal, _) =>
+      val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String]
+      DataType.fromDDL(ddlSchema.toString)
     case e => throw new AnalysisException(
-      s"Schema should be specified in DDL format as a string literal instead 
of ${e.sql}")
+      "Schema should be specified in DDL format as a string literal or output 
of " +
+        s"the schema_of_json function instead of ${e.sql}")
   }
 
   def convertToMapData(exp: Expression): Map[String, String] = exp match {

http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
index 853b1ea..e70296f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
@@ -17,8 +17,11 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import com.univocity.parsers.csv.CsvParser
+
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.csv._
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.util._
@@ -120,3 +123,54 @@ case class CsvToStructs(
 
   override def prettyName: String = "from_csv"
 }
+
+/**
+ * A function infers schema of CSV string.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(csv[, options]) - Returns schema in the DDL format of CSV 
string.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_('1,abc');
+       struct<_c0:int,_c1:string>
+  """,
+  since = "3.0.0")
+case class SchemaOfCsv(
+    child: Expression,
+    options: Map[String, String])
+  extends UnaryExpression with CodegenFallback {
+
+  def this(child: Expression) = this(child, Map.empty[String, String])
+
+  def this(child: Expression, options: Expression) = this(
+    child = child,
+    options = ExprUtils.convertToMapData(options))
+
+  override def dataType: DataType = StringType
+
+  override def nullable: Boolean = false
+
+  @transient
+  private lazy val csv = child.eval().asInstanceOf[UTF8String]
+
+  override def checkInputDataTypes(): TypeCheckResult = child match {
+    case Literal(s, StringType) if s != null => super.checkInputDataTypes()
+    case _ => TypeCheckResult.TypeCheckFailure(
+      s"The input csv should be a string literal and not null; however, got 
${child.sql}.")
+  }
+
+  override def eval(v: InternalRow): Any = {
+    val parsedOptions = new CSVOptions(options, true, "UTC")
+    val parser = new CsvParser(parsedOptions.asParserSettings)
+    val row = parser.parseLine(csv.toString)
+    assert(row != null, "Parsed CSV record should not be null.")
+
+    val header = row.zipWithIndex.map { case (_, index) => s"_c$index" }
+    val startType: Array[DataType] = 
Array.fill[DataType](header.length)(NullType)
+    val fieldTypes = CSVInferSchema.inferRowType(parsedOptions)(startType, row)
+    val st = StructType(CSVInferSchema.toStructFields(fieldTypes, header, 
parsedOptions))
+    UTF8String.fromString(st.catalogString)
+  }
+
+  override def prettyName: String = "schema_of_csv"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 77af590..eafcb61 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -529,7 +529,7 @@ case class JsonToStructs(
   // Used in `FunctionRegistry`
   def this(child: Expression, schema: Expression, options: Map[String, 
String]) =
     this(
-      schema = JsonExprUtils.evalSchemaExpr(schema),
+      schema = ExprUtils.evalTypeExpr(schema),
       options = options,
       child = child,
       timeZoneId = None)
@@ -538,7 +538,7 @@ case class JsonToStructs(
 
   def this(child: Expression, schema: Expression, options: Expression) =
     this(
-      schema = JsonExprUtils.evalSchemaExpr(schema),
+      schema = ExprUtils.evalTypeExpr(schema),
       options = ExprUtils.convertToMapData(options),
       child = child,
       timeZoneId = None)
@@ -784,15 +784,3 @@ case class SchemaOfJson(
 
   override def prettyName: String = "schema_of_json"
 }
-
-object JsonExprUtils {
-  def evalSchemaExpr(exp: Expression): DataType = exp match {
-    case Literal(s, StringType) => DataType.fromDDL(s.toString)
-    case e @ SchemaOfJson(_: Literal, _) =>
-      val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String]
-      DataType.fromDDL(ddlSchema.toString)
-    case e => throw new AnalysisException(
-      "Schema should be specified in DDL format as a string literal" +
-      s" or output of the schema_of_json function instead of ${e.sql}")
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala
new file mode 100644
index 0000000..651846d
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.csv
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types._
+
+class CSVInferSchemaSuite extends SparkFunSuite {
+
+  test("String fields types are inferred correctly from null types") {
+    val options = new CSVOptions(Map.empty[String, String], false, "GMT")
+    assert(CSVInferSchema.inferField(NullType, "", options) == NullType)
+    assert(CSVInferSchema.inferField(NullType, null, options) == NullType)
+    assert(CSVInferSchema.inferField(NullType, "100000000000", options) == 
LongType)
+    assert(CSVInferSchema.inferField(NullType, "60", options) == IntegerType)
+    assert(CSVInferSchema.inferField(NullType, "3.5", options) == DoubleType)
+    assert(CSVInferSchema.inferField(NullType, "test", options) == StringType)
+    assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) 
== TimestampType)
+    assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType)
+    assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == 
BooleanType)
+
+    val textValueOne = Long.MaxValue.toString + "0"
+    val decimalValueOne = new java.math.BigDecimal(textValueOne)
+    val expectedTypeOne = DecimalType(decimalValueOne.precision, 
decimalValueOne.scale)
+    assert(CSVInferSchema.inferField(NullType, textValueOne, options) == 
expectedTypeOne)
+  }
+
+  test("String fields types are inferred correctly from other types") {
+    val options = new CSVOptions(Map.empty[String, String], false, "GMT")
+    assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType)
+    assert(CSVInferSchema.inferField(LongType, "test", options) == StringType)
+    assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == 
DoubleType)
+    assert(CSVInferSchema.inferField(DoubleType, null, options) == DoubleType)
+    assert(CSVInferSchema.inferField(DoubleType, "test", options) == 
StringType)
+    assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00", options) 
== TimestampType)
+    assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00", 
options) == TimestampType)
+    assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType)
+    assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == 
BooleanType)
+    assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == 
BooleanType)
+
+    val textValueOne = Long.MaxValue.toString + "0"
+    val decimalValueOne = new java.math.BigDecimal(textValueOne)
+    val expectedTypeOne = DecimalType(decimalValueOne.precision, 
decimalValueOne.scale)
+    assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == 
expectedTypeOne)
+  }
+
+  test("Timestamp field types are inferred correctly via custom data format") {
+    var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), false, 
"GMT")
+    assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == 
TimestampType)
+    options = new CSVOptions(Map("timestampFormat" -> "yyyy"), false, "GMT")
+    assert(CSVInferSchema.inferField(TimestampType, "2015", options) == 
TimestampType)
+  }
+
+  test("Timestamp field types are inferred correctly from other types") {
+    val options = new CSVOptions(Map.empty[String, String], false, "GMT")
+    assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == 
StringType)
+    assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) 
== StringType)
+    assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == 
StringType)
+  }
+
+  test("Boolean fields types are inferred correctly from other types") {
+    val options = new CSVOptions(Map.empty[String, String], false, "GMT")
+    assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType)
+    assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == 
StringType)
+  }
+
+  test("Type arrays are merged to highest common type") {
+    assert(
+      CSVInferSchema.mergeRowTypes(Array(StringType),
+        Array(DoubleType)).deep == Array(StringType).deep)
+    assert(
+      CSVInferSchema.mergeRowTypes(Array(IntegerType),
+        Array(LongType)).deep == Array(LongType).deep)
+    assert(
+      CSVInferSchema.mergeRowTypes(Array(DoubleType),
+        Array(LongType)).deep == Array(DoubleType).deep)
+  }
+
+  test("Null fields are handled properly when a nullValue is specified") {
+    var options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT")
+    assert(CSVInferSchema.inferField(NullType, "null", options) == NullType)
+    assert(CSVInferSchema.inferField(StringType, "null", options) == 
StringType)
+    assert(CSVInferSchema.inferField(LongType, "null", options) == LongType)
+
+    options = new CSVOptions(Map("nullValue" -> "\\N"), false, "GMT")
+    assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == 
IntegerType)
+    assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType)
+    assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == 
TimestampType)
+    assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == 
BooleanType)
+    assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == 
DecimalType(1, 1))
+  }
+
+  test("Merging Nulltypes should yield Nulltype.") {
+    val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), 
Array(NullType))
+    assert(mergedNullTypes.deep == Array(NullType).deep)
+  }
+
+  test("SPARK-18433: Improve DataSource option keys to be more 
case-insensitive") {
+    val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), false, 
"GMT")
+    assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == 
TimestampType)
+  }
+
+  test("SPARK-18877: `inferField` on DecimalType should find a common type 
with `typeSoFar`") {
+    val options = new CSVOptions(Map.empty[String, String], false, "GMT")
+
+    // 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9).
+    assert(CSVInferSchema.inferField(DecimalType(3, -10), "1.19E+11", options) 
==
+      DecimalType(4, -9))
+
+    // BigDecimal("12345678901234567890.01234567890123456789") is precision 40 
and scale 20.
+    val value = "12345678901234567890.01234567890123456789"
+    assert(CSVInferSchema.inferField(DecimalType(3, -10), value, options) == 
DoubleType)
+
+    // Seq(s"${Long.MaxValue}1", "2015-12-01 00:00:00") should be StringType
+    assert(CSVInferSchema.inferField(NullType, s"${Long.MaxValue}1", options) 
== DecimalType(20, 0))
+    assert(CSVInferSchema.inferField(DecimalType(20, 0), "2015-12-01 
00:00:00", options)
+      == StringType)
+  }
+
+  test("DoubleType should be inferred when user defined nan/inf are provided") 
{
+    val options = new CSVOptions(Map("nanValue" -> "nan", "negativeInf" -> 
"-inf",
+      "positiveInf" -> "inf"), false, "GMT")
+    assert(CSVInferSchema.inferField(NullType, "nan", options) == DoubleType)
+    assert(CSVInferSchema.inferField(NullType, "inf", options) == DoubleType)
+    assert(CSVInferSchema.inferField(NullType, "-inf", options) == DoubleType)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala
new file mode 100644
index 0000000..e4e7dc2
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.csv
+
+import java.math.BigDecimal
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+class UnivocityParserSuite extends SparkFunSuite {
+  private val parser = new UnivocityParser(
+    StructType(Seq.empty),
+    new CSVOptions(Map.empty[String, String], false, "GMT"))
+
+  private def assertNull(v: Any) = assert(v == null)
+
+  test("Can parse decimal type values") {
+    val stringValues = Seq("10.05", "1,000.01", "158,058,049.001")
+    val decimalValues = Seq(10.05, 1000.01, 158058049.001)
+    val decimalType = new DecimalType()
+
+    stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) =>
+      val decimalValue = new BigDecimal(decimalVal.toString)
+      val options = new CSVOptions(Map.empty[String, String], false, "GMT")
+      assert(parser.makeConverter("_1", decimalType, options = 
options).apply(strVal) ===
+        Decimal(decimalValue, decimalType.precision, decimalType.scale))
+    }
+  }
+
+  test("Nullable types are handled") {
+    val types = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, 
DoubleType,
+      BooleanType, DecimalType.DoubleDecimal, TimestampType, DateType, 
StringType)
+
+    // Nullable field with nullValue option.
+    types.foreach { t =>
+      // Tests that a custom nullValue.
+      val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), false, 
"GMT")
+      val converter =
+        parser.makeConverter("_1", t, nullable = true, options = 
nullValueOptions)
+      assertNull(converter.apply("-"))
+      assertNull(converter.apply(null))
+
+      // Tests that the default nullValue is empty string.
+      val options = new CSVOptions(Map.empty[String, String], false, "GMT")
+      assertNull(parser.makeConverter("_1", t, nullable = true, options = 
options).apply(""))
+    }
+
+    // Not nullable field with nullValue option.
+    types.foreach { t =>
+      // Casts a null to not nullable field should throw an exception.
+      val options = new CSVOptions(Map("nullValue" -> "-"), false, "GMT")
+      val converter =
+        parser.makeConverter("_1", t, nullable = false, options = options)
+      var message = intercept[RuntimeException] {
+        converter.apply("-")
+      }.getMessage
+      assert(message.contains("null value found but field _1 is not 
nullable."))
+      message = intercept[RuntimeException] {
+        converter.apply(null)
+      }.getMessage
+      assert(message.contains("null value found but field _1 is not 
nullable."))
+    }
+
+    // If nullValue is different with empty string, then, empty string should 
not be casted into
+    // null.
+    Seq(true, false).foreach { b =>
+      val options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT")
+      val converter =
+        parser.makeConverter("_1", StringType, nullable = b, options = options)
+      assert(converter.apply("") == UTF8String.fromString(""))
+    }
+  }
+
+  test("Throws exception for empty string with non null type") {
+      val options = new CSVOptions(Map.empty[String, String], false, "GMT")
+    val exception = intercept[RuntimeException]{
+      parser.makeConverter("_1", IntegerType, nullable = false, options = 
options).apply("")
+    }
+    assert(exception.getMessage.contains("null value found but field _1 is not 
nullable."))
+  }
+
+  test("Types are cast correctly") {
+    val options = new CSVOptions(Map.empty[String, String], false, "GMT")
+    assert(parser.makeConverter("_1", ByteType, options = options).apply("10") 
== 10)
+    assert(parser.makeConverter("_1", ShortType, options = 
options).apply("10") == 10)
+    assert(parser.makeConverter("_1", IntegerType, options = 
options).apply("10") == 10)
+    assert(parser.makeConverter("_1", LongType, options = options).apply("10") 
== 10)
+    assert(parser.makeConverter("_1", FloatType, options = 
options).apply("1.00") == 1.0)
+    assert(parser.makeConverter("_1", DoubleType, options = 
options).apply("1.00") == 1.0)
+    assert(parser.makeConverter("_1", BooleanType, options = 
options).apply("true") == true)
+
+    val timestampsOptions =
+      new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), false, 
"GMT")
+    val customTimestamp = "31/01/2015 00:00"
+    val expectedTime = 
timestampsOptions.timestampFormat.parse(customTimestamp).getTime
+    val castedTimestamp =
+      parser.makeConverter("_1", TimestampType, nullable = true, options = 
timestampsOptions)
+        .apply(customTimestamp)
+    assert(castedTimestamp == expectedTime * 1000L)
+
+    val customDate = "31/01/2015"
+    val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), false, 
"GMT")
+    val expectedDate = dateOptions.dateFormat.parse(customDate).getTime
+    val castedDate =
+      parser.makeConverter("_1", DateType, nullable = true, options = 
dateOptions)
+        .apply(customTimestamp)
+    assert(castedDate == DateTimeUtils.millisToDays(expectedDate))
+
+    val timestamp = "2015-01-01 00:00:00"
+    assert(parser.makeConverter("_1", TimestampType, options = 
options).apply(timestamp) ==
+      DateTimeUtils.stringToTime(timestamp).getTime  * 1000L)
+    assert(parser.makeConverter("_1", DateType, options = 
options).apply("2015-01-01") ==
+      
DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime))
+  }
+
+  test("Throws exception for casting an invalid string to Float and Double 
Types") {
+    val options = new CSVOptions(Map.empty[String, String], false, "GMT")
+    val types = Seq(DoubleType, FloatType)
+    val input = Seq("10u000", "abc", "1 2/3")
+    types.foreach { dt =>
+      input.foreach { v =>
+        val message = intercept[NumberFormatException] {
+          parser.makeConverter("_1", dt, options = options).apply(v)
+        }.getMessage
+        assert(message.contains(v))
+      }
+    }
+  }
+
+  test("Float NaN values are parsed correctly") {
+    val options = new CSVOptions(Map("nanValue" -> "nn"), false, "GMT")
+    val floatVal: Float = parser.makeConverter(
+      "_1", FloatType, nullable = true, options = options
+    ).apply("nn").asInstanceOf[Float]
+
+    // Java implements the IEEE-754 floating point standard which guarantees 
that any comparison
+    // against NaN will return false (except != which returns true)
+    assert(floatVal != floatVal)
+  }
+
+  test("Double NaN values are parsed correctly") {
+    val options = new CSVOptions(Map("nanValue" -> "-"), false, "GMT")
+    val doubleVal: Double = parser.makeConverter(
+      "_1", DoubleType, nullable = true, options = options
+    ).apply("-").asInstanceOf[Double]
+
+    assert(doubleVal.isNaN)
+  }
+
+  test("Float infinite values can be parsed") {
+    val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), 
false, "GMT")
+    val floatVal1 = parser.makeConverter(
+      "_1", FloatType, nullable = true, options = negativeInfOptions
+    ).apply("max").asInstanceOf[Float]
+
+    assert(floatVal1 == Float.NegativeInfinity)
+
+    val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), 
false, "GMT")
+    val floatVal2 = parser.makeConverter(
+      "_1", FloatType, nullable = true, options = positiveInfOptions
+    ).apply("max").asInstanceOf[Float]
+
+    assert(floatVal2 == Float.PositiveInfinity)
+  }
+
+  test("Double infinite values can be parsed") {
+    val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), 
false, "GMT")
+    val doubleVal1 = parser.makeConverter(
+      "_1", DoubleType, nullable = true, options = negativeInfOptions
+    ).apply("max").asInstanceOf[Double]
+
+    assert(doubleVal1 == Double.NegativeInfinity)
+
+    val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), 
false, "GMT")
+    val doubleVal2 = parser.makeConverter(
+      "_1", DoubleType, nullable = true, options = positiveInfOptions
+    ).apply("max").asInstanceOf[Double]
+
+    assert(doubleVal2 == Double.PositiveInfinity)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala
index 65987af..386e0d1 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala
@@ -155,4 +155,14 @@ class CsvExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper with P
     }.getCause
     assert(exception.getMessage.contains("from_csv() doesn't support the 
DROPMALFORMED mode"))
   }
+
+  test("infer schema of CSV strings") {
+    checkEvaluation(new SchemaOfCsv(Literal.create("1,abc")), 
"struct<_c0:int,_c1:string>")
+  }
+
+  test("infer schema of CSV strings by using options") {
+    checkEvaluation(
+      new SchemaOfCsv(Literal.create("1|abc"), Map("delimiter" -> "|")),
+      "struct<_c0:int,_c1:string>")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index 9e7b45d..4808e8e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -34,7 +34,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.{BinaryFileRDD, RDD}
 import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, 
UnivocityParser}
+import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVInferSchema, 
CSVOptions, UnivocityParser}
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.datasources.text.TextFileFormat
 import org.apache.spark.sql.types.StructType

http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
deleted file mode 100644
index 4326a18..0000000
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ /dev/null
@@ -1,214 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.csv
-
-import java.math.BigDecimal
-
-import scala.util.control.Exception._
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.TypeCoercion
-import org.apache.spark.sql.catalyst.csv.CSVOptions
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.types._
-
-private[csv] object CSVInferSchema {
-
-  /**
-   * Similar to the JSON schema inference
-   *     1. Infer type of each row
-   *     2. Merge row types to find common type
-   *     3. Replace any null types with string type
-   */
-  def infer(
-      tokenRDD: RDD[Array[String]],
-      header: Array[String],
-      options: CSVOptions): StructType = {
-    val fields = if (options.inferSchemaFlag) {
-      val startType: Array[DataType] = 
Array.fill[DataType](header.length)(NullType)
-      val rootTypes: Array[DataType] =
-        tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes)
-
-      header.zip(rootTypes).map { case (thisHeader, rootType) =>
-        val dType = rootType match {
-          case _: NullType => StringType
-          case other => other
-        }
-        StructField(thisHeader, dType, nullable = true)
-      }
-    } else {
-      // By default fields are assumed to be StringType
-      header.map(fieldName => StructField(fieldName, StringType, nullable = 
true))
-    }
-
-    StructType(fields)
-  }
-
-  private def inferRowType(options: CSVOptions)
-      (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
-    var i = 0
-    while (i < math.min(rowSoFar.length, next.length)) {  // May have columns 
on right missing.
-      rowSoFar(i) = inferField(rowSoFar(i), next(i), options)
-      i+=1
-    }
-    rowSoFar
-  }
-
-  def mergeRowTypes(first: Array[DataType], second: Array[DataType]): 
Array[DataType] = {
-    first.zipAll(second, NullType, NullType).map { case (a, b) =>
-      compatibleType(a, b).getOrElse(NullType)
-    }
-  }
-
-  /**
-   * Infer type of string field. Given known type Double, and a string "1", 
there is no
-   * point checking if it is an Int, as the final type must be Double or 
higher.
-   */
-  def inferField(typeSoFar: DataType, field: String, options: CSVOptions): 
DataType = {
-    if (field == null || field.isEmpty || field == options.nullValue) {
-      typeSoFar
-    } else {
-      typeSoFar match {
-        case NullType => tryParseInteger(field, options)
-        case IntegerType => tryParseInteger(field, options)
-        case LongType => tryParseLong(field, options)
-        case _: DecimalType =>
-          // DecimalTypes have different precisions and scales, so we try to 
find the common type.
-          compatibleType(typeSoFar, tryParseDecimal(field, 
options)).getOrElse(StringType)
-        case DoubleType => tryParseDouble(field, options)
-        case TimestampType => tryParseTimestamp(field, options)
-        case BooleanType => tryParseBoolean(field, options)
-        case StringType => StringType
-        case other: DataType =>
-          throw new UnsupportedOperationException(s"Unexpected data type 
$other")
-      }
-    }
-  }
-
-  private def isInfOrNan(field: String, options: CSVOptions): Boolean = {
-    field == options.nanValue || field == options.negativeInf || field == 
options.positiveInf
-  }
-
-  private def tryParseInteger(field: String, options: CSVOptions): DataType = {
-    if ((allCatch opt field.toInt).isDefined) {
-      IntegerType
-    } else {
-      tryParseLong(field, options)
-    }
-  }
-
-  private def tryParseLong(field: String, options: CSVOptions): DataType = {
-    if ((allCatch opt field.toLong).isDefined) {
-      LongType
-    } else {
-      tryParseDecimal(field, options)
-    }
-  }
-
-  private def tryParseDecimal(field: String, options: CSVOptions): DataType = {
-    val decimalTry = allCatch opt {
-      // `BigDecimal` conversion can fail when the `field` is not a form of 
number.
-      val bigDecimal = new BigDecimal(field)
-      // Because many other formats do not support decimal, it reduces the 
cases for
-      // decimals by disallowing values having scale (eg. `1.1`).
-      if (bigDecimal.scale <= 0) {
-        // `DecimalType` conversion can fail when
-        //   1. The precision is bigger than 38.
-        //   2. scale is bigger than precision.
-        DecimalType(bigDecimal.precision, bigDecimal.scale)
-      } else {
-        tryParseDouble(field, options)
-      }
-    }
-    decimalTry.getOrElse(tryParseDouble(field, options))
-  }
-
-  private def tryParseDouble(field: String, options: CSVOptions): DataType = {
-    if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field, options)) 
{
-      DoubleType
-    } else {
-      tryParseTimestamp(field, options)
-    }
-  }
-
-  private def tryParseTimestamp(field: String, options: CSVOptions): DataType 
= {
-    // This case infers a custom `dataFormat` is set.
-    if ((allCatch opt options.timestampFormat.parse(field)).isDefined) {
-      TimestampType
-    } else if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) {
-      // We keep this for backwards compatibility.
-      TimestampType
-    } else {
-      tryParseBoolean(field, options)
-    }
-  }
-
-  private def tryParseBoolean(field: String, options: CSVOptions): DataType = {
-    if ((allCatch opt field.toBoolean).isDefined) {
-      BooleanType
-    } else {
-      stringType()
-    }
-  }
-
-  // Defining a function to return the StringType constant is necessary in 
order to work around
-  // a Scala compiler issue which leads to runtime incompatibilities with 
certain Spark versions;
-  // see issue #128 for more details.
-  private def stringType(): DataType = {
-    StringType
-  }
-
-  /**
-   * Returns the common data type given two input data types so that the 
return type
-   * is compatible with both input data types.
-   */
-  private def compatibleType(t1: DataType, t2: DataType): Option[DataType] = {
-    TypeCoercion.findTightestCommonType(t1, 
t2).orElse(findCompatibleTypeForCSV(t1, t2))
-  }
-
-  /**
-   * The following pattern matching represents additional type promotion rules 
that
-   * are CSV specific.
-   */
-  private val findCompatibleTypeForCSV: (DataType, DataType) => 
Option[DataType] = {
-    case (StringType, t2) => Some(StringType)
-    case (t1, StringType) => Some(StringType)
-
-    // These two cases below deal with when `IntegralType` is larger than 
`DecimalType`.
-    case (t1: IntegralType, t2: DecimalType) =>
-      compatibleType(DecimalType.forType(t1), t2)
-    case (t1: DecimalType, t2: IntegralType) =>
-      compatibleType(t1, DecimalType.forType(t2))
-
-    // Double support larger range than fixed decimal, DecimalType.Maximum 
should be enough
-    // in most case, also have better precision.
-    case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
-      Some(DoubleType)
-
-    case (t1: DecimalType, t2: DecimalType) =>
-      val scale = math.max(t1.scale, t2.scale)
-      val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
-      if (range + scale > 38) {
-        // DecimalType can't support precision > 38
-        Some(DoubleType)
-      } else {
-        Some(DecimalType(range + scale, scale))
-      }
-    case _ => None
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 5348b65..f8c4d88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3870,6 +3870,41 @@ object functions {
     withExpr(new CsvToStructs(e.expr, schema.expr, options.asScala.toMap))
   }
 
+  /**
+   * Parses a CSV string and infers its schema in DDL format.
+   *
+   * @param csv a CSV string.
+   *
+   * @group collection_funcs
+   * @since 3.0.0
+   */
+  def schema_of_csv(csv: String): Column = schema_of_csv(lit(csv))
+
+  /**
+   * Parses a CSV string and infers its schema in DDL format.
+   *
+   * @param csv a string literal containing a CSV string.
+   *
+   * @group collection_funcs
+   * @since 3.0.0
+   */
+  def schema_of_csv(csv: Column): Column = withExpr(new SchemaOfCsv(csv.expr))
+
+  /**
+   * Parses a CSV string and infers its schema in DDL format using options.
+   *
+   * @param csv a string literal containing a CSV string.
+   * @param options options to control how the CSV is parsed. accepts the same 
options and the
+   *                json data source. See [[DataFrameReader#csv]].
+   * @return a column with string literal containing schema in DDL format.
+   *
+   * @group collection_funcs
+   * @since 3.0.0
+   */
+  def schema_of_csv(csv: Column, options: java.util.Map[String, String]): 
Column = {
+    withExpr(SchemaOfCsv(csv.expr, options.asScala.toMap))
+  }
+
   // scalastyle:off line.size.limit
   // scalastyle:off parameter.number
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql 
b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql
index d2214fd..5be6f80 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql
@@ -7,3 +7,11 @@ select from_csv('1', 'a InvalidType');
 select from_csv('1', 'a INT', named_struct('mode', 'PERMISSIVE'));
 select from_csv('1', 'a INT', map('mode', 1));
 select from_csv();
+-- infer schema of json literal
+select from_csv('1,abc', schema_of_csv('1,abc'));
+select schema_of_csv('1|abc', map('delimiter', '|'));
+select schema_of_csv(null);
+CREATE TEMPORARY VIEW csvTable(csvField, a) AS SELECT * FROM VALUES ('1,abc', 
'a');
+SELECT schema_of_csv(csvField) FROM csvTable;
+-- Clean up
+DROP VIEW IF EXISTS csvTable;

http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out
index f19f34a..677bbd9 100644
--- a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 7
+-- Number of queries: 13
 
 
 -- !query 0
@@ -24,7 +24,7 @@ select from_csv('1', 1)
 struct<>
 -- !query 2 output
 org.apache.spark.sql.AnalysisException
-Schema should be specified in DDL format as a string literal instead of 1;; 
line 1 pos 7
+Schema should be specified in DDL format as a string literal or output of the 
schema_of_csv function instead of 1;; line 1 pos 7
 
 
 -- !query 3
@@ -67,3 +67,53 @@ struct<>
 -- !query 6 output
 org.apache.spark.sql.AnalysisException
 Invalid number of arguments for function from_csv. Expected: one of 2 and 3; 
Found: 0; line 1 pos 7
+
+
+-- !query 7
+select from_csv('1,abc', schema_of_csv('1,abc'))
+-- !query 7 schema
+struct<from_csv(1,abc):struct<_c0:int,_c1:string>>
+-- !query 7 output
+{"_c0":1,"_c1":"abc"}
+
+
+-- !query 8
+select schema_of_csv('1|abc', map('delimiter', '|'))
+-- !query 8 schema
+struct<schema_of_csv(1|abc):string>
+-- !query 8 output
+struct<_c0:int,_c1:string>
+
+
+-- !query 9
+select schema_of_csv(null)
+-- !query 9 schema
+struct<>
+-- !query 9 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'schema_of_csv(NULL)' due to data type mismatch: The input csv 
should be a string literal and not null; however, got NULL.; line 1 pos 7
+
+
+-- !query 10
+CREATE TEMPORARY VIEW csvTable(csvField, a) AS SELECT * FROM VALUES ('1,abc', 
'a')
+-- !query 10 schema
+struct<>
+-- !query 10 output
+
+
+
+-- !query 11
+SELECT schema_of_csv(csvField) FROM csvTable
+-- !query 11 schema
+struct<>
+-- !query 11 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'schema_of_csv(csvtable.`csvField`)' due to data type mismatch: 
The input csv should be a string literal and not null; however, got 
csvtable.`csvField`.; line 1 pos 7
+
+
+-- !query 12
+DROP VIEW IF EXISTS csvTable
+-- !query 12 schema
+struct<>
+-- !query 12 output
+

http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala
index 38a2143..9395f05 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala
@@ -59,4 +59,19 @@ class CsvFunctionsSuite extends QueryTest with 
SharedSQLContext {
       Row(Row(null, null, "0,2013-111-11 12:13:14")),
       Row(Row(1, java.sql.Date.valueOf("1983-08-04"), null))))
   }
+
+  test("schema_of_csv - infers schemas") {
+    checkAnswer(
+      spark.range(1).select(schema_of_csv(lit("0.1,1"))),
+      Seq(Row("struct<_c0:double,_c1:int>")))
+    checkAnswer(
+      spark.range(1).select(schema_of_csv("0.1,1")),
+      Seq(Row("struct<_c0:double,_c1:int>")))
+  }
+
+  test("schema_of_csv - infers schemas using options") {
+    val df = spark.range(1)
+      .select(schema_of_csv(lit("0.1 1"), Map("sep" -> " ").asJava))
+    checkAnswer(df, Seq(Row("struct<_c0:double,_c1:int>")))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
deleted file mode 100644
index 6b64f2f..0000000
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
+++ /dev/null
@@ -1,143 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.csv
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.csv.CSVOptions
-import org.apache.spark.sql.types._
-
-class CSVInferSchemaSuite extends SparkFunSuite {
-
-  test("String fields types are inferred correctly from null types") {
-    val options = new CSVOptions(Map.empty[String, String], false, "GMT")
-    assert(CSVInferSchema.inferField(NullType, "", options) == NullType)
-    assert(CSVInferSchema.inferField(NullType, null, options) == NullType)
-    assert(CSVInferSchema.inferField(NullType, "100000000000", options) == 
LongType)
-    assert(CSVInferSchema.inferField(NullType, "60", options) == IntegerType)
-    assert(CSVInferSchema.inferField(NullType, "3.5", options) == DoubleType)
-    assert(CSVInferSchema.inferField(NullType, "test", options) == StringType)
-    assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) 
== TimestampType)
-    assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType)
-    assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == 
BooleanType)
-
-    val textValueOne = Long.MaxValue.toString + "0"
-    val decimalValueOne = new java.math.BigDecimal(textValueOne)
-    val expectedTypeOne = DecimalType(decimalValueOne.precision, 
decimalValueOne.scale)
-    assert(CSVInferSchema.inferField(NullType, textValueOne, options) == 
expectedTypeOne)
-  }
-
-  test("String fields types are inferred correctly from other types") {
-    val options = new CSVOptions(Map.empty[String, String], false, "GMT")
-    assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType)
-    assert(CSVInferSchema.inferField(LongType, "test", options) == StringType)
-    assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == 
DoubleType)
-    assert(CSVInferSchema.inferField(DoubleType, null, options) == DoubleType)
-    assert(CSVInferSchema.inferField(DoubleType, "test", options) == 
StringType)
-    assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00", options) 
== TimestampType)
-    assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00", 
options) == TimestampType)
-    assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType)
-    assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == 
BooleanType)
-    assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == 
BooleanType)
-
-    val textValueOne = Long.MaxValue.toString + "0"
-    val decimalValueOne = new java.math.BigDecimal(textValueOne)
-    val expectedTypeOne = DecimalType(decimalValueOne.precision, 
decimalValueOne.scale)
-    assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == 
expectedTypeOne)
-  }
-
-  test("Timestamp field types are inferred correctly via custom data format") {
-    var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), false, 
"GMT")
-    assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == 
TimestampType)
-    options = new CSVOptions(Map("timestampFormat" -> "yyyy"), false, "GMT")
-    assert(CSVInferSchema.inferField(TimestampType, "2015", options) == 
TimestampType)
-  }
-
-  test("Timestamp field types are inferred correctly from other types") {
-    val options = new CSVOptions(Map.empty[String, String], false, "GMT")
-    assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == 
StringType)
-    assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) 
== StringType)
-    assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == 
StringType)
-  }
-
-  test("Boolean fields types are inferred correctly from other types") {
-    val options = new CSVOptions(Map.empty[String, String], false, "GMT")
-    assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType)
-    assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == 
StringType)
-  }
-
-  test("Type arrays are merged to highest common type") {
-    assert(
-      CSVInferSchema.mergeRowTypes(Array(StringType),
-        Array(DoubleType)).deep == Array(StringType).deep)
-    assert(
-      CSVInferSchema.mergeRowTypes(Array(IntegerType),
-        Array(LongType)).deep == Array(LongType).deep)
-    assert(
-      CSVInferSchema.mergeRowTypes(Array(DoubleType),
-        Array(LongType)).deep == Array(DoubleType).deep)
-  }
-
-  test("Null fields are handled properly when a nullValue is specified") {
-    var options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT")
-    assert(CSVInferSchema.inferField(NullType, "null", options) == NullType)
-    assert(CSVInferSchema.inferField(StringType, "null", options) == 
StringType)
-    assert(CSVInferSchema.inferField(LongType, "null", options) == LongType)
-
-    options = new CSVOptions(Map("nullValue" -> "\\N"), false, "GMT")
-    assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == 
IntegerType)
-    assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType)
-    assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == 
TimestampType)
-    assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == 
BooleanType)
-    assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == 
DecimalType(1, 1))
-  }
-
-  test("Merging Nulltypes should yield Nulltype.") {
-    val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), 
Array(NullType))
-    assert(mergedNullTypes.deep == Array(NullType).deep)
-  }
-
-  test("SPARK-18433: Improve DataSource option keys to be more 
case-insensitive") {
-    val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), false, 
"GMT")
-    assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == 
TimestampType)
-  }
-
-  test("SPARK-18877: `inferField` on DecimalType should find a common type 
with `typeSoFar`") {
-    val options = new CSVOptions(Map.empty[String, String], false, "GMT")
-
-    // 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9).
-    assert(CSVInferSchema.inferField(DecimalType(3, -10), "1.19E+11", options) 
==
-      DecimalType(4, -9))
-
-    // BigDecimal("12345678901234567890.01234567890123456789") is precision 40 
and scale 20.
-    val value = "12345678901234567890.01234567890123456789"
-    assert(CSVInferSchema.inferField(DecimalType(3, -10), value, options) == 
DoubleType)
-
-    // Seq(s"${Long.MaxValue}1", "2015-12-01 00:00:00") should be StringType
-    assert(CSVInferSchema.inferField(NullType, s"${Long.MaxValue}1", options) 
== DecimalType(20, 0))
-    assert(CSVInferSchema.inferField(DecimalType(20, 0), "2015-12-01 
00:00:00", options)
-      == StringType)
-  }
-
-  test("DoubleType should be inferred when user defined nan/inf are provided") 
{
-    val options = new CSVOptions(Map("nanValue" -> "nan", "negativeInf" -> 
"-inf",
-      "positiveInf" -> "inf"), false, "GMT")
-    assert(CSVInferSchema.inferField(NullType, "nan", options) == DoubleType)
-    assert(CSVInferSchema.inferField(NullType, "inf", options) == DoubleType)
-    assert(CSVInferSchema.inferField(NullType, "-inf", options) == DoubleType)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/c9667aff/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
deleted file mode 100644
index 6f23114..0000000
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
+++ /dev/null
@@ -1,200 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.csv
-
-import java.math.BigDecimal
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityParser}
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
-
-class UnivocityParserSuite extends SparkFunSuite {
-  private val parser = new UnivocityParser(
-    StructType(Seq.empty),
-    new CSVOptions(Map.empty[String, String], false, "GMT"))
-
-  private def assertNull(v: Any) = assert(v == null)
-
-  test("Can parse decimal type values") {
-    val stringValues = Seq("10.05", "1,000.01", "158,058,049.001")
-    val decimalValues = Seq(10.05, 1000.01, 158058049.001)
-    val decimalType = new DecimalType()
-
-    stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) =>
-      val decimalValue = new BigDecimal(decimalVal.toString)
-      val options = new CSVOptions(Map.empty[String, String], false, "GMT")
-      assert(parser.makeConverter("_1", decimalType, options = 
options).apply(strVal) ===
-        Decimal(decimalValue, decimalType.precision, decimalType.scale))
-    }
-  }
-
-  test("Nullable types are handled") {
-    val types = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, 
DoubleType,
-      BooleanType, DecimalType.DoubleDecimal, TimestampType, DateType, 
StringType)
-
-    // Nullable field with nullValue option.
-    types.foreach { t =>
-      // Tests that a custom nullValue.
-      val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), false, 
"GMT")
-      val converter =
-        parser.makeConverter("_1", t, nullable = true, options = 
nullValueOptions)
-      assertNull(converter.apply("-"))
-      assertNull(converter.apply(null))
-
-      // Tests that the default nullValue is empty string.
-      val options = new CSVOptions(Map.empty[String, String], false, "GMT")
-      assertNull(parser.makeConverter("_1", t, nullable = true, options = 
options).apply(""))
-    }
-
-    // Not nullable field with nullValue option.
-    types.foreach { t =>
-      // Casts a null to not nullable field should throw an exception.
-      val options = new CSVOptions(Map("nullValue" -> "-"), false, "GMT")
-      val converter =
-        parser.makeConverter("_1", t, nullable = false, options = options)
-      var message = intercept[RuntimeException] {
-        converter.apply("-")
-      }.getMessage
-      assert(message.contains("null value found but field _1 is not 
nullable."))
-      message = intercept[RuntimeException] {
-        converter.apply(null)
-      }.getMessage
-      assert(message.contains("null value found but field _1 is not 
nullable."))
-    }
-
-    // If nullValue is different with empty string, then, empty string should 
not be casted into
-    // null.
-    Seq(true, false).foreach { b =>
-      val options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT")
-      val converter =
-        parser.makeConverter("_1", StringType, nullable = b, options = options)
-      assert(converter.apply("") == UTF8String.fromString(""))
-    }
-  }
-
-  test("Throws exception for empty string with non null type") {
-      val options = new CSVOptions(Map.empty[String, String], false, "GMT")
-    val exception = intercept[RuntimeException]{
-      parser.makeConverter("_1", IntegerType, nullable = false, options = 
options).apply("")
-    }
-    assert(exception.getMessage.contains("null value found but field _1 is not 
nullable."))
-  }
-
-  test("Types are cast correctly") {
-    val options = new CSVOptions(Map.empty[String, String], false, "GMT")
-    assert(parser.makeConverter("_1", ByteType, options = options).apply("10") 
== 10)
-    assert(parser.makeConverter("_1", ShortType, options = 
options).apply("10") == 10)
-    assert(parser.makeConverter("_1", IntegerType, options = 
options).apply("10") == 10)
-    assert(parser.makeConverter("_1", LongType, options = options).apply("10") 
== 10)
-    assert(parser.makeConverter("_1", FloatType, options = 
options).apply("1.00") == 1.0)
-    assert(parser.makeConverter("_1", DoubleType, options = 
options).apply("1.00") == 1.0)
-    assert(parser.makeConverter("_1", BooleanType, options = 
options).apply("true") == true)
-
-    val timestampsOptions =
-      new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), false, 
"GMT")
-    val customTimestamp = "31/01/2015 00:00"
-    val expectedTime = 
timestampsOptions.timestampFormat.parse(customTimestamp).getTime
-    val castedTimestamp =
-      parser.makeConverter("_1", TimestampType, nullable = true, options = 
timestampsOptions)
-        .apply(customTimestamp)
-    assert(castedTimestamp == expectedTime * 1000L)
-
-    val customDate = "31/01/2015"
-    val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), false, 
"GMT")
-    val expectedDate = dateOptions.dateFormat.parse(customDate).getTime
-    val castedDate =
-      parser.makeConverter("_1", DateType, nullable = true, options = 
dateOptions)
-        .apply(customTimestamp)
-    assert(castedDate == DateTimeUtils.millisToDays(expectedDate))
-
-    val timestamp = "2015-01-01 00:00:00"
-    assert(parser.makeConverter("_1", TimestampType, options = 
options).apply(timestamp) ==
-      DateTimeUtils.stringToTime(timestamp).getTime  * 1000L)
-    assert(parser.makeConverter("_1", DateType, options = 
options).apply("2015-01-01") ==
-      
DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime))
-  }
-
-  test("Throws exception for casting an invalid string to Float and Double 
Types") {
-    val options = new CSVOptions(Map.empty[String, String], false, "GMT")
-    val types = Seq(DoubleType, FloatType)
-    val input = Seq("10u000", "abc", "1 2/3")
-    types.foreach { dt =>
-      input.foreach { v =>
-        val message = intercept[NumberFormatException] {
-          parser.makeConverter("_1", dt, options = options).apply(v)
-        }.getMessage
-        assert(message.contains(v))
-      }
-    }
-  }
-
-  test("Float NaN values are parsed correctly") {
-    val options = new CSVOptions(Map("nanValue" -> "nn"), false, "GMT")
-    val floatVal: Float = parser.makeConverter(
-      "_1", FloatType, nullable = true, options = options
-    ).apply("nn").asInstanceOf[Float]
-
-    // Java implements the IEEE-754 floating point standard which guarantees 
that any comparison
-    // against NaN will return false (except != which returns true)
-    assert(floatVal != floatVal)
-  }
-
-  test("Double NaN values are parsed correctly") {
-    val options = new CSVOptions(Map("nanValue" -> "-"), false, "GMT")
-    val doubleVal: Double = parser.makeConverter(
-      "_1", DoubleType, nullable = true, options = options
-    ).apply("-").asInstanceOf[Double]
-
-    assert(doubleVal.isNaN)
-  }
-
-  test("Float infinite values can be parsed") {
-    val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), 
false, "GMT")
-    val floatVal1 = parser.makeConverter(
-      "_1", FloatType, nullable = true, options = negativeInfOptions
-    ).apply("max").asInstanceOf[Float]
-
-    assert(floatVal1 == Float.NegativeInfinity)
-
-    val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), 
false, "GMT")
-    val floatVal2 = parser.makeConverter(
-      "_1", FloatType, nullable = true, options = positiveInfOptions
-    ).apply("max").asInstanceOf[Float]
-
-    assert(floatVal2 == Float.PositiveInfinity)
-  }
-
-  test("Double infinite values can be parsed") {
-    val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), 
false, "GMT")
-    val doubleVal1 = parser.makeConverter(
-      "_1", DoubleType, nullable = true, options = negativeInfOptions
-    ).apply("max").asInstanceOf[Double]
-
-    assert(doubleVal1 == Double.NegativeInfinity)
-
-    val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), 
false, "GMT")
-    val doubleVal2 = parser.makeConverter(
-      "_1", DoubleType, nullable = true, options = positiveInfOptions
-    ).apply("max").asInstanceOf[Double]
-
-    assert(doubleVal2 == Double.PositiveInfinity)
-  }
-
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to