Repository: spark
Updated Branches:
  refs/heads/branch-1.3 bc04fa2e2 -> 98f72dfc1


[SPARK-6553] [pyspark] Support functools.partial as UDF

Use `f.__repr__()` instead of `f.__name__` when instantiating 
`UserDefinedFunction`s, so `functools.partial`s may be used.

Author: ksonj <k...@siberie.de>

Closes #5206 from ksonj/partials and squashes the following commits:

ea66f3d [ksonj] Inserted blank lines for PEP8 compliance
d81b02b [ksonj] added tests for udf with partial function and callable object
2c76100 [ksonj] Makes UDFs work with all types of callables
b814a12 [ksonj] support functools.partial as udf


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/98f72dfc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/98f72dfc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/98f72dfc

Branch: refs/heads/branch-1.3
Commit: 98f72dfc17853b570d05c20e97c78919682b6df6
Parents: bc04fa2
Author: ksonj <k...@siberie.de>
Authored: Wed Apr 1 17:23:57 2015 -0700
Committer: Josh Rosen <joshro...@databricks.com>
Committed: Wed Apr 1 17:23:57 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/functions.py |  3 ++-
 python/pyspark/sql/tests.py     | 31 +++++++++++++++++++++++++++++++
 2 files changed, 33 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/98f72dfc/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 8a478fd..146ba6f 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -123,7 +123,8 @@ class UserDefinedFunction(object):
         pickled_command, broadcast_vars, env, includes = 
_prepare_for_python_RDD(sc, command, self)
         ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
         jdt = ssql_ctx.parseDataType(self.returnType.json())
-        judf = sc._jvm.UserDefinedPythonFunction(f.__name__, 
bytearray(pickled_command), env,
+        fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
+        judf = sc._jvm.UserDefinedPythonFunction(fname, 
bytearray(pickled_command), env,
                                                  includes, sc.pythonExec, 
broadcast_vars,
                                                  sc._javaAccumulator, jdt)
         return judf

http://git-wip-us.apache.org/repos/asf/spark/blob/98f72dfc/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 258464b..b3a6a2c 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -25,6 +25,7 @@ import pydoc
 import shutil
 import tempfile
 import pickle
+import functools
 
 import py4j
 
@@ -41,6 +42,7 @@ from pyspark.sql import SQLContext, HiveContext, Column, Row
 from pyspark.sql.types import *
 from pyspark.sql.types import UserDefinedType, _infer_type
 from pyspark.tests import ReusedPySparkTestCase
+from pyspark.sql.functions import UserDefinedFunction
 
 
 class ExamplePointUDT(UserDefinedType):
@@ -114,6 +116,35 @@ class SQLTests(ReusedPySparkTestCase):
         ReusedPySparkTestCase.tearDownClass()
         shutil.rmtree(cls.tempdir.name, ignore_errors=True)
 
+    def test_udf_with_callable(self):
+        d = [Row(number=i, squared=i**2) for i in range(10)]
+        rdd = self.sc.parallelize(d)
+        data = self.sqlCtx.createDataFrame(rdd)
+
+        class PlusFour:
+            def __call__(self, col):
+                if col is not None:
+                    return col + 4
+
+        call = PlusFour()
+        pudf = UserDefinedFunction(call, LongType())
+        res = data.select(pudf(data['number']).alias('plus_four'))
+        self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
+
+    def test_udf_with_partial_function(self):
+        d = [Row(number=i, squared=i**2) for i in range(10)]
+        rdd = self.sc.parallelize(d)
+        data = self.sqlCtx.createDataFrame(rdd)
+
+        def some_func(col, param):
+            if col is not None:
+                return col + param
+
+        pfunc = functools.partial(some_func, param=4)
+        pudf = UserDefinedFunction(pfunc, LongType())
+        res = data.select(pudf(data['number']).alias('plus_four'))
+        self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
+
     def test_udf(self):
         self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, 
IntegerType())
         [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to