This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 8ad9b83  [SPARK-31965][TESTS][PYTHON] Move doctests related to Java 
function registration to test conditionally
8ad9b83 is described below

commit 8ad9b83edc239eae6b468d619419af5c0f41b4d0
Author: HyukjinKwon <gurwls...@apache.org>
AuthorDate: Wed Jun 10 21:15:40 2020 -0700

    [SPARK-31965][TESTS][PYTHON] Move doctests related to Java function 
registration to test conditionally
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to move the doctests in `registerJavaUDAF` and 
`registerJavaFunction` to the proper unittests that run conditionally when the 
test classes are present.
    
    Both tests are dependent on the test classes in JVM side, 
`test.org.apache.spark.sql.JavaStringLength` and 
`test.org.apache.spark.sql.MyDoubleAvg`. So if you run the tests against the 
plain `sbt package`, it fails as below:
    
    ```
    **********************************************************************
    File "/.../spark/python/pyspark/sql/udf.py", line 366, in 
pyspark.sql.udf.UDFRegistration.registerJavaFunction
    Failed example:
        spark.udf.registerJavaFunction(
            "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", 
IntegerType())
    Exception raised:
        Traceback (most recent call last):
       ...
    test.org.apache.spark.sql.JavaStringLength, please make sure it is on the 
classpath;
    ...
       6 of   7 in pyspark.sql.udf.UDFRegistration.registerJavaFunction
       2 of   4 in pyspark.sql.udf.UDFRegistration.registerJavaUDAF
    ***Test Failed*** 8 failures.
    ```
    
    ### Why are the changes needed?
    
    In order to support to run the tests against the plain SBT build. See also 
https://spark.apache.org/developer-tools.html
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, it's test-only.
    
    ### How was this patch tested?
    
    Manually tested as below:
    
    ```bash
    ./build/sbt -DskipTests -Phive-thriftserver clean package
    cd python
    ./run-tests --python-executable=python3 --testname="pyspark.sql.udf 
UserDefinedFunction"
    ./run-tests --python-executable=python3 
--testname="pyspark.sql.tests.test_udf UDFTests"
    ```
    
    ```bash
    ./build/sbt -DskipTests -Phive-thriftserver clean test:package
    cd python
    ./run-tests --python-executable=python3 --testname="pyspark.sql.udf 
UserDefinedFunction"
    ./run-tests --python-executable=python3 
--testname="pyspark.sql.tests.test_udf UDFTests"
    ```
    
    Closes #28795 from HyukjinKwon/SPARK-31965.
    
    Authored-by: HyukjinKwon <gurwls...@apache.org>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
    (cherry picked from commit 56264fb5d3ad1a488be5e08feb2e0304d1c2ed6a)
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 python/pyspark/sql/tests/test_udf.py | 28 ++++++++++++++++++++++++++++
 python/pyspark/sql/udf.py            | 14 +++++++++-----
 2 files changed, 37 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index 061d3f5..ea7ec9f 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -21,6 +21,8 @@ import shutil
 import tempfile
 import unittest
 
+import py4j
+
 from pyspark import SparkContext
 from pyspark.sql import SparkSession, Column, Row
 from pyspark.sql.functions import UserDefinedFunction, udf
@@ -357,6 +359,32 @@ class UDFTests(ReusedSQLTestCase):
             df.select(add_four("id").alias("plus_four")).collect()
         )
 
+    @unittest.skipIf(not test_compiled, test_not_compiled_message)
+    def test_register_java_function(self):
+        self.spark.udf.registerJavaFunction(
+            "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", 
IntegerType())
+        [value] = self.spark.sql("SELECT javaStringLength('test')").first()
+        self.assertEqual(value, 4)
+
+        self.spark.udf.registerJavaFunction(
+            "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength")
+        [value] = self.spark.sql("SELECT javaStringLength2('test')").first()
+        self.assertEqual(value, 4)
+
+        self.spark.udf.registerJavaFunction(
+            "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", 
"integer")
+        [value] = self.spark.sql("SELECT javaStringLength3('test')").first()
+        self.assertEqual(value, 4)
+
+    @unittest.skipIf(not test_compiled, test_not_compiled_message)
+    def test_register_java_udaf(self):
+        self.spark.udf.registerJavaUDAF("javaUDAF", 
"test.org.apache.spark.sql.MyDoubleAvg")
+        df = self.spark.createDataFrame([(1, "a"), (2, "b"), (3, "a")], ["id", 
"name"])
+        df.createOrReplaceTempView("df")
+        row = self.spark.sql(
+            "SELECT name, javaUDAF(id) as avg from df group by name order by 
name desc").first()
+        self.assertEqual(row.asDict(), Row(name='b', avg=102.0).asDict())
+
     def test_non_existed_udf(self):
         spark = self.spark
         self.assertRaisesRegexp(AnalysisException, "Can not load class 
non_existed_udf",
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 10546ec..da68583 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -365,17 +365,20 @@ class UDFRegistration(object):
         >>> from pyspark.sql.types import IntegerType
         >>> spark.udf.registerJavaFunction(
         ...     "javaStringLength", 
"test.org.apache.spark.sql.JavaStringLength", IntegerType())
-        >>> spark.sql("SELECT javaStringLength('test')").collect()
+        ... # doctest: +SKIP
+        >>> spark.sql("SELECT javaStringLength('test')").collect()  # doctest: 
+SKIP
         [Row(javaStringLength(test)=4)]
 
         >>> spark.udf.registerJavaFunction(
         ...     "javaStringLength2", 
"test.org.apache.spark.sql.JavaStringLength")
-        >>> spark.sql("SELECT javaStringLength2('test')").collect()
+        ... # doctest: +SKIP
+        >>> spark.sql("SELECT javaStringLength2('test')").collect()  # 
doctest: +SKIP
         [Row(javaStringLength2(test)=4)]
 
         >>> spark.udf.registerJavaFunction(
         ...     "javaStringLength3", 
"test.org.apache.spark.sql.JavaStringLength", "integer")
-        >>> spark.sql("SELECT javaStringLength3('test')").collect()
+        ... # doctest: +SKIP
+        >>> spark.sql("SELECT javaStringLength3('test')").collect()  # 
doctest: +SKIP
         [Row(javaStringLength3(test)=4)]
         """
 
@@ -395,10 +398,11 @@ class UDFRegistration(object):
         :param javaClassName: fully qualified name of java class
 
         >>> spark.udf.registerJavaUDAF("javaUDAF", 
"test.org.apache.spark.sql.MyDoubleAvg")
+        ... # doctest: +SKIP
         >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", 
"name"])
         >>> df.createOrReplaceTempView("df")
-        >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name 
order by name desc") \
-                .collect()
+        >>> q = "SELECT name, javaUDAF(id) as avg from df group by name order 
by name desc"
+        >>> spark.sql(q).collect()  # doctest: +SKIP
         [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
         """
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to