This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e7a466ebce7 [SPARK-41532][CONNECT][CLIENT] Add check for operations 
that involve multiple data frames
e7a466ebce7 is described below

commit e7a466ebce780664e7601f28b6dab02db4703871
Author: Hisoka <fanjiaemi...@qq.com>
AuthorDate: Mon May 8 09:29:00 2023 -0400

    [SPARK-41532][CONNECT][CLIENT] Add check for operations that involve 
multiple data frames
    
    ### What changes were proposed in this pull request?
    Add check for operations that involve multiple data frames,  because spark 
do not support joining for example two data frames from different Spark Connect 
Sessions.
    
    ### Why are the changes needed?
    Spark do not support joining for example two data frames from different 
Spark Connect Sessions. To avoid exceptions, the client should clearly fail 
when it tries to construct such a composition.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Add new test
    
    Closes #40684 from Hisoka-X/df_from_different_session.
    
    Authored-by: Hisoka <fanjiaemi...@qq.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  9 ++++++
 .../scala/org/apache/spark/sql/SparkSession.scala  |  3 ++
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  | 32 +++++++++++++++++++++-
 .../connect/client/util/RemoteSparkSession.scala   |  3 +-
 python/pyspark/errors/exceptions/base.py           |  6 ++++
 python/pyspark/sql/connect/dataframe.py            | 11 ++++++--
 python/pyspark/sql/connect/session.py              |  5 ++++
 .../sql/tests/connect/test_connect_basic.py        | 16 +++++++++++
 python/pyspark/testing/connectutils.py             |  2 ++
 9 files changed, 83 insertions(+), 4 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 3301b483b5e..555f6c312c5 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -22,6 +22,7 @@ import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.util.control.NonFatal
 
+import org.apache.spark.SparkException
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.api.java.function._
 import org.apache.spark.connect.proto
@@ -563,6 +564,7 @@ class Dataset[T] private[sql] (
   def stat: DataFrameStatFunctions = new DataFrameStatFunctions(sparkSession, 
plan.getRoot)
 
   private def buildJoin(right: Dataset[_])(f: proto.Join.Builder => Unit): 
DataFrame = {
+    checkSameSparkSession(right)
     sparkSession.newDataFrame { builder =>
       val joinBuilder = builder.getJoinBuilder
       joinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot)
@@ -1647,6 +1649,7 @@ class Dataset[T] private[sql] (
 
   private def buildSetOp(right: Dataset[T], setOpType: 
proto.SetOperation.SetOpType)(
       f: proto.SetOperation.Builder => Unit): Dataset[T] = {
+    checkSameSparkSession(right)
     sparkSession.newDataset(encoder) { builder =>
       f(
         builder.getSetOpBuilder
@@ -1656,6 +1659,12 @@ class Dataset[T] private[sql] (
     }
   }
 
+  private def checkSameSparkSession(other: Dataset[_]): Unit = {
+    if (this.sparkSession.sessionId != other.sparkSession.sessionId) {
+      throw new SparkException("Both Datasets must belong to the same 
SparkSession")
+    }
+  }
+
   /**
    * Returns a new Dataset containing union of rows in this Dataset and 
another Dataset.
    *
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index a8bfac5d71f..bc6cf32379f 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -71,6 +71,9 @@ class SparkSession private[sql] (
 
   private[this] val allocator = new RootAllocator()
 
+  // a unique session ID for this session from client.
+  private[sql] def sessionId: String = client.sessionId
+
   lazy val version: String = {
     
client.analyze(proto.AnalyzePlanRequest.AnalyzeCase.SPARK_VERSION).getSparkVersion.getVersion
   }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 8a01f828ef0..33fc5d5a4a2 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -32,11 +32,13 @@ import org.apache.commons.lang3.{JavaVersion, SystemUtils}
 import org.scalactic.TolerantNumerics
 import org.scalatest.concurrent.Eventually._
 
-import org.apache.spark.SPARK_VERSION
+import org.apache.spark.{SPARK_VERSION, SparkException}
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
 import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
 import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.connect.client.SparkConnectClient
 import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, 
RemoteSparkSession}
+import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils.port
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -193,6 +195,34 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper {
     }
   }
 
+  test("different spark session join/union") {
+    val df = spark.range(10).limit(3)
+
+    val spark2 = SparkSession
+      .builder()
+      .client(
+        SparkConnectClient
+          .builder()
+          .port(port)
+          .build())
+      .build()
+
+    val df2 = spark2.range(10).limit(3)
+
+    assertThrows[SparkException] {
+      df.union(df2).collect()
+    }
+
+    assertThrows[SparkException] {
+      df.unionByName(df2).collect()
+    }
+
+    assertThrows[SparkException] {
+      df.join(df2).collect()
+    }
+
+  }
+
   test("write without table or path") {
     // Should receive no error to write noop
     spark.range(10).write.format("noop").mode("append").save()
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
index 235605e3121..1476b16da5d 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
@@ -48,7 +48,8 @@ import 
org.apache.spark.sql.connect.common.config.ConnectCommon
 object SparkConnectServerUtils {
 
   // Server port
-  private[connect] val port = ConnectCommon.CONNECT_GRPC_BINDING_PORT + 
util.Random.nextInt(1000)
+  private[spark] val port: Int =
+    ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000)
 
   @volatile private var stopped = false
 
diff --git a/python/pyspark/errors/exceptions/base.py 
b/python/pyspark/errors/exceptions/base.py
index 1b9a6b0229e..fd1c07c4df6 100644
--- a/python/pyspark/errors/exceptions/base.py
+++ b/python/pyspark/errors/exceptions/base.py
@@ -102,6 +102,12 @@ class AnalysisException(PySparkException):
     """
 
 
+class SessionNotSameException(PySparkException):
+    """
+    Performed the same operation on different SparkSession.
+    """
+
+
 class TempTableAlreadyExistsException(AnalysisException):
     """
     Failed to create temp view since it is already exists.
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 50eadf46200..87f061f139e 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+from pyspark.errors.exceptions.base import SessionNotSameException
 from pyspark.sql.connect.utils import check_dependencies
 
 check_dependencies(__name__)
@@ -254,7 +255,7 @@ class DataFrame:
             raise Exception("Cannot cartesian join when self._plan is empty.")
         if other._plan is None:
             raise Exception("Cannot cartesian join when other._plan is empty.")
-
+        self.checkSameSparkSession(other)
         return DataFrame.withPlan(
             plan.Join(left=self._plan, right=other._plan, on=None, 
how="cross"),
             session=self._session,
@@ -262,6 +263,10 @@ class DataFrame:
 
     crossJoin.__doc__ = PySparkDataFrame.crossJoin.__doc__
 
+    def checkSameSparkSession(self, other: "DataFrame") -> None:
+        if self._session.session_id != other._session.session_id:
+            raise SessionNotSameException("Both Datasets must belong to the 
same SparkSession")
+
     def coalesce(self, numPartitions: int) -> "DataFrame":
         if not numPartitions > 0:
             raise PySparkValueError(
@@ -560,7 +565,7 @@ class DataFrame:
             raise Exception("Cannot join when other._plan is empty.")
         if how is not None and isinstance(how, str):
             how = how.lower().replace("_", "")
-
+        self.checkSameSparkSession(other)
         return DataFrame.withPlan(
             plan.Join(left=self._plan, right=other._plan, on=on, how=how),
             session=self._session,
@@ -1005,6 +1010,7 @@ class DataFrame:
                 error_class="MISSING_VALID_PLAN",
                 message_parameters={"operator": "Union"},
             )
+        self.checkSameSparkSession(other)
         return DataFrame.withPlan(
             plan.SetOperation(self._plan, other._plan, "union", is_all=True), 
session=self._session
         )
@@ -1017,6 +1023,7 @@ class DataFrame:
                 error_class="MISSING_VALID_PLAN",
                 message_parameters={"operator": "UnionByName"},
             )
+        self.checkSameSparkSession(other)
         return DataFrame.withPlan(
             plan.SetOperation(
                 self._plan,
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 3bd842f7847..4f8fa419119 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -209,6 +209,7 @@ class SparkSession:
             takes precedence.
         """
         self._client = SparkConnectClient(connection=connection, userId=userId)
+        self._session_id = self._client._session_id
 
     def table(self, tableName: str) -> DataFrame:
         return self.read.table(tableName)
@@ -706,6 +707,10 @@ class SparkSession:
         else:
             raise RuntimeError("There should not be an existing Spark Session 
or Spark Context.")
 
+    @property
+    def session_id(self) -> str:
+        return self._session_id
+
 
 SparkSession.__doc__ = PySparkSession.__doc__
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index d00a8a797ae..45dbe182f12 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -29,6 +29,7 @@ from pyspark.errors import (
     PySparkException,
     PySparkValueError,
 )
+from pyspark.errors.exceptions.base import SessionNotSameException
 from pyspark.sql import SparkSession as PySparkSession, Row
 from pyspark.sql.types import (
     StructType,
@@ -1796,6 +1797,21 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             "ShuffledHashJoin" in cdf1.join(cdf2.hint("SHUFFLE_HASH"), 
"name")._explain_string()
         )
 
+    def test_different_spark_session_join_or_union(self):
+        df = self.connect.range(10).limit(3)
+
+        spark2 = RemoteSparkSession(connection="sc://localhost")
+        df2 = spark2.range(10).limit(3)
+
+        with self.assertRaises(SessionNotSameException):
+            df.union(df2).collect()
+
+        with self.assertRaises(SessionNotSameException):
+            df.unionByName(df2).collect()
+
+        with self.assertRaises(SessionNotSameException):
+            df.join(df2).collect()
+
     def test_extended_hint_types(self):
         cdf = self.connect.range(100).toDF("id")
 
diff --git a/python/pyspark/testing/connectutils.py 
b/python/pyspark/testing/connectutils.py
index c1ca57aa3cc..68e08d5244f 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -20,6 +20,7 @@ import typing
 import os
 import functools
 import unittest
+import uuid
 
 from pyspark import Row, SparkConf
 from pyspark.testing.utils import PySparkErrorTestUtils
@@ -73,6 +74,7 @@ if should_test_connect:
 class MockRemoteSession:
     def __init__(self):
         self.hooks = {}
+        self.session_id = str(uuid.uuid4())
 
     def set_hook(self, name, hook):
         self.hooks[name] = hook


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to