Repository: spark
Updated Branches:
  refs/heads/master 864d94fe8 -> 6cbc61d10


[SPARK-19732][SQL][PYSPARK] Add fill functions for nulls in bool fields of 
datasets

## What changes were proposed in this pull request?

Allow fill/replace of NAs with booleans, both in Python and Scala

## How was this patch tested?

Unit tests, doctests

This PR is original work from me and I license this work to the Spark project

Author: Ruben Berenguel Montoro <ru...@mostlymaths.net>
Author: Ruben Berenguel <ru...@mostlymaths.net>

Closes #18164 from rberenguel/SPARK-19732-fillna-bools.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6cbc61d1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6cbc61d1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6cbc61d1

Branch: refs/heads/master
Commit: 6cbc61d1070584ffbc34b1f53df352c9162f414a
Parents: 864d94f
Author: Ruben Berenguel Montoro <ru...@mostlymaths.net>
Authored: Sat Jun 3 14:56:42 2017 +0900
Committer: Takuya UESHIN <ues...@databricks.com>
Committed: Sat Jun 3 14:56:42 2017 +0900

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 | 23 ++++++++++---
 python/pyspark/sql/tests.py                     | 34 +++++++++++++++-----
 .../apache/spark/sql/DataFrameNaFunctions.scala | 30 +++++++++++++++--
 .../spark/sql/DataFrameNaFunctionsSuite.scala   | 21 ++++++++++++
 4 files changed, 94 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6cbc61d1/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 8d8b938..99abfcc 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1289,7 +1289,7 @@ class DataFrame(object):
         """Replace null values, alias for ``na.fill()``.
         :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are 
aliases of each other.
 
-        :param value: int, long, float, string, or dict.
+        :param value: int, long, float, string, bool or dict.
             Value to replace null values with.
             If the value is a dict, then `subset` is ignored and `value` must 
be a mapping
             from column name (string) to replacement value. The replacement 
value must be
@@ -1309,6 +1309,15 @@ class DataFrame(object):
         | 50|    50| null|
         +---+------+-----+
 
+        >>> df5.na.fill(False).show()
+        +----+-------+-----+
+        | age|   name|  spy|
+        +----+-------+-----+
+        |  10|  Alice|false|
+        |   5|    Bob|false|
+        |null|Mallory| true|
+        +----+-------+-----+
+
         >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show()
         +---+------+-------+
         |age|height|   name|
@@ -1319,10 +1328,13 @@ class DataFrame(object):
         | 50|  null|unknown|
         +---+------+-------+
         """
-        if not isinstance(value, (float, int, long, basestring, dict)):
-            raise ValueError("value should be a float, int, long, string, or 
dict")
+        if not isinstance(value, (float, int, long, basestring, bool, dict)):
+            raise ValueError("value should be a float, int, long, string, bool 
or dict")
+
+        # Note that bool validates isinstance(int), but we don't want to
+        # convert bools to floats
 
-        if isinstance(value, (int, long)):
+        if not isinstance(value, bool) and isinstance(value, (int, long)):
             value = float(value)
 
         if isinstance(value, dict):
@@ -1819,6 +1831,9 @@ def _test():
                                    Row(name='Bob', age=5, height=None),
                                    Row(name='Tom', age=None, height=None),
                                    Row(name=None, age=None, 
height=None)]).toDF()
+    globs['df5'] = sc.parallelize([Row(name='Alice', spy=False, age=10),
+                                   Row(name='Bob', spy=None, age=5),
+                                   Row(name='Mallory', spy=True, 
age=None)]).toDF()
     globs['sdf'] = sc.parallelize([Row(name='Tom', time=1479441846),
                                    Row(name='Bob', time=1479442946)]).toDF()
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6cbc61d1/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index acea911..845e1c7 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1697,40 +1697,58 @@ class SQLTests(ReusedPySparkTestCase):
         schema = StructType([
             StructField("name", StringType(), True),
             StructField("age", IntegerType(), True),
-            StructField("height", DoubleType(), True)])
+            StructField("height", DoubleType(), True),
+            StructField("spy", BooleanType(), True)])
 
         # fillna shouldn't change non-null values
-        row = self.spark.createDataFrame([(u'Alice', 10, 80.1)], 
schema).fillna(50).first()
+        row = self.spark.createDataFrame([(u'Alice', 10, 80.1, True)], 
schema).fillna(50).first()
         self.assertEqual(row.age, 10)
 
         # fillna with int
-        row = self.spark.createDataFrame([(u'Alice', None, None)], 
schema).fillna(50).first()
+        row = self.spark.createDataFrame([(u'Alice', None, None, None)], 
schema).fillna(50).first()
         self.assertEqual(row.age, 50)
         self.assertEqual(row.height, 50.0)
 
         # fillna with double
-        row = self.spark.createDataFrame([(u'Alice', None, None)], 
schema).fillna(50.1).first()
+        row = self.spark.createDataFrame(
+            [(u'Alice', None, None, None)], schema).fillna(50.1).first()
         self.assertEqual(row.age, 50)
         self.assertEqual(row.height, 50.1)
 
+        # fillna with bool
+        row = self.spark.createDataFrame(
+            [(u'Alice', None, None, None)], schema).fillna(True).first()
+        self.assertEqual(row.age, None)
+        self.assertEqual(row.spy, True)
+
         # fillna with string
-        row = self.spark.createDataFrame([(None, None, None)], 
schema).fillna("hello").first()
+        row = self.spark.createDataFrame([(None, None, None, None)], 
schema).fillna("hello").first()
         self.assertEqual(row.name, u"hello")
         self.assertEqual(row.age, None)
 
         # fillna with subset specified for numeric cols
         row = self.spark.createDataFrame(
-            [(None, None, None)], schema).fillna(50, subset=['name', 
'age']).first()
+            [(None, None, None, None)], schema).fillna(50, subset=['name', 
'age']).first()
         self.assertEqual(row.name, None)
         self.assertEqual(row.age, 50)
         self.assertEqual(row.height, None)
+        self.assertEqual(row.spy, None)
 
-        # fillna with subset specified for numeric cols
+        # fillna with subset specified for string cols
         row = self.spark.createDataFrame(
-            [(None, None, None)], schema).fillna("haha", subset=['name', 
'age']).first()
+            [(None, None, None, None)], schema).fillna("haha", subset=['name', 
'age']).first()
         self.assertEqual(row.name, "haha")
         self.assertEqual(row.age, None)
         self.assertEqual(row.height, None)
+        self.assertEqual(row.spy, None)
+
+        # fillna with subset specified for bool cols
+        row = self.spark.createDataFrame(
+            [(None, None, None, None)], schema).fillna(True, subset=['name', 
'spy']).first()
+        self.assertEqual(row.name, None)
+        self.assertEqual(row.age, None)
+        self.assertEqual(row.height, None)
+        self.assertEqual(row.spy, True)
 
         # fillna with dictionary for boolean types
         row = self.spark.createDataFrame([Row(a=None), 
Row(a=True)]).fillna({"a": True}).first()

http://git-wip-us.apache.org/repos/asf/spark/blob/6cbc61d1/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 052d85a..ee949e7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -196,6 +196,30 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
   def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, 
cols)
 
   /**
+   * Returns a new `DataFrame` that replaces null values in boolean columns 
with `value`.
+   *
+   * @since 2.3.0
+   */
+  def fill(value: Boolean): DataFrame = fill(value, df.columns)
+
+  /**
+   * (Scala-specific) Returns a new `DataFrame` that replaces null values in 
specified
+   * boolean columns. If a specified column is not a boolean column, it is 
ignored.
+   *
+   * @since 2.3.0
+   */
+  def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, 
cols)
+
+  /**
+   * Returns a new `DataFrame` that replaces null values in specified boolean 
columns.
+   * If a specified column is not a boolean column, it is ignored.
+   *
+   * @since 2.3.0
+   */
+  def fill(value: Boolean, cols: Array[String]): DataFrame = fill(value, 
cols.toSeq)
+
+
+  /**
    * Returns a new `DataFrame` that replaces null values.
    *
    * The key of the map is the column name, and the value of the map is the 
replacement value.
@@ -440,8 +464,8 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
 
   /**
    * Returns a new `DataFrame` that replaces null or NaN values in specified
-   * numeric, string columns. If a specified column is not a numeric, string 
column,
-   * it is ignored.
+   * numeric, string columns. If a specified column is not a numeric, string
+   * or boolean column it is ignored.
    */
   private def fillValue[T](value: T, cols: Seq[String]): DataFrame = {
     // the fill[T] which T is  Long/Double,
@@ -452,6 +476,7 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
     val targetType = value match {
       case _: Double | _: Long => NumericType
       case _: String => StringType
+      case _: Boolean => BooleanType
       case _ => throw new IllegalArgumentException(
         s"Unsupported value type ${value.getClass.getName} ($value).")
     }
@@ -461,6 +486,7 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
       val typeMatches = (targetType, f.dataType) match {
         case (NumericType, dt) => dt.isInstanceOf[NumericType]
         case (StringType, dt) => dt == StringType
+        case (BooleanType, dt) => dt == BooleanType
       }
       // Only fill if the column is part of the cols list.
       if (typeMatches && cols.exists(col => columnEquals(f.name, col))) {

http://git-wip-us.apache.org/repos/asf/spark/blob/6cbc61d1/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index aa237d0..e63c5cb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -104,6 +104,13 @@ class DataFrameNaFunctionsSuite extends QueryTest with 
SharedSQLContext {
   test("fill") {
     val input = createDF()
 
+    val boolInput = Seq[(String, java.lang.Boolean)](
+      ("Bob", false),
+      ("Alice", null),
+      ("Mallory", true),
+      (null, null)
+    ).toDF("name", "spy")
+
     val fillNumeric = input.na.fill(50.6)
     checkAnswer(
       fillNumeric,
@@ -124,6 +131,12 @@ class DataFrameNaFunctionsSuite extends QueryTest with 
SharedSQLContext {
         Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil)
     assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq)
 
+    // boolean
+    checkAnswer(
+      boolInput.na.fill(true).select("spy"),
+      Row(false) :: Row(true) :: Row(true) :: Row(true) :: Nil)
+    assert(boolInput.na.fill(true).columns.toSeq === boolInput.columns.toSeq)
+
     // fill double with subset columns
     checkAnswer(
       input.na.fill(50.6, "age" :: Nil).select("name", "age"),
@@ -134,6 +147,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with 
SharedSQLContext {
         Row("Amy", 50) ::
         Row(null, 50) :: Nil)
 
+    // fill boolean with subset columns
+    checkAnswer(
+      boolInput.na.fill(true, "spy" :: Nil).select("name", "spy"),
+      Row("Bob", false) ::
+        Row("Alice", true) ::
+        Row("Mallory", true) ::
+        Row(null, true) :: Nil)
+
     // fill string with subset columns
     checkAnswer(
       Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", 
"col1" :: Nil),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to