This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 31e7c37  [SPARK-28185][PYTHON][SQL] Closes the generator when Python 
UDFs stop early
31e7c37 is described below

commit 31e7c37354132545da59bff176af1613bd09447c
Author: WeichenXu <weichen...@databricks.com>
AuthorDate: Fri Jun 28 17:10:25 2019 +0900

    [SPARK-28185][PYTHON][SQL] Closes the generator when Python UDFs stop early
    
    ## What changes were proposed in this pull request?
    
     Closes the generator when Python UDFs stop early.
    
    ### Manually verification on pandas iterator UDF and mapPartitions
    
    ```python
    from pyspark.sql import SparkSession
    from pyspark.sql.functions import pandas_udf, PandasUDFType
    from pyspark.sql.functions import col, udf
    from pyspark.taskcontext import TaskContext
    import time
    import os
    
    spark.conf.set('spark.sql.execution.arrow.maxRecordsPerBatch', '1')
    spark.conf.set('spark.sql.pandas.udf.buffer.size', '4')
    
    pandas_udf("int", PandasUDFType.SCALAR_ITER)
    def fi1(it):
        try:
            for batch in it:
                yield batch + 100
                time.sleep(1.0)
        except BaseException as be:
            print("Debug: exception raised: " + str(type(be)))
            raise be
        finally:
            open("/tmp/000001.tmp", "a").close()
    
    df1 = spark.range(10).select(col('id').alias('a')).repartition(1)
    
    # will see log Debug: exception raised: <class 'GeneratorExit'>
    # and file "/tmp/000001.tmp" generated.
    df1.select(col('a'), fi1('a')).limit(2).collect()
    
    def mapper(it):
        try:
            for batch in it:
                    yield batch
        except BaseException as be:
            print("Debug: exception raised: " + str(type(be)))
            raise be
        finally:
            open("/tmp/000002.tmp", "a").close()
    
    df2 = spark.range(10000000).repartition(1)
    
    # will see log Debug: exception raised: <class 'GeneratorExit'>
    # and file "/tmp/000002.tmp" generated.
    df2.rdd.mapPartitions(mapper).take(2)
    
    ```
    
    ## How was this patch tested?
    
    Unit test added.
    
    Please review https://spark.apache.org/contributing.html before opening a 
pull request.
    
    Closes #24986 from WeichenXu123/pandas_iter_udf_limit.
    
    Authored-by: WeichenXu <weichen...@databricks.com>
    Signed-off-by: HyukjinKwon <gurwls...@apache.org>
---
 python/pyspark/sql/tests/test_pandas_udf_scalar.py | 37 ++++++++++++++++++++++
 python/pyspark/worker.py                           |  7 +++-
 2 files changed, 43 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py 
b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
index c291d42..d254508 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
@@ -850,6 +850,43 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
             with self.assertRaisesRegexp(Exception, "reached finally block"):
                 self.spark.range(1).select(test_close(col("id"))).collect()
 
+    def test_scalar_iter_udf_close_early(self):
+        tmp_dir = tempfile.mkdtemp()
+        try:
+            tmp_file = tmp_dir + '/reach_finally_block'
+
+            @pandas_udf('int', PandasUDFType.SCALAR_ITER)
+            def test_close(batch_iter):
+                generator_exit_caught = False
+                try:
+                    for batch in batch_iter:
+                        yield batch
+                        time.sleep(1.0)  # avoid the function finish too fast.
+                except GeneratorExit as ge:
+                    generator_exit_caught = True
+                    raise ge
+                finally:
+                    assert generator_exit_caught, "Generator exit exception 
was not caught."
+                    open(tmp_file, 'a').close()
+
+            with QuietTest(self.sc):
+                with 
self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 1,
+                                    "spark.sql.pandas.udf.buffer.size": 4}):
+                    self.spark.range(10).repartition(1) \
+                        .select(test_close(col("id"))).limit(2).collect()
+                    # wait here because python udf worker will take some time 
to detect
+                    # jvm side socket closed and then will trigger 
`GenerateExit` raised.
+                    # wait timeout is 10s.
+                    for i in range(100):
+                        time.sleep(0.1)
+                        if os.path.exists(tmp_file):
+                            break
+
+                    assert os.path.exists(tmp_file), "finally block not 
reached."
+
+        finally:
+            shutil.rmtree(tmp_dir)
+
     # Regression test for SPARK-23314
     def test_timestamp_dst(self):
         # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 
am
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index ee46bb6..04376c9 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -481,7 +481,12 @@ def main(infile, outfile):
 
         def process():
             iterator = deserializer.load_stream(infile)
-            serializer.dump_stream(func(split_index, iterator), outfile)
+            out_iter = func(split_index, iterator)
+            try:
+                serializer.dump_stream(out_iter, outfile)
+            finally:
+                if hasattr(out_iter, 'close'):
+                    out_iter.close()
 
         if profiler:
             profiler.profile(process)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to