This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 31e7c37 [SPARK-28185][PYTHON][SQL] Closes the generator when Python UDFs stop early 31e7c37 is described below commit 31e7c37354132545da59bff176af1613bd09447c Author: WeichenXu <weichen...@databricks.com> AuthorDate: Fri Jun 28 17:10:25 2019 +0900 [SPARK-28185][PYTHON][SQL] Closes the generator when Python UDFs stop early ## What changes were proposed in this pull request? Closes the generator when Python UDFs stop early. ### Manually verification on pandas iterator UDF and mapPartitions ```python from pyspark.sql import SparkSession from pyspark.sql.functions import pandas_udf, PandasUDFType from pyspark.sql.functions import col, udf from pyspark.taskcontext import TaskContext import time import os spark.conf.set('spark.sql.execution.arrow.maxRecordsPerBatch', '1') spark.conf.set('spark.sql.pandas.udf.buffer.size', '4') pandas_udf("int", PandasUDFType.SCALAR_ITER) def fi1(it): try: for batch in it: yield batch + 100 time.sleep(1.0) except BaseException as be: print("Debug: exception raised: " + str(type(be))) raise be finally: open("/tmp/000001.tmp", "a").close() df1 = spark.range(10).select(col('id').alias('a')).repartition(1) # will see log Debug: exception raised: <class 'GeneratorExit'> # and file "/tmp/000001.tmp" generated. df1.select(col('a'), fi1('a')).limit(2).collect() def mapper(it): try: for batch in it: yield batch except BaseException as be: print("Debug: exception raised: " + str(type(be))) raise be finally: open("/tmp/000002.tmp", "a").close() df2 = spark.range(10000000).repartition(1) # will see log Debug: exception raised: <class 'GeneratorExit'> # and file "/tmp/000002.tmp" generated. df2.rdd.mapPartitions(mapper).take(2) ``` ## How was this patch tested? Unit test added. Please review https://spark.apache.org/contributing.html before opening a pull request. Closes #24986 from WeichenXu123/pandas_iter_udf_limit. Authored-by: WeichenXu <weichen...@databricks.com> Signed-off-by: HyukjinKwon <gurwls...@apache.org> --- python/pyspark/sql/tests/test_pandas_udf_scalar.py | 37 ++++++++++++++++++++++ python/pyspark/worker.py | 7 +++- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index c291d42..d254508 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -850,6 +850,43 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): with self.assertRaisesRegexp(Exception, "reached finally block"): self.spark.range(1).select(test_close(col("id"))).collect() + def test_scalar_iter_udf_close_early(self): + tmp_dir = tempfile.mkdtemp() + try: + tmp_file = tmp_dir + '/reach_finally_block' + + @pandas_udf('int', PandasUDFType.SCALAR_ITER) + def test_close(batch_iter): + generator_exit_caught = False + try: + for batch in batch_iter: + yield batch + time.sleep(1.0) # avoid the function finish too fast. + except GeneratorExit as ge: + generator_exit_caught = True + raise ge + finally: + assert generator_exit_caught, "Generator exit exception was not caught." + open(tmp_file, 'a').close() + + with QuietTest(self.sc): + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 1, + "spark.sql.pandas.udf.buffer.size": 4}): + self.spark.range(10).repartition(1) \ + .select(test_close(col("id"))).limit(2).collect() + # wait here because python udf worker will take some time to detect + # jvm side socket closed and then will trigger `GenerateExit` raised. + # wait timeout is 10s. + for i in range(100): + time.sleep(0.1) + if os.path.exists(tmp_file): + break + + assert os.path.exists(tmp_file), "finally block not reached." + + finally: + shutil.rmtree(tmp_dir) + # Regression test for SPARK-23314 def test_timestamp_dst(self): # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index ee46bb6..04376c9 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -481,7 +481,12 @@ def main(infile, outfile): def process(): iterator = deserializer.load_stream(infile) - serializer.dump_stream(func(split_index, iterator), outfile) + out_iter = func(split_index, iterator) + try: + serializer.dump_stream(out_iter, outfile) + finally: + if hasattr(out_iter, 'close'): + out_iter.close() if profiler: profiler.profile(process) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org