This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 7405a40930f [SPARK-41379][SS][PYTHON] Provide cloned spark session in 
DataFrame in user function for foreachBatch sink in PySpark
7405a40930f is described below

commit 7405a40930fc58337df1d21822c913d727f1acdc
Author: Jungtaek Lim <kabhwan.opensou...@gmail.com>
AuthorDate: Mon Dec 5 14:39:23 2022 +0900

    [SPARK-41379][SS][PYTHON] Provide cloned spark session in DataFrame in user 
function for foreachBatch sink in PySpark
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to provide cloned spark session in DataFrame in user 
function for foreachBatch sink in PySpark.
    
    ### Why are the changes needed?
    
    It's arguable a bug - previously given DataFrame is associated with two 
different SparkSessions, 1) one which runs the streaming query (accessed via 
`df.sparkSession`) 2) one which microbatch execution "cloned" (accessed via 
`df._jdf.sparkSession()`). If users pick the 1), it destroys the purpose of 
cloning spark session, e.g. disabling AQE. Also, which session is picked up 
depends on the underlying implementation of "each" method in DataFrame, which 
would give inconsistency.
    
    Following is a problematic example:
    
    ```
    def user_func(batch_df, batch_id):
      batch_df.createOrReplaceTempView("updates")
      ... # what is the right way to refer the temp view "updates"?
    ```
    
    Before this PR, the only way to refer the temp view "updates" is, using 
"internal" field in DataFrame, `_jdf`. That said, running a new query via 
`batch_df._jdf.sparkSession()` can only see the temp view defined in the user 
function. We would like to make this possible without enforcing end users to 
access "internal" field.
    
    After this PR, they can (and should) use `batch_df.sparkSession` instead.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, this PR makes in sync to which spark session to use. Users can use 
df.sparkSession to access cloned spark session, which will be the same with the 
spark session the methods in DataFrame will use.
    
    ### How was this patch tested?
    
    New test case which fails with current master branch.
    
    Closes #38906 from HeartSaVioR/SPARK-41379.
    
    Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    (cherry picked from commit f4ec6f2eeef7f82d478a1047231f1de1bfc429bd)
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 python/pyspark/sql/tests/test_streaming.py | 21 +++++++++++++++++++++
 python/pyspark/sql/utils.py                |  6 +++++-
 2 files changed, 26 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/tests/test_streaming.py 
b/python/pyspark/sql/tests/test_streaming.py
index 809294d34c3..076ca736af0 100644
--- a/python/pyspark/sql/tests/test_streaming.py
+++ b/python/pyspark/sql/tests/test_streaming.py
@@ -573,6 +573,27 @@ class StreamingTests(ReusedSQLTestCase):
             if q:
                 q.stop()
 
+    def test_streaming_foreachBatch_tempview(self):
+        q = None
+        collected = dict()
+
+        def collectBatch(batch_df, batch_id):
+            batch_df.createOrReplaceTempView("updates")
+            # it should use the spark session within given DataFrame, as 
microbatch execution will
+            # clone the session which is no longer same with the session used 
to start the
+            # streaming query
+            collected[batch_id] = batch_df.sparkSession.sql("SELECT * FROM 
updates").collect()
+
+        try:
+            df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
+            q = df.writeStream.foreachBatch(collectBatch).start()
+            q.processAllAvailable()
+            self.assertTrue(0 in collected)
+            self.assertTrue(len(collected[0]), 2)
+        finally:
+            if q:
+                q.stop()
+
     def test_streaming_foreachBatch_propagates_python_errors(self):
         from pyspark.sql.utils import StreamingQueryException
 
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index e4a0299164e..3219af23c1a 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -264,9 +264,13 @@ class ForeachBatchFunction:
 
     def call(self, jdf: JavaObject, batch_id: int) -> None:
         from pyspark.sql.dataframe import DataFrame
+        from pyspark.sql.session import SparkSession
 
         try:
-            self.func(DataFrame(jdf, self.session), batch_id)
+            session_jdf = jdf.sparkSession()
+            # assuming that spark context is still the same between JVM and 
PySpark
+            wrapped_session_jdf = SparkSession(self.session.sparkContext, 
session_jdf)
+            self.func(DataFrame(jdf, wrapped_session_jdf), batch_id)
         except Exception as e:
             self.error = e
             raise e


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to