This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5da7b6878c20 [SPARK-47763][CONNECT][TESTS] Enable local-cluster tests 
with pyspark-connect package
5da7b6878c20 is described below

commit 5da7b6878c2083fc50cb345233e9dac03bf806ac
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Wed Apr 17 21:16:50 2024 +0900

    [SPARK-47763][CONNECT][TESTS] Enable local-cluster tests with 
pyspark-connect package
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to extends `pyspark-connect` scheduled job to run 
`pyspark.resource` tests as well.
    
    ### Why are the changes needed?
    
    In order to make sure pure Python library works with `pyspark.resource`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, test-only.
    
    ### How was this patch tested?
    
    Tested in my own fork: 
https://github.com/HyukjinKwon/spark/actions/runs/8718980385/job/23917348664
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46090 from HyukjinKwon/enable-resources.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .github/workflows/build_python_connect.yml         | 31 +++++++++++++++++++---
 python/packaging/connect/setup.py                  |  1 +
 .../resource/tests/test_connect_resources.py       | 15 ++++++-----
 .../sql/tests/connect/client/test_artifact.py      | 13 +++++----
 python/pyspark/sql/tests/connect/test_resources.py | 15 +++++------
 python/pyspark/sql/tests/test_resources.py         | 12 +++++----
 6 files changed, 56 insertions(+), 31 deletions(-)

diff --git a/.github/workflows/build_python_connect.yml 
b/.github/workflows/build_python_connect.yml
index 863980b0c2e5..3e11dec14741 100644
--- a/.github/workflows/build_python_connect.yml
+++ b/.github/workflows/build_python_connect.yml
@@ -29,6 +29,7 @@ jobs:
     name: "Build modules: pyspark-connect"
     runs-on: ubuntu-latest
     timeout-minutes: 300
+    if: github.repository == 'apache/spark'
     steps:
       - name: Checkout Spark repository
         uses: actions/checkout@v4
@@ -80,19 +81,43 @@ jobs:
           # Make less noisy
           cp conf/log4j2.properties.template conf/log4j2.properties
           sed -i 's/rootLogger.level = info/rootLogger.level = warn/g' 
conf/log4j2.properties
-          # Start a Spark Connect server
+
+          # Start a Spark Connect server for local
           
PYTHONPATH="python/lib/pyspark.zip:python/lib/py4j-0.10.9.7-src.zip:$PYTHONPATH"
 ./sbin/start-connect-server.sh \
             --driver-java-options 
"-Dlog4j.configurationFile=file:$GITHUB_WORKSPACE/conf/log4j2.properties" \
             --jars "`find connector/connect/server/target -name 
spark-connect-*SNAPSHOT.jar`,`find connector/protobuf/target -name 
spark-protobuf-*SNAPSHOT.jar`,`find connector/avro/target -name 
spark-avro*SNAPSHOT.jar`"
+
           # Make sure running Python workers that contains pyspark.core once. 
They will be reused.
           python -c "from pyspark.sql import SparkSession; _ = 
SparkSession.builder.remote('sc://localhost').getOrCreate().range(100).repartition(100).mapInPandas(lambda
 x: x, 'id INT').collect()"
+
           # Remove Py4J and PySpark zipped library to make sure there is no 
JVM connection
-          rm python/lib/*
-          rm -r python/pyspark
+          mv python/lib lib.back
+          mv python/pyspark pyspark.back
+
           # Several tests related to catalog requires to run them 
sequencially, e.g., writing a table in a listener.
           ./python/run-tests --parallelism=1 --python-executables=python3 
--modules pyspark-connect,pyspark-ml-connect
           # None of tests are dependent on each other in Pandas API on Spark 
so run them in parallel
           ./python/run-tests --parallelism=4 --python-executables=python3 
--modules 
pyspark-pandas-connect-part0,pyspark-pandas-connect-part1,pyspark-pandas-connect-part2,pyspark-pandas-connect-part3
+
+          # Stop Spark Connect server.
+          ./sbin/stop-connect-server.sh
+          mv lib.back python/lib
+          mv pyspark.back python/pyspark
+
+          # Start a Spark Connect server for local-cluster
+          
PYTHONPATH="python/lib/pyspark.zip:python/lib/py4j-0.10.9.7-src.zip:$PYTHONPATH"
 ./sbin/start-connect-server.sh \
+            --master "local-cluster[2, 4, 1024]" \
+            --driver-java-options 
"-Dlog4j.configurationFile=file:$GITHUB_WORKSPACE/conf/log4j2.properties" \
+            --jars "`find connector/connect/server/target -name 
spark-connect-*SNAPSHOT.jar`,`find connector/protobuf/target -name 
spark-protobuf-*SNAPSHOT.jar`,`find connector/avro/target -name 
spark-avro*SNAPSHOT.jar`"
+
+          # Make sure running Python workers that contains pyspark.core once. 
They will be reused.
+          python -c "from pyspark.sql import SparkSession; _ = 
SparkSession.builder.remote('sc://localhost').getOrCreate().range(100).repartition(100).mapInPandas(lambda
 x: x, 'id INT').show(n=100)" > /dev/null
+
+          # Remove Py4J and PySpark zipped library to make sure there is no 
JVM connection
+          mv python/lib lib.back
+          mv python/pyspark lib.back
+
+          ./python/run-tests --parallelism=1 --python-executables=python3 
--testnames 
"pyspark.resource.tests.test_connect_resources,pyspark.sql.tests.connect.client.test_artifact,pyspark.sql.tests.connect.test_resources"
       - name: Upload test results to report
         if: always()
         uses: actions/upload-artifact@v4
diff --git a/python/packaging/connect/setup.py 
b/python/packaging/connect/setup.py
index 19925962804b..3f2d79a641bc 100755
--- a/python/packaging/connect/setup.py
+++ b/python/packaging/connect/setup.py
@@ -70,6 +70,7 @@ if "SPARK_TESTING" in os.environ:
     test_packages = [
         "pyspark.tests",  # for Memory profiler parity tests
         "pyspark.testing",
+        "pyspark.resource.tests",
         "pyspark.sql.tests",
         "pyspark.sql.tests.connect",
         "pyspark.sql.tests.connect.streaming",
diff --git a/python/pyspark/resource/tests/test_connect_resources.py 
b/python/pyspark/resource/tests/test_connect_resources.py
index 1529a33cb0ad..90bae85c2a1b 100644
--- a/python/pyspark/resource/tests/test_connect_resources.py
+++ b/python/pyspark/resource/tests/test_connect_resources.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 #
 import unittest
+import os
 
 from pyspark.resource import ResourceProfileBuilder, TaskResourceRequests, 
ExecutorResourceRequests
 from pyspark.sql import SparkSession
@@ -35,20 +36,20 @@ class ResourceProfileTests(unittest.TestCase):
         # check taskResources, similar to executorResources.
         self.assertEqual(rp.taskResources["cpus"].amount, 2.0)
 
-        # SparkContext is not initialized and is not remote.
-        with self.assertRaisesRegex(
-            RuntimeError, "SparkContext must be created to get the profile id."
-        ):
+        # SparkContext or SparkSesssion is not initialized.
+        with self.assertRaises(RuntimeError):
             rp.id
 
         # Remote mode.
-        spark = SparkSession.builder.remote("local-cluster[1, 2, 
1024]").getOrCreate()
+        spark = SparkSession.builder.remote(
+            os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local-cluster[1, 
2, 1024]")
+        ).getOrCreate()
         # Still can access taskResources, similar to executorResources.
         self.assertEqual(rp.taskResources["cpus"].amount, 2.0)
         rp.id
         df = spark.range(10)
-        df.mapInPandas(lambda x: x, df.schema, False, rp).collect()
-        df.mapInArrow(lambda x: x, df.schema, False, rp).collect()
+        df.mapInPandas(lambda x: x, df.schema, False, rp).show(n=10)
+        df.mapInArrow(lambda x: x, df.schema, False, rp).show(n=10)
 
         def assert_request_contents(exec_reqs, task_reqs):
             self.assertEqual(len(exec_reqs), 6)
diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py 
b/python/pyspark/sql/tests/connect/client/test_artifact.py
index f4f49ab25126..d0456f06d7a1 100644
--- a/python/pyspark/sql/tests/connect/client/test_artifact.py
+++ b/python/pyspark/sql/tests/connect/client/test_artifact.py
@@ -25,7 +25,7 @@ from pyspark.errors.exceptions.connect import 
SparkConnectGrpcException
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import ReusedConnectTestCase, 
should_test_connect
 from pyspark.testing.utils import SPARK_HOME
-from pyspark.sql.functions import udf
+from pyspark.sql.functions import udf, assert_true, lit
 
 if should_test_connect:
     from pyspark.sql.connect.client.artifact import ArtifactManager
@@ -46,7 +46,7 @@ class ArtifactTestsMixin:
                 return my_pyfile.my_func()
 
             spark_session.addArtifacts(pyfile_path, pyfile=True)
-            
self.assertEqual(spark_session.range(1).select(func("id")).first()[0], 10)
+            spark_session.range(1).select(assert_true(func("id") == 
lit(10))).show()
 
     def test_add_pyfile(self):
         self.check_add_pyfile(self.spark)
@@ -94,7 +94,7 @@ class ArtifactTestsMixin:
                 return my_zipfile.my_func()
 
             spark_session.addArtifacts(f"{package_path}.zip", pyfile=True)
-            
self.assertEqual(spark_session.range(1).select(func("id")).first()[0], 5)
+            spark_session.range(1).select(assert_true(func("id") == 
lit(5))).show()
 
     def test_add_zipped_package(self):
         self.check_add_zipped_package(self.spark)
@@ -130,7 +130,7 @@ class ArtifactTestsMixin:
                 ) as my_file:
                     return my_file.read().strip()
 
-            
self.assertEqual(spark_session.range(1).select(func("id")).first()[0], "hello 
world!")
+            spark_session.range(1).select(assert_true(func("id") == lit("hello 
world!"))).show()
 
     def test_add_archive(self):
         self.check_add_archive(self.spark)
@@ -160,7 +160,7 @@ class ArtifactTestsMixin:
                 with open(os.path.join(root, "my_file.txt"), "r") as my_file:
                     return my_file.read().strip()
 
-            
self.assertEqual(spark_session.range(1).select(func("id")).first()[0], "Hello 
world!!")
+            spark_session.range(1).select(assert_true(func("id") == lit("Hello 
world!!"))).show()
 
     def test_add_file(self):
         self.check_add_file(self.spark)
@@ -427,7 +427,6 @@ class ArtifactTests(ReusedConnectTestCase, 
ArtifactTestsMixin):
                 )
 
 
-@unittest.skipIf(is_remote_only(), "Requires local cluster to run")
 class LocalClusterArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin):
     @classmethod
     def conf(cls):
@@ -442,7 +441,7 @@ class LocalClusterArtifactTests(ReusedConnectTestCase, 
ArtifactTestsMixin):
 
     @classmethod
     def master(cls):
-        return "local-cluster[2,2,512]"
+        return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", 
"local-cluster[2,2,512]")
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/connect/test_resources.py 
b/python/pyspark/sql/tests/connect/test_resources.py
index 931acd929804..94d71b54ff05 100644
--- a/python/pyspark/sql/tests/connect/test_resources.py
+++ b/python/pyspark/sql/tests/connect/test_resources.py
@@ -15,19 +15,16 @@
 # limitations under the License.
 #
 import unittest
+import os
 
-from pyspark.util import is_remote_only
 from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.sql.tests.test_resources import ResourceProfileTestsMixin
 
 
-# TODO(SPARK-47757): Reeanble ResourceProfileTests for pyspark-connect
-if not is_remote_only():
-    from pyspark.sql.tests.test_resources import ResourceProfileTestsMixin
-
-    class ResourceProfileTests(ResourceProfileTestsMixin, 
ReusedConnectTestCase):
-        @classmethod
-        def master(cls):
-            return "local-cluster[1, 4, 1024]"
+class ResourceProfileTests(ResourceProfileTestsMixin, ReusedConnectTestCase):
+    @classmethod
+    def master(cls):
+        return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", 
"local-cluster[1, 4, 1024]")
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/test_resources.py 
b/python/pyspark/sql/tests/test_resources.py
index 9dfb14d9c37f..4ce61e9f763d 100644
--- a/python/pyspark/sql/tests/test_resources.py
+++ b/python/pyspark/sql/tests/test_resources.py
@@ -16,7 +16,7 @@
 #
 import unittest
 
-from pyspark import SparkContext, TaskContext
+from pyspark import TaskContext
 from pyspark.resource import TaskResourceRequests, ResourceProfileBuilder
 from pyspark.sql import SparkSession
 from pyspark.testing.sqlutils import (
@@ -41,7 +41,7 @@ class ResourceProfileTestsMixin(object):
                 yield batch
 
         df = self.spark.range(10)
-        df.mapInArrow(func, "id long").collect()
+        df.mapInArrow(func, "id long").show(n=10)
 
     def test_map_in_arrow_with_profile(self):
         def func(iterator):
@@ -54,7 +54,7 @@ class ResourceProfileTestsMixin(object):
 
         treqs = TaskResourceRequests().cpus(3)
         rp = ResourceProfileBuilder().require(treqs).build
-        df.mapInArrow(func, "id long", False, rp).collect()
+        df.mapInArrow(func, "id long", False, rp).show(n=10)
 
     def test_map_in_pandas_without_profile(self):
         def func(iterator):
@@ -64,7 +64,7 @@ class ResourceProfileTestsMixin(object):
                 yield batch
 
         df = self.spark.range(10)
-        df.mapInPandas(func, "id long").collect()
+        df.mapInPandas(func, "id long").show(n=10)
 
     def test_map_in_pandas_with_profile(self):
         def func(iterator):
@@ -77,12 +77,14 @@ class ResourceProfileTestsMixin(object):
 
         treqs = TaskResourceRequests().cpus(3)
         rp = ResourceProfileBuilder().require(treqs).build
-        df.mapInPandas(func, "id long", False, rp).collect()
+        df.mapInPandas(func, "id long", False, rp).show(n=10)
 
 
 class ResourceProfileTests(ResourceProfileTestsMixin, ReusedPySparkTestCase):
     @classmethod
     def setUpClass(cls):
+        from pyspark.core.context import SparkContext
+
         cls.sc = SparkContext("local-cluster[1, 4, 1024]", cls.__name__, 
conf=cls.conf())
         cls.spark = SparkSession(cls.sc)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to