Repository: spark Updated Branches: refs/heads/master e398c2814 -> 4ce735eed
[SPARK-21394][SPARK-21432][PYTHON] Reviving callable object/partial function support in UDF in PySpark ## What changes were proposed in this pull request? This PR proposes to avoid `__name__` in the tuple naming the attributes assigned directly from the wrapped function to the wrapper function, and use `self._name` (`func.__name__` or `obj.__class__.name__`). After SPARK-19161, we happened to break callable objects as UDFs in Python as below: ```python from pyspark.sql import functions class F(object): def __call__(self, x): return x foo = F() udf = functions.udf(foo) ``` ``` Traceback (most recent call last): File "<stdin>", line 1, in <module> File ".../spark/python/pyspark/sql/functions.py", line 2142, in udf return _udf(f=f, returnType=returnType) File ".../spark/python/pyspark/sql/functions.py", line 2133, in _udf return udf_obj._wrapped() File ".../spark/python/pyspark/sql/functions.py", line 2090, in _wrapped functools.wraps(self.func) File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/functools.py", line 33, in update_wrapper setattr(wrapper, attr, getattr(wrapped, attr)) AttributeError: F instance has no attribute '__name__' ``` This worked in Spark 2.1: ```python from pyspark.sql import functions class F(object): def __call__(self, x): return x foo = F() udf = functions.udf(foo) spark.range(1).select(udf("id")).show() ``` ``` +-----+ |F(id)| +-----+ | 0| +-----+ ``` **After** ```python from pyspark.sql import functions class F(object): def __call__(self, x): return x foo = F() udf = functions.udf(foo) spark.range(1).select(udf("id")).show() ``` ``` +-----+ |F(id)| +-----+ | 0| +-----+ ``` _In addition, we also happened to break partial functions as below_: ```python from pyspark.sql import functions from functools import partial partial_func = partial(lambda x: x, x=1) udf = functions.udf(partial_func) ``` ``` Traceback (most recent call last): File "<stdin>", line 1, in <module> File ".../spark/python/pyspark/sql/functions.py", line 2154, in udf return _udf(f=f, returnType=returnType) File ".../spark/python/pyspark/sql/functions.py", line 2145, in _udf return udf_obj._wrapped() File ".../spark/python/pyspark/sql/functions.py", line 2099, in _wrapped functools.wraps(self.func, assigned=assignments) File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/functools.py", line 33, in update_wrapper setattr(wrapper, attr, getattr(wrapped, attr)) AttributeError: 'functools.partial' object has no attribute '__module__' ``` This worked in Spark 2.1: ```python from pyspark.sql import functions from functools import partial partial_func = partial(lambda x: x, x=1) udf = functions.udf(partial_func) spark.range(1).select(udf()).show() ``` ``` +---------+ |partial()| +---------+ | 1| +---------+ ``` **After** ```python from pyspark.sql import functions from functools import partial partial_func = partial(lambda x: x, x=1) udf = functions.udf(partial_func) spark.range(1).select(udf()).show() ``` ``` +---------+ |partial()| +---------+ | 1| +---------+ ``` ## How was this patch tested? Unit tests in `python/pyspark/sql/tests.py` and manual tests. Author: hyukjinkwon <gurwls...@gmail.com> Closes #18615 from HyukjinKwon/callable-object. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4ce735ee Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4ce735ee Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4ce735ee Branch: refs/heads/master Commit: 4ce735eed103f3bd055c087126acd1366c2537ec Parents: e398c28 Author: hyukjinkwon <gurwls...@gmail.com> Authored: Mon Jul 17 00:37:36 2017 -0700 Committer: Holden Karau <hol...@us.ibm.com> Committed: Mon Jul 17 00:37:36 2017 -0700 ---------------------------------------------------------------------- python/pyspark/sql/functions.py | 14 +++++++++++++- python/pyspark/sql/tests.py | 21 +++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/4ce735ee/python/pyspark/sql/functions.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d45ff63..2c8c8e2 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2087,10 +2087,22 @@ class UserDefinedFunction(object): """ Wrap this udf with a function and attach docstring from func """ - @functools.wraps(self.func) + + # It is possible for a callable instance without __name__ attribute or/and + # __module__ attribute to be wrapped here. For example, functools.partial. In this case, + # we should avoid wrapping the attributes from the wrapped function to the wrapper + # function. So, we take out these attribute names from the default names to set and + # then manually assign it after being wrapped. + assignments = tuple( + a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__') + + @functools.wraps(self.func, assigned=assignments) def wrapper(*args): return self(*args) + wrapper.__name__ = self._name + wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__') + else self.func.__class__.__module__) wrapper.func = self.func wrapper.returnType = self.returnType http://git-wip-us.apache.org/repos/asf/spark/blob/4ce735ee/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 29e48a6..be5495c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -679,6 +679,27 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) + class F(object): + """Identity""" + def __call__(self, x): + return x + + f = F() + return_type = IntegerType() + f_ = udf(f, return_type) + + self.assertTrue(f.__doc__ in f_.__doc__) + self.assertEqual(f, f_.func) + self.assertEqual(return_type, f_.returnType) + + f = functools.partial(f, x=1) + return_type = IntegerType() + f_ = udf(f, return_type) + + self.assertTrue(f.__doc__ in f_.__doc__) + self.assertEqual(f, f_.func) + self.assertEqual(return_type, f_.returnType) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org