Repository: spark
Updated Branches:
  refs/heads/master 258a24341 -> da54abfd8


[SPARK-14081][SQL] - Preserve DataFrame column types when filling nulls.

## What changes were proposed in this pull request?
This change resolves an issue where `DataFrameNaFunctions.fill` changes a 
`FloatType` column to a `DoubleType`. We also clarify the contract that 
replacement values will be cast to the column data type, which may change the 
replacement value when casting to a lower precision type.

## How was this patch tested?
This patch has associated unit tests.

Author: Travis Crawford <tra...@medium.com>

Closes #11967 from traviscrawford/SPARK-14081-dataframena.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/da54abfd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/da54abfd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/da54abfd

Branch: refs/heads/master
Commit: da54abfd8730ef752eca921089bcf568773bd24a
Parents: 258a243
Author: Travis Crawford <tra...@medium.com>
Authored: Wed Mar 30 16:59:52 2016 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Wed Mar 30 16:59:52 2016 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/DataFrameNaFunctions.scala | 18 +++----
 .../spark/sql/DataFrameNaFunctionsSuite.scala   | 50 ++++++++++++--------
 2 files changed, 40 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/da54abfd/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 33588ef..f0e16ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -200,6 +200,7 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
    * The key of the map is the column name, and the value of the map is the 
replacement value.
    * The value must be of the following type:
    * `Integer`, `Long`, `Float`, `Double`, `String`, `Boolean`.
+   * Replacement values are cast to the column data type.
    *
    * For example, the following replaces null values in column "A" with string 
"unknown", and
    * null values in column "B" with numeric value 1.0.
@@ -217,6 +218,7 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
    *
    * The key of the map is the column name, and the value of the map is the 
replacement value.
    * The value must be of the following type: `Int`, `Long`, `Float`, 
`Double`, `String`, `Boolean`.
+   * Replacement values are cast to the column data type.
    *
    * For example, the following replaces null values in column "A" with string 
"unknown", and
    * null values in column "B" with numeric value 1.0.
@@ -386,10 +388,10 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
     val projections = df.schema.fields.map { f =>
       values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) 
=>
         v match {
-          case v: jl.Float => fillCol[Double](f, v.toDouble)
+          case v: jl.Float => fillCol[Float](f, v)
           case v: jl.Double => fillCol[Double](f, v)
-          case v: jl.Long => fillCol[Double](f, v.toDouble)
-          case v: jl.Integer => fillCol[Double](f, v.toDouble)
+          case v: jl.Long => fillCol[Long](f, v)
+          case v: jl.Integer => fillCol[Integer](f, v)
           case v: jl.Boolean => fillCol[Boolean](f, v.booleanValue())
           case v: String => fillCol[String](f, v)
         }
@@ -402,13 +404,13 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
    * Returns a [[Column]] expression that replaces null value in `col` with 
`replacement`.
    */
   private def fillCol[T](col: StructField, replacement: T): Column = {
-    col.dataType match {
+    val quotedColName = "`" + col.name + "`"
+    val colValue = col.dataType match {
       case DoubleType | FloatType =>
-        coalesce(nanvl(df.col("`" + col.name + "`"), lit(null)),
-          lit(replacement).cast(col.dataType)).as(col.name)
-      case _ =>
-        coalesce(df.col("`" + col.name + "`"), 
lit(replacement).cast(col.dataType)).as(col.name)
+        nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these 
types
+      case _ => df.col(quotedColName)
     }
+    coalesce(colValue, lit(replacement)).cast(col.dataType).as(col.name)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/da54abfd/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index e348754..18e04c2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -141,26 +141,36 @@ class DataFrameNaFunctionsSuite extends QueryTest with 
SharedSQLContext {
   }
 
   test("fill with map") {
-    val df = Seq[(String, String, java.lang.Long, java.lang.Double, 
java.lang.Boolean)](
-      (null, null, null, null, null)).toDF("a", "b", "c", "d", "e")
-    checkAnswer(
-      df.na.fill(Map(
-        "a" -> "test",
-        "c" -> 1,
-        "d" -> 2.2,
-        "e" -> false
-      )),
-      Row("test", null, 1, 2.2, false))
-
-    // Test Java version
-    checkAnswer(
-      df.na.fill(Map(
-        "a" -> "test",
-        "c" -> 1,
-        "d" -> 2.2,
-        "e" -> false
-      ).asJava),
-      Row("test", null, 1, 2.2, false))
+    val df = Seq[(String, String, java.lang.Integer, java.lang.Long,
+        java.lang.Float, java.lang.Double, java.lang.Boolean)](
+      (null, null, null, null, null, null, null))
+      .toDF("stringFieldA", "stringFieldB", "integerField", "longField",
+        "floatField", "doubleField", "booleanField")
+
+    val fillMap = Map(
+      "stringFieldA" -> "test",
+      "integerField" -> 1,
+      "longField" -> 2L,
+      "floatField" -> 3.3f,
+      "doubleField" -> 4.4d,
+      "booleanField" -> false)
+
+    val expectedRow = Row("test", null, 1, 2L, 3.3f, 4.4d, false)
+
+    checkAnswer(df.na.fill(fillMap), expectedRow)
+    checkAnswer(df.na.fill(fillMap.asJava), expectedRow) // Test Java version
+
+    // Ensure replacement values are cast to the column data type.
+    checkAnswer(df.na.fill(Map(
+      "integerField" -> 1d,
+      "longField" -> 2d,
+      "floatField" -> 3d,
+      "doubleField" -> 4d)),
+      Row(null, null, 1, 2L, 3f, 4d, null))
+
+    // Ensure column types do not change. Columns that have null values 
replaced
+    // will no longer be flagged as nullable, so do not compare schemas 
directly.
+    assert(df.na.fill(fillMap).schema.fields.map(_.dataType) === 
df.schema.fields.map(_.dataType))
   }
 
   test("replace") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to