Repository: spark
Updated Branches:
  refs/heads/master a2d8d767d -> e2773996b


[SPARK-19454][PYTHON][SQL] DataFrame.replace improvements

## What changes were proposed in this pull request?

- Allows skipping `value` argument if `to_replace` is a `dict`:
        ```python
        df = sc.parallelize([("Alice", 1, 3.0)]).toDF()
        df.replace({"Alice": "Bob"}).show()
        ````
- Adds validation step to ensure homogeneous values / replacements.
- Simplifies internal control flow.
- Improves unit tests coverage.

## How was this patch tested?

Existing unit tests, additional unit tests, manual testing.

Author: zero323 <zero...@users.noreply.github.com>

Closes #16793 from zero323/SPARK-19454.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e2773996
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e2773996
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e2773996

Branch: refs/heads/master
Commit: e2773996b8d1c0214d9ffac634a059b4923caf7b
Parents: a2d8d76
Author: zero323 <zero...@users.noreply.github.com>
Authored: Wed Apr 5 11:47:40 2017 -0700
Committer: Holden Karau <hol...@us.ibm.com>
Committed: Wed Apr 5 11:47:40 2017 -0700

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py | 81 +++++++++++++++++++++++++-----------
 python/pyspark/sql/tests.py     | 72 ++++++++++++++++++++++++++++++++
 2 files changed, 128 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e2773996/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index a24512f..774caf5 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -25,6 +25,8 @@ if sys.version >= '3':
 else:
     from itertools import imap as map
 
+import warnings
+
 from pyspark import copy_func, since
 from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
 from pyspark.serializers import BatchedSerializer, PickleSerializer, 
UTF8Deserializer
@@ -1281,7 +1283,7 @@ class DataFrame(object):
             return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), 
self.sql_ctx)
 
     @since(1.4)
-    def replace(self, to_replace, value, subset=None):
+    def replace(self, to_replace, value=None, subset=None):
         """Returns a new :class:`DataFrame` replacing a value with another 
value.
         :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are
         aliases of each other.
@@ -1326,43 +1328,72 @@ class DataFrame(object):
         |null|  null|null|
         +----+------+----+
         """
-        if not isinstance(to_replace, (float, int, long, basestring, list, 
tuple, dict)):
+        # Helper functions
+        def all_of(types):
+            """Given a type or tuple of types and a sequence of xs
+            check if each x is instance of type(s)
+
+            >>> all_of(bool)([True, False])
+            True
+            >>> all_of(basestring)(["a", 1])
+            False
+            """
+            def all_of_(xs):
+                return all(isinstance(x, types) for x in xs)
+            return all_of_
+
+        all_of_bool = all_of(bool)
+        all_of_str = all_of(basestring)
+        all_of_numeric = all_of((float, int, long))
+
+        # Validate input types
+        valid_types = (bool, float, int, long, basestring, list, tuple)
+        if not isinstance(to_replace, valid_types + (dict, )):
             raise ValueError(
-                "to_replace should be a float, int, long, string, list, tuple, 
or dict")
+                "to_replace should be a float, int, long, string, list, tuple, 
or dict. "
+                "Got {0}".format(type(to_replace)))
 
-        if not isinstance(value, (float, int, long, basestring, list, tuple)):
-            raise ValueError("value should be a float, int, long, string, 
list, or tuple")
+        if not isinstance(value, valid_types) and not isinstance(to_replace, 
dict):
+            raise ValueError("If to_replace is not a dict, value should be "
+                             "a float, int, long, string, list, or tuple. "
+                             "Got {0}".format(type(value)))
+
+        if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, 
tuple)):
+            if len(to_replace) != len(value):
+                raise ValueError("to_replace and value lists should be of the 
same length. "
+                                 "Got {0} and {1}".format(len(to_replace), 
len(value)))
 
-        rep_dict = dict()
+        if not (subset is None or isinstance(subset, (list, tuple, 
basestring))):
+            raise ValueError("subset should be a list or tuple of column 
names, "
+                             "column name or None. Got 
{0}".format(type(subset)))
 
+        # Reshape input arguments if necessary
         if isinstance(to_replace, (float, int, long, basestring)):
             to_replace = [to_replace]
 
-        if isinstance(to_replace, tuple):
-            to_replace = list(to_replace)
+        if isinstance(value, (float, int, long, basestring)):
+            value = [value for _ in range(len(to_replace))]
 
-        if isinstance(value, tuple):
-            value = list(value)
-
-        if isinstance(to_replace, list) and isinstance(value, list):
-            if len(to_replace) != len(value):
-                raise ValueError("to_replace and value lists should be of the 
same length")
-            rep_dict = dict(zip(to_replace, value))
-        elif isinstance(to_replace, list) and isinstance(value, (float, int, 
long, basestring)):
-            rep_dict = dict([(tr, value) for tr in to_replace])
-        elif isinstance(to_replace, dict):
+        if isinstance(to_replace, dict):
             rep_dict = to_replace
+            if value is not None:
+                warnings.warn("to_replace is a dict and value is not None. 
value will be ignored.")
+        else:
+            rep_dict = dict(zip(to_replace, value))
 
-        if subset is None:
-            return DataFrame(self._jdf.na().replace('*', rep_dict), 
self.sql_ctx)
-        elif isinstance(subset, basestring):
+        if isinstance(subset, basestring):
             subset = [subset]
 
-        if not isinstance(subset, (list, tuple)):
-            raise ValueError("subset should be a list or tuple of column 
names")
+        # Verify we were not passed in mixed type generics."
+        if not any(all_of_type(rep_dict.keys()) and 
all_of_type(rep_dict.values())
+                   for all_of_type in [all_of_bool, all_of_str, 
all_of_numeric]):
+            raise ValueError("Mixed type replacements are not supported")
 
-        return DataFrame(
-            self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), 
self.sql_ctx)
+        if subset is None:
+            return DataFrame(self._jdf.na().replace('*', rep_dict), 
self.sql_ctx)
+        else:
+            return DataFrame(
+                self._jdf.na().replace(self._jseq(subset), 
self._jmap(rep_dict)), self.sql_ctx)
 
     @since(2.0)
     def approxQuantile(self, col, probabilities, relativeError):

http://git-wip-us.apache.org/repos/asf/spark/blob/e2773996/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index db41b4e..2b24443 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1779,6 +1779,78 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(row.age, 10)
         self.assertEqual(row.height, None)
 
+        # replace with lists
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.1)], schema).replace([u'Alice'], 
[u'Ann']).first()
+        self.assertTupleEqual(row, (u'Ann', 10, 80.1))
+
+        # replace with dict
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.1)], schema).replace({10: 11}).first()
+        self.assertTupleEqual(row, (u'Alice', 11, 80.1))
+
+        # test backward compatibility with dummy value
+        dummy_value = 1
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, 
dummy_value).first()
+        self.assertTupleEqual(row, (u'Bob', 10, 80.1))
+
+        # test dict with mixed numerics
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 
90.5}).first()
+        self.assertTupleEqual(row, (u'Alice', -10, 90.5))
+
+        # replace with tuples
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', 
)).first()
+        self.assertTupleEqual(row, (u'Bob', 10, 80.1))
+
+        # replace multiple columns
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 
90)).first()
+        self.assertTupleEqual(row, (u'Alice', 20, 90.0))
+
+        # test for mixed numerics
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 
90.5)).first()
+        self.assertTupleEqual(row, (u'Alice', 20, 90.5))
+
+        row = self.spark.createDataFrame(
+            [(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first()
+        self.assertTupleEqual(row, (u'Alice', 20, 90.5))
+
+        # replace with boolean
+        row = (self
+               .spark.createDataFrame([(u'Alice', 10, 80.0)], schema)
+               .selectExpr("name = 'Bob'", 'age <= 15')
+               .replace(False, True).first())
+        self.assertTupleEqual(row, (True, True))
+
+        # should fail if subset is not list, tuple or None
+        with self.assertRaises(ValueError):
+            self.spark.createDataFrame(
+                [(u'Alice', 10, 80.1)], schema).replace({10: 11}, 
subset=1).first()
+
+        # should fail if to_replace and value have different length
+        with self.assertRaises(ValueError):
+            self.spark.createDataFrame(
+                [(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], 
["Eve"]).first()
+
+        # should fail if when received unexpected type
+        with self.assertRaises(ValueError):
+            from datetime import datetime
+            self.spark.createDataFrame(
+                [(u'Alice', 10, 80.1)], schema).replace(datetime.now(), 
datetime.now()).first()
+
+        # should fail if provided mixed type replacements
+        with self.assertRaises(ValueError):
+            self.spark.createDataFrame(
+                [(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 
20]).first()
+
+        with self.assertRaises(ValueError):
+            self.spark.createDataFrame(
+                [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 
20}).first()
+
     def test_capture_analysis_exception(self):
         self.assertRaises(AnalysisException, lambda: self.spark.sql("select 
abc"))
         self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + 
b"))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to