This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new 7405a40930f [SPARK-41379][SS][PYTHON] Provide cloned spark session in DataFrame in user function for foreachBatch sink in PySpark 7405a40930f is described below commit 7405a40930fc58337df1d21822c913d727f1acdc Author: Jungtaek Lim <kabhwan.opensou...@gmail.com> AuthorDate: Mon Dec 5 14:39:23 2022 +0900 [SPARK-41379][SS][PYTHON] Provide cloned spark session in DataFrame in user function for foreachBatch sink in PySpark ### What changes were proposed in this pull request? This PR proposes to provide cloned spark session in DataFrame in user function for foreachBatch sink in PySpark. ### Why are the changes needed? It's arguable a bug - previously given DataFrame is associated with two different SparkSessions, 1) one which runs the streaming query (accessed via `df.sparkSession`) 2) one which microbatch execution "cloned" (accessed via `df._jdf.sparkSession()`). If users pick the 1), it destroys the purpose of cloning spark session, e.g. disabling AQE. Also, which session is picked up depends on the underlying implementation of "each" method in DataFrame, which would give inconsistency. Following is a problematic example: ``` def user_func(batch_df, batch_id): batch_df.createOrReplaceTempView("updates") ... # what is the right way to refer the temp view "updates"? ``` Before this PR, the only way to refer the temp view "updates" is, using "internal" field in DataFrame, `_jdf`. That said, running a new query via `batch_df._jdf.sparkSession()` can only see the temp view defined in the user function. We would like to make this possible without enforcing end users to access "internal" field. After this PR, they can (and should) use `batch_df.sparkSession` instead. ### Does this PR introduce _any_ user-facing change? Yes, this PR makes in sync to which spark session to use. Users can use df.sparkSession to access cloned spark session, which will be the same with the spark session the methods in DataFrame will use. ### How was this patch tested? New test case which fails with current master branch. Closes #38906 from HeartSaVioR/SPARK-41379. Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> (cherry picked from commit f4ec6f2eeef7f82d478a1047231f1de1bfc429bd) Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- python/pyspark/sql/tests/test_streaming.py | 21 +++++++++++++++++++++ python/pyspark/sql/utils.py | 6 +++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_streaming.py b/python/pyspark/sql/tests/test_streaming.py index 809294d34c3..076ca736af0 100644 --- a/python/pyspark/sql/tests/test_streaming.py +++ b/python/pyspark/sql/tests/test_streaming.py @@ -573,6 +573,27 @@ class StreamingTests(ReusedSQLTestCase): if q: q.stop() + def test_streaming_foreachBatch_tempview(self): + q = None + collected = dict() + + def collectBatch(batch_df, batch_id): + batch_df.createOrReplaceTempView("updates") + # it should use the spark session within given DataFrame, as microbatch execution will + # clone the session which is no longer same with the session used to start the + # streaming query + collected[batch_id] = batch_df.sparkSession.sql("SELECT * FROM updates").collect() + + try: + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.assertTrue(0 in collected) + self.assertTrue(len(collected[0]), 2) + finally: + if q: + q.stop() + def test_streaming_foreachBatch_propagates_python_errors(self): from pyspark.sql.utils import StreamingQueryException diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index e4a0299164e..3219af23c1a 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -264,9 +264,13 @@ class ForeachBatchFunction: def call(self, jdf: JavaObject, batch_id: int) -> None: from pyspark.sql.dataframe import DataFrame + from pyspark.sql.session import SparkSession try: - self.func(DataFrame(jdf, self.session), batch_id) + session_jdf = jdf.sparkSession() + # assuming that spark context is still the same between JVM and PySpark + wrapped_session_jdf = SparkSession(self.session.sparkContext, session_jdf) + self.func(DataFrame(jdf, wrapped_session_jdf), batch_id) except Exception as e: self.error = e raise e --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org