Repository: spark
Updated Branches:
  refs/heads/master af8b6cc82 -> 8ddbc431d


[SPARK-20685] Fix BatchPythonEvaluation bug in case of single UDF w/ repeated 
arg.

## What changes were proposed in this pull request?

There's a latent corner-case bug in PySpark UDF evaluation where executing a 
`BatchPythonEvaluation` with a single multi-argument UDF where _at least one 
argument value is repeated_ will crash at execution with a confusing error.

This problem was introduced in #12057: the code there has a fast path for 
handling a "batch UDF evaluation consisting of a single Python UDF", but that 
branch incorrectly assumes that a single UDF won't have repeated arguments and 
therefore skips the code for unpacking arguments from the input row (whose 
schema may not necessarily match the UDF inputs due to de-duplication of 
repeated arguments which occurred in the JVM before sending UDF inputs to 
Python).

This fix here is simply to remove this special-casing: it turns out that the 
code in the "multiple UDFs" branch just so happens to work for the single-UDF 
case because Python treats `(x)` as equivalent to `x`, not as a single-argument 
tuple.

## How was this patch tested?

New regression test in `pyspark.python.sql.tests` module (tested and confirmed 
that it fails before my fix).

Author: Josh Rosen <joshro...@databricks.com>

Closes #17927 from JoshRosen/SPARK-20685.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8ddbc431
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8ddbc431
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8ddbc431

Branch: refs/heads/master
Commit: 8ddbc431d8b21d5ee57d3d209a4f25e301f15283
Parents: af8b6cc
Author: Josh Rosen <joshro...@databricks.com>
Authored: Wed May 10 16:50:57 2017 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Wed May 10 16:50:57 2017 -0700

----------------------------------------------------------------------
 python/pyspark/sql/tests.py |  6 ++++++
 python/pyspark/worker.py    | 29 +++++++++++++----------------
 2 files changed, 19 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8ddbc431/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index e3fe01e..8707500 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -330,6 +330,12 @@ class SQLTests(ReusedPySparkTestCase):
         [row] = self.spark.sql("SELECT double(double(1) + 1)").collect()
         self.assertEqual(row[0], 6)
 
+    def test_single_udf_with_repeated_argument(self):
+        # regression test for SPARK-20685
+        self.spark.catalog.registerFunction("add", lambda x, y: x + y, 
IntegerType())
+        row = self.spark.sql("SELECT add(1, 1)").first()
+        self.assertEqual(tuple(row), (2, ))
+
     def test_multiple_udfs(self):
         self.spark.catalog.registerFunction("double", lambda x: x * 2, 
IntegerType())
         [row] = self.spark.sql("SELECT double(1), double(2)").collect()

http://git-wip-us.apache.org/repos/asf/spark/blob/8ddbc431/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 25ee475..baaa3fe 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -87,22 +87,19 @@ def read_single_udf(pickleSer, infile):
 
 def read_udfs(pickleSer, infile):
     num_udfs = read_int(infile)
-    if num_udfs == 1:
-        # fast path for single UDF
-        _, udf = read_single_udf(pickleSer, infile)
-        mapper = lambda a: udf(*a)
-    else:
-        udfs = {}
-        call_udf = []
-        for i in range(num_udfs):
-            arg_offsets, udf = read_single_udf(pickleSer, infile)
-            udfs['f%d' % i] = udf
-            args = ["a[%d]" % o for o in arg_offsets]
-            call_udf.append("f%d(%s)" % (i, ", ".join(args)))
-        # Create function like this:
-        #   lambda a: (f0(a0), f1(a1, a2), f2(a3))
-        mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
-        mapper = eval(mapper_str, udfs)
+    udfs = {}
+    call_udf = []
+    for i in range(num_udfs):
+        arg_offsets, udf = read_single_udf(pickleSer, infile)
+        udfs['f%d' % i] = udf
+        args = ["a[%d]" % o for o in arg_offsets]
+        call_udf.append("f%d(%s)" % (i, ", ".join(args)))
+    # Create function like this:
+    #   lambda a: (f0(a0), f1(a1, a2), f2(a3))
+    # In the special case of a single UDF this will return a single result 
rather
+    # than a tuple of results; this is the format that the JVM side expects.
+    mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
+    mapper = eval(mapper_str, udfs)
 
     func = lambda _, it: map(mapper, it)
     ser = BatchedSerializer(PickleSerializer(), 100)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to