This is an automated email from the ASF dual-hosted git repository. yangjie01 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new fa2e53f6105 [SPARK-44784][CONNECT] Make SBT testing hermetic fa2e53f6105 is described below commit fa2e53f6105c60effdf210cab4c8d77f13fee6b6 Author: Herman van Hovell <her...@databricks.com> AuthorDate: Sun Aug 27 11:28:58 2023 +0800 [SPARK-44784][CONNECT] Make SBT testing hermetic ### What changes were proposed in this pull request? This PR makes a bunch of changes to connect testing for the scala client: - We do not start the connect server with the `SPARK_DIST_CLASSPATH ` environment variable. This is set by the build system, but its value for SBT and Maven is different. For SBT it also contained the client code. - We use dependency upload to add the dependencies needed for the tests. Currently this entails: the compiled test classes (class files), scalatest jars, and scalactic jars. - The use of classfile sync unearthed an issue with stubbing and the `ExecutorClassLoader`. If they load classes in the same namespace then stubbing will generate stubs for classes that can be loaded by the `ExecutorClassLoader`. Since this is mostly a testing issue I decided to move the test code to a different namespace. We should definitely fix this later on. - A bunch of tiny fixes. ### Why are the changes needed? SBT testing for connect leaked client side code into the server. This is a problem because tests pass and we sign-off on features that do not work when well in a normal environment. Stubbing was an example of this. Maven did not have this problem and was therefore more correct. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? It are mostly tests. ### Was this patch authored or co-authored using generative AI tooling? No. I write my own code thank you... Closes #42591 from hvanhovell/investigate-stubbing. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: yangjie01 <yangji...@baidu.com> (cherry picked from commit 9326615592eac14c7cab3dd126b3c21222b7778f) Signed-off-by: yangjie01 <yangji...@baidu.com> --- connector/connect/bin/spark-connect-build | 6 +- connector/connect/bin/spark-connect-scala-client | 2 +- .../bin/spark-connect-scala-client-classpath | 3 +- connector/connect/client/jvm/pom.xml | 5 - .../apache/spark/sql/connect/client/ToStub.scala} | 11 +- .../org/apache/spark/sql/JavaEncoderSuite.java | 12 +- .../scala/org/apache/spark/sql/CatalogSuite.scala | 2 +- .../spark/sql/ClientDataFrameStatSuite.scala | 2 +- .../org/apache/spark/sql/ClientDatasetSuite.scala | 2 +- .../org/apache/spark/sql/ClientE2ETestSuite.scala | 4 +- .../org/apache/spark/sql/ColumnTestSuite.scala | 2 +- .../spark/sql/DataFrameNaFunctionSuite.scala | 2 +- .../org/apache/spark/sql/FunctionTestSuite.scala | 2 +- .../sql/KeyValueGroupedDatasetE2ETestSuite.scala | 2 +- .../apache/spark/sql/PlanGenerationTestSuite.scala | 3 +- .../apache/spark/sql/SQLImplicitsTestSuite.scala | 2 +- .../apache/spark/sql/SparkSessionE2ESuite.scala | 2 +- .../org/apache/spark/sql/SparkSessionSuite.scala | 2 +- .../org/apache/spark/sql/StubbingTestSuite.scala} | 23 ++- .../client => }/UDFClassLoadingE2ESuite.scala | 5 +- .../sql/UserDefinedFunctionE2ETestSuite.scala | 16 +- .../spark/sql/UserDefinedFunctionSuite.scala | 2 +- .../spark/sql/application/ReplE2ESuite.scala | 2 +- .../spark/sql/connect/client/ArtifactSuite.scala | 2 +- .../CheckConnectJvmClientCompatibility.scala | 2 +- .../sql/connect/client/ClassFinderSuite.scala | 2 +- .../SparkConnectClientBuilderParseTestSuite.scala | 2 +- .../connect/client/SparkConnectClientSuite.scala | 2 +- .../connect/client/arrow/ArrowEncoderSuite.scala | 2 +- .../connect/client/util/RemoteSparkSession.scala | 228 --------------------- .../sql/streaming/ClientStreamingQuerySuite.scala | 4 +- .../FlatMapGroupsWithStateStreamingSuite.scala | 4 +- .../streaming/StreamingQueryProgressSuite.scala | 2 +- .../client/util => test}/ConnectFunSuite.scala | 2 +- .../util => test}/IntegrationTestUtils.scala | 22 +- .../{connect/client/util => test}/QueryTest.scala | 2 +- .../apache/spark/sql/test/RemoteSparkSession.scala | 224 ++++++++++++++++++++ .../apache/spark/sql/{ => test}/SQLHelper.scala | 3 +- .../spark/sql/connect/client/ArtifactManager.scala | 20 +- .../sql/connect/client/GrpcRetryHandler.scala | 4 +- .../sql/connect/client/SparkConnectClient.scala | 2 +- .../spark/sql/connect/common/ProtoDataTypes.scala | 2 +- .../sql/connect/common/config/ConnectCommon.scala | 2 +- .../artifact/SparkConnectArtifactManager.scala | 40 +++- .../org/apache/spark/util/StubClassLoader.scala | 5 +- project/SparkBuild.scala | 1 - 46 files changed, 369 insertions(+), 324 deletions(-) diff --git a/connector/connect/bin/spark-connect-build b/connector/connect/bin/spark-connect-build index ca8d4cf6e90..63c17d8f7aa 100755 --- a/connector/connect/bin/spark-connect-build +++ b/connector/connect/bin/spark-connect-build @@ -29,7 +29,5 @@ SCALA_BINARY_VER=`grep "scala.binary.version" "${SPARK_HOME}/pom.xml" | head -n1 SCALA_VER=`grep "scala.version" "${SPARK_HOME}/pom.xml" | grep ${SCALA_BINARY_VER} | head -n1 | awk -F '[<>]' '{print $3}'` SCALA_ARG="-Pscala-${SCALA_BINARY_VER}" -# Build the jars needed for spark submit and spark connect -build/sbt "${SCALA_ARG}" -Phive -Pconnect package || exit 1 -# Build the jars needed for spark connect JVM client -build/sbt "${SCALA_ARG}" "sql/package;connect-client-jvm/assembly" || exit 1 +# Build the jars needed for spark submit and spark connect JVM client +build/sbt "${SCALA_ARG}" -Phive -Pconnect package "connect-client-jvm/package" || exit 1 diff --git a/connector/connect/bin/spark-connect-scala-client b/connector/connect/bin/spark-connect-scala-client index ef394df4e0f..ffa77f70842 100755 --- a/connector/connect/bin/spark-connect-scala-client +++ b/connector/connect/bin/spark-connect-scala-client @@ -45,7 +45,7 @@ SCALA_ARG="-Pscala-${SCALA_BINARY_VER}" SCBUILD="${SCBUILD:-1}" if [ "$SCBUILD" -eq "1" ]; then # Build the jars needed for spark connect JVM client - build/sbt "${SCALA_ARG}" "sql/package;connect-client-jvm/assembly" || exit 1 + build/sbt "${SCALA_ARG}" "connect-client-jvm/package" || exit 1 fi if [ -z "$SCCLASSPATH" ]; then diff --git a/connector/connect/bin/spark-connect-scala-client-classpath b/connector/connect/bin/spark-connect-scala-client-classpath index 99a22f3d5ff..9d33e90bf09 100755 --- a/connector/connect/bin/spark-connect-scala-client-classpath +++ b/connector/connect/bin/spark-connect-scala-client-classpath @@ -30,6 +30,5 @@ SCALA_VER=`grep "scala.version" "${SPARK_HOME}/pom.xml" | grep ${SCALA_BINARY_VE SCALA_ARG="-Pscala-${SCALA_BINARY_VER}" CONNECT_CLASSPATH="$(build/sbt "${SCALA_ARG}" -DcopyDependencies=false "export connect-client-jvm/fullClasspath" | grep jar | tail -n1)" -SQL_CLASSPATH="$(build/sbt "${SCALA_ARG}" -DcopyDependencies=false "export sql/fullClasspath" | grep jar | tail -n1)" -echo "$CONNECT_CLASSPATH:$CLASSPATH" +echo "$CONNECT_CLASSPATH" diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml index 4ea3ef1e48b..a7e5c5c2bab 100644 --- a/connector/connect/client/jvm/pom.xml +++ b/connector/connect/client/jvm/pom.xml @@ -113,11 +113,6 @@ <artifactId>scalacheck_${scala.binary.version}</artifactId> <scope>test</scope> </dependency> - <dependency> - <groupId>org.mockito</groupId> - <artifactId>mockito-core</artifactId> - <scope>test</scope> - </dependency> <!-- Use mima to perform the compatibility check --> <dependency> <groupId>com.typesafe</groupId> diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ToStub.scala similarity index 78% copy from connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala copy to connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ToStub.scala index 3f594d79b62..9a5fda1189d 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ToStub.scala @@ -14,9 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.connect.common.config +package org.apache.spark.sql.connect.client -private[connect] object ConnectCommon { - val CONNECT_GRPC_BINDING_PORT: Int = 15002 - val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024; -} +/** + * Class used to test stubbing. This needs to be in the main source tree, because this is not + * synced with the connect server during tests. + */ +case class ToStub(value: Long) diff --git a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java index 6e5fb72d496..d5fdede774f 100644 --- a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java +++ b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java @@ -27,9 +27,8 @@ import static org.junit.Assert.*; import static org.apache.spark.sql.Encoders.*; import static org.apache.spark.sql.functions.*; import static org.apache.spark.sql.RowFactory.create; -import org.apache.spark.sql.connect.client.SparkConnectClient; -import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils; import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.test.SparkConnectServerUtils; import org.apache.spark.sql.types.StructType; /** @@ -40,14 +39,7 @@ public class JavaEncoderSuite implements Serializable { @BeforeClass public static void setup() { - SparkConnectServerUtils.start(); - spark = SparkSession - .builder() - .client(SparkConnectClient - .builder() - .port(SparkConnectServerUtils.port()) - .build()) - .create(); + spark = SparkConnectServerUtils.createSparkSession(); } @AfterClass diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala index fa97498f7e7..cefa63ecd35 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala @@ -22,7 +22,7 @@ import java.io.{File, FilenameFilter} import org.apache.commons.io.FileUtils import org.apache.spark.SparkException -import org.apache.spark.sql.connect.client.util.RemoteSparkSession +import org.apache.spark.sql.test.{RemoteSparkSession, SQLHelper} import org.apache.spark.sql.types.{DoubleType, LongType, StructType} import org.apache.spark.storage.StorageLevel diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala index 2f4e1aa9bd0..069d8ec502f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala @@ -22,7 +22,7 @@ import java.util.Random import org.scalatest.matchers.must.Matchers._ import org.apache.spark.{SparkException, SparkIllegalArgumentException} -import org.apache.spark.sql.connect.client.util.RemoteSparkSession +import org.apache.spark.sql.test.RemoteSparkSession class ClientDataFrameStatSuite extends RemoteSparkSession { private def toLetter(i: Int): String = (i + 97).toChar.toString diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala index ae20f771d6c..a521c6745a9 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala @@ -26,8 +26,8 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.connect.proto import org.apache.spark.sql.connect.client.{DummySparkConnectService, SparkConnectClient} -import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.ConnectFunSuite // Add sample tests. // - sample fraction: simple.sample(0.1) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 86973b82e72..fd443b73925 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -36,10 +36,10 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult} -import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession} -import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils.port import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SqlApiConf +import org.apache.spark.sql.test.{IntegrationTestUtils, RemoteSparkSession, SQLHelper} +import org.apache.spark.sql.test.SparkConnectServerUtils.port import org.apache.spark.sql.types._ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala index 0d361fe1007..a88d6ec116a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala @@ -21,7 +21,7 @@ import java.io.ByteArrayOutputStream import scala.collection.JavaConverters._ import org.apache.spark.sql.{functions => fn} -import org.apache.spark.sql.connect.client.util.ConnectFunSuite +import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.sql.types._ /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala index ac64d4411a8..393fa19fa70 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ -import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.internal.SqlApiConf +import org.apache.spark.sql.test.{QueryTest, SQLHelper} import org.apache.spark.sql.types.{StringType, StructType} class DataFrameNaFunctionSuite extends QueryTest with SQLHelper { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala index 4a8e108357f..78cc26d627c 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala @@ -21,9 +21,9 @@ import java.util.Collections import scala.collection.JavaConverters._ import org.apache.spark.sql.avro.{functions => avroFn} -import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.functions._ import org.apache.spark.sql.protobuf.{functions => pbFn} +import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.sql.types.{DataType, StructType} /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala index 380ca2fb72b..3e979be73a7 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala @@ -20,9 +20,9 @@ import java.sql.Timestamp import java.util.Arrays import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append -import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout} +import org.apache.spark.sql.test.{QueryTest, SQLHelper} import org.apache.spark.sql.types._ case class ClickEvent(id: String, timestamp: Timestamp) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 11d0696b6e1..4916ff1f597 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -37,11 +37,10 @@ import org.apache.spark.sql.avro.{functions => avroFn} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.connect.client.SparkConnectClient -import org.apache.spark.sql.connect.client.util.ConnectFunSuite -import org.apache.spark.sql.connect.client.util.IntegrationTestUtils import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit import org.apache.spark.sql.protobuf.{functions => pbFn} +import org.apache.spark.sql.test.{ConnectFunSuite, IntegrationTestUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.SparkFileUtils diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala index 6db38bfb1c3..680380c91a0 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.connect.client.arrow.{ArrowDeserializers, ArrowSerializer} -import org.apache.spark.sql.connect.client.util.ConnectFunSuite +import org.apache.spark.sql.test.ConnectFunSuite /** * Test suite for SQL implicits. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala index 490bdf9cd86..c76dc724828 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala @@ -26,7 +26,7 @@ import scala.util.{Failure, Success} import org.scalatest.concurrent.Eventually._ import org.apache.spark.SparkException -import org.apache.spark.sql.connect.client.util.RemoteSparkSession +import org.apache.spark.sql.test.RemoteSparkSession import org.apache.spark.util.SparkThreadUtils.awaitResult /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala index 4aa8b4360ee..90fe8f57d07 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala @@ -22,7 +22,7 @@ import scala.util.control.NonFatal import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor} -import org.apache.spark.sql.connect.client.util.ConnectFunSuite +import org.apache.spark.sql.test.ConnectFunSuite /** * Tests for non-dataframe related SparkSession operations. diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/StubbingTestSuite.scala similarity index 62% copy from connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala copy to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/StubbingTestSuite.scala index 3f594d79b62..b9c5888e5cb 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/StubbingTestSuite.scala @@ -14,9 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.connect.common.config +package org.apache.spark.sql -private[connect] object ConnectCommon { - val CONNECT_GRPC_BINDING_PORT: Int = 15002 - val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024; +import org.apache.spark.sql.connect.client.ToStub +import org.apache.spark.sql.test.RemoteSparkSession + +class StubbingTestSuite extends RemoteSparkSession { + private def eval[T](f: => T): T = f + + test("capture of to-be stubbed class") { + val session = spark + import session.implicits._ + val result = spark + .range(0, 10, 1, 1) + .map(n => n + 1) + .as[ToStub] + .head() + eval { + assert(result.value == 1) + } + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/UDFClassLoadingE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala similarity index 94% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/UDFClassLoadingE2ESuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala index 8fdb7efbcba..a76e046db2e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/UDFClassLoadingE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala @@ -14,17 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.connect.client +package org.apache.spark.sql import java.io.File import java.nio.file.{Files, Paths} import scala.util.Properties -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connect.client.util.RemoteSparkSession import org.apache.spark.sql.connect.common.ProtoDataTypes import org.apache.spark.sql.expressions.ScalarUserDefinedFunction +import org.apache.spark.sql.test.RemoteSparkSession class UDFClassLoadingE2ESuite extends RemoteSparkSession { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala index d00659ac2d8..0af8c78a1da 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala @@ -26,8 +26,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.api.java.function._ import org.apache.spark.sql.api.java.UDF2 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, PrimitiveLongEncoder} -import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.functions.{col, struct, udf} +import org.apache.spark.sql.test.QueryTest import org.apache.spark.sql.types.IntegerType /** @@ -215,33 +215,31 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest { } test("Dataset foreachPartition") { - val sum = new AtomicLong() val func: Iterator[JLong] => Unit = f => { + val sum = new AtomicLong() f.foreach(v => sum.addAndGet(v)) - // The value should be 45 - assert(sum.get() == -1) + throw new Exception("Success, processed records: " + sum.get()) } val exception = intercept[Exception] { spark.range(10).repartition(1).foreachPartition(func) } - assert(exception.getMessage.contains("45 did not equal -1")) + assert(exception.getMessage.contains("Success, processed records: 45")) } test("Dataset foreachPartition - java") { val sum = new AtomicLong() val exception = intercept[Exception] { spark - .range(10) + .range(11) .repartition(1) .foreachPartition(new ForeachPartitionFunction[JLong] { override def call(t: JIterator[JLong]): Unit = { t.asScala.foreach(v => sum.addAndGet(v)) - // The value should be 45 - assert(sum.get() == -1) + throw new Exception("Success, processed records: " + sum.get()) } }) } - assert(exception.getMessage.contains("45 did not equal -1")) + assert(exception.getMessage.contains("Success, processed records: 55")) } test("Dataset foreach: change not visible to client") { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala index 76608559866..923aa5af75b 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala @@ -20,9 +20,9 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.connect.common.UdfPacket import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.util.SparkSerDeUtils class UserDefinedFunctionSuite extends ConnectFunSuite { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index 5a909ab8b41..4106d298dbe 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -25,7 +25,7 @@ import scala.util.Properties import org.apache.commons.io.output.ByteArrayOutputStream import org.scalatest.BeforeAndAfterEach -import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession} +import org.apache.spark.sql.test.{IntegrationTestUtils, RemoteSparkSession} class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala index 7901008bc12..770143f2e9b 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala @@ -30,7 +30,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.connect.proto.AddArtifactsRequest import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration -import org.apache.spark.sql.connect.client.util.ConnectFunSuite +import org.apache.spark.sql.test.ConnectFunSuite class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 8f226eb2f7e..1f599f2346e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -24,7 +24,7 @@ import java.util.regex.Pattern import com.typesafe.tools.mima.core._ import com.typesafe.tools.mima.lib.MiMaLib -import org.apache.spark.sql.connect.client.util.IntegrationTestUtils._ +import org.apache.spark.sql.test.IntegrationTestUtils._ /** * A tool for checking the binary compatibility of the connect client API against the spark SQL diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala index 625d4cf43e1..ca23436675f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala @@ -20,7 +20,7 @@ import java.nio.file.Paths import org.apache.commons.io.FileUtils -import org.apache.spark.sql.connect.client.util.ConnectFunSuite +import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.util.SparkFileUtils class ClassFinderSuite extends ConnectFunSuite { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala index 1dc1fd567ec..e1d4a18d0ff 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.connect.client import java.util.UUID -import org.apache.spark.sql.connect.client.util.ConnectFunSuite +import org.apache.spark.sql.test.ConnectFunSuite /** * Test suite for [[SparkConnectClient.Builder]] parsing and configuration. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index 6348e0e49ca..80e245ec78b 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -31,8 +31,8 @@ import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, ArtifactStatusesRequest, ArtifactStatusesResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.connect.common.config.ConnectCommon +import org.apache.spark.sql.test.ConnectFunSuite class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 2a499cc548f..b6ad27d3e52 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._ import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._ import org.apache.spark.sql.connect.client.CloseableIterator import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum -import org.apache.spark.sql.connect.client.util.ConnectFunSuite +import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StructType, UserDefinedType, YearMonthIntervalType} /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala deleted file mode 100644 index 33540bf4985..00000000000 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala +++ /dev/null @@ -1,228 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.connect.client.util - -import java.io.{BufferedOutputStream, File} -import java.util.concurrent.TimeUnit - -import scala.io.Source - -import org.scalatest.BeforeAndAfterAll -import sys.process._ - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connect.client.SparkConnectClient -import org.apache.spark.sql.connect.client.util.IntegrationTestUtils._ -import org.apache.spark.sql.connect.common.config.ConnectCommon - -/** - * An util class to start a local spark connect server in a different process for local E2E tests. - * Pre-running the tests, the spark connect artifact needs to be built using e.g. `build/sbt - * package`. It is designed to start the server once but shared by all tests. It is equivalent to - * use the following command to start the connect server via command line: - * - * {{{ - * bin/spark-shell \ - * --jars `ls connector/connect/server/target/**/spark-connect*SNAPSHOT.jar | paste -sd ',' -` \ - * --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin - * }}} - * - * Set system property `spark.test.home` or env variable `SPARK_HOME` if the test is not executed - * from the Spark project top folder. Set system property `spark.debug.sc.jvm.client=true` to - * print the server process output in the console to debug server start stop problems. - */ -object SparkConnectServerUtils { - - // Server port - val port: Int = - ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000) - - @volatile private var stopped = false - - private var consoleOut: BufferedOutputStream = _ - private val serverStopCommand = "q" - - private lazy val sparkConnect: Process = { - debug("Starting the Spark Connect Server...") - val connectJar = findJar( - "connector/connect/server", - "spark-connect-assembly", - "spark-connect").getCanonicalPath - - val builder = Process( - Seq( - "bin/spark-submit", - "--driver-class-path", - connectJar, - "--conf", - s"spark.connect.grpc.binding.port=$port") ++ testConfigs ++ debugConfigs ++ Seq( - "--class", - "org.apache.spark.sql.connect.SimpleSparkConnectService", - connectJar), - new File(sparkHome)) - - val io = new ProcessIO( - in => consoleOut = new BufferedOutputStream(in), - out => Source.fromInputStream(out).getLines.foreach(debug), - err => Source.fromInputStream(err).getLines.foreach(debug)) - val process = builder.run(io) - - // Adding JVM shutdown hook - sys.addShutdownHook(stop()) - process - } - - /** - * As one shared spark will be started for all E2E tests, for tests that needs some special - * configs, we add them here - */ - private def testConfigs: Seq[String] = { - // To find InMemoryTableCatalog for V2 writer tests - val catalystTestJar = - tryFindJar("sql/catalyst", "spark-catalyst", "spark-catalyst", test = true) - .map(clientTestJar => Seq(clientTestJar.getCanonicalPath)) - .getOrElse(Seq.empty) - - // For UDF maven E2E tests, the server needs the client code to find the UDFs defined in tests. - val connectClientTestJar = tryFindJar( - "connector/connect/client/jvm", - // SBT passes the client & test jars to the server process automatically. - // So we skip building or finding this jar for SBT. - "sbt-tests-do-not-need-this-jar", - "spark-connect-client-jvm", - test = true) - .map(clientTestJar => Seq(clientTestJar.getCanonicalPath)) - .getOrElse(Seq.empty) - - val allJars = catalystTestJar ++ connectClientTestJar - val jarsConfigs = Seq("--jars", allJars.mkString(",")) - - // Use InMemoryTableCatalog for V2 writer tests - val writerV2Configs = Seq( - "--conf", - "spark.sql.catalog.testcat=org.apache.spark.sql.connector.catalog.InMemoryTableCatalog") - - // Run tests using hive - val hiveTestConfigs = { - val catalogImplementation = if (IntegrationTestUtils.isSparkHiveJarAvailable) { - "hive" - } else { - // scalastyle:off println - println( - "Will start Spark Connect server with `spark.sql.catalogImplementation=in-memory`, " + - "some tests that rely on Hive will be ignored. If you don't want to skip them:\n" + - "1. Test with maven: run `build/mvn install -DskipTests -Phive` before testing\n" + - "2. Test with sbt: run test with `-Phive` profile") - // scalastyle:on println - // SPARK-43647: Proactively cleaning the `classes` and `test-classes` dir of hive - // module to avoid unexpected loading of `DataSourceRegister` in hive module during - // testing without `-Phive` profile. - IntegrationTestUtils.cleanUpHiveClassesDirIfNeeded() - "in-memory" - } - Seq("--conf", s"spark.sql.catalogImplementation=$catalogImplementation") - } - - // Make the server terminate reattachable streams every 1 second and 123 bytes, - // to make the tests exercise reattach. - val reattachExecuteConfigs = Seq( - "--conf", - "spark.connect.execute.reattachable.senderMaxStreamDuration=1s", - "--conf", - "spark.connect.execute.reattachable.senderMaxStreamSize=123") - - jarsConfigs ++ writerV2Configs ++ hiveTestConfigs ++ reattachExecuteConfigs - } - - def start(): Unit = { - assert(!stopped) - sparkConnect - } - - def stop(): Int = { - stopped = true - debug("Stopping the Spark Connect Server...") - try { - consoleOut.write(serverStopCommand.getBytes) - consoleOut.flush() - consoleOut.close() - } catch { - case e: Throwable => - debug(e) - sparkConnect.destroy() - } - - val code = sparkConnect.exitValue() - debug(s"Spark Connect Server is stopped with exit code: $code") - code - } -} - -trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll { - import SparkConnectServerUtils._ - var spark: SparkSession = _ - protected lazy val serverPort: Int = port - - override def beforeAll(): Unit = { - super.beforeAll() - SparkConnectServerUtils.start() - spark = SparkSession - .builder() - .client(SparkConnectClient.builder().port(serverPort).build()) - .create() - - // Retry and wait for the server to start - val stop = System.nanoTime() + TimeUnit.MINUTES.toNanos(1) // ~1 min - var sleepInternalMs = TimeUnit.SECONDS.toMillis(1) // 1s with * 2 backoff - var success = false - val error = new RuntimeException(s"Failed to start the test server on port $serverPort.") - - while (!success && System.nanoTime() < stop) { - try { - // Run a simple query to verify the server is really up and ready - val result = spark - .sql("select val from (values ('Hello'), ('World')) as t(val)") - .collect() - assert(result.length == 2) - success = true - debug("Spark Connect Server is up.") - } catch { - // ignored the error - case e: Throwable => - error.addSuppressed(e) - Thread.sleep(sleepInternalMs) - sleepInternalMs *= 2 - } - } - - // Throw error if failed - if (!success) { - debug(error) - throw error - } - } - - override def afterAll(): Unit = { - try { - if (spark != null) spark.stop() - } catch { - case e: Throwable => debug(e) - } - spark = null - super.afterAll() - } -} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala index 944a999a860..dc4d441ec30 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala @@ -29,11 +29,11 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession, SQLHelper} -import org.apache.spark.sql.connect.client.util.QueryTest +import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession} import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.window import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryStartedEvent, QueryTerminatedEvent} +import org.apache.spark.sql.test.{QueryTest, SQLHelper} import org.apache.spark.util.SparkFileUtils class ClientStreamingQuerySuite extends QueryTest with SQLHelper with Logging { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala index cdb6b9a2e9c..2fab6e8e3c8 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala @@ -23,9 +23,9 @@ import org.scalatest.concurrent.Eventually.eventually import org.scalatest.concurrent.Futures.timeout import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.{SparkSession, SQLHelper} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append -import org.apache.spark.sql.connect.client.util.QueryTest +import org.apache.spark.sql.test.{QueryTest, SQLHelper} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} case class ClickEvent(id: String, timestamp: Timestamp) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala index a6a44c1bd71..1a72252a417 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala @@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema -import org.apache.spark.sql.connect.client.util.ConnectFunSuite +import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.sql.types.StructType class StreamingQueryProgressSuite extends ConnectFunSuite { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala index 0a1e794c8e7..8d69d91a34f 100755 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.connect.client.util +package org.apache.spark.sql.test import java.nio.file.Path diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/IntegrationTestUtils.scala similarity index 87% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/IntegrationTestUtils.scala index 4d88565308f..61d08912aec 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/IntegrationTestUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.connect.client.util +package org.apache.spark.sql.test import java.io.File import java.nio.file.{Files, Paths} @@ -30,8 +30,12 @@ object IntegrationTestUtils { // System properties used for testing and debugging private val DEBUG_SC_JVM_CLIENT = "spark.debug.sc.jvm.client" + private val DEBUG_SC_JVM_CLIENT_ENV = "SPARK_DEBUG_SC_JVM_CLIENT" // Enable this flag to print all server logs to the console - private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, "false").toBoolean + private[sql] val isDebug = { + System.getProperty(DEBUG_SC_JVM_CLIENT, "false").toBoolean || + Option(System.getenv(DEBUG_SC_JVM_CLIENT_ENV)).exists(_.toBoolean) + } private[sql] lazy val scalaVersion = { versionNumberString.split('.') match { @@ -49,8 +53,14 @@ object IntegrationTestUtils { sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) } - private[connect] def debugConfigs: Seq[String] = { - val log4j2 = s"$sparkHome/connector/connect/client/jvm/src/test/resources/log4j2.properties" + private[sql] lazy val connectClientHomeDir = s"$sparkHome/connector/connect/client/jvm" + + private[sql] lazy val connectClientTestClassDir = { + s"$connectClientHomeDir/target/$scalaDir/test-classes" + } + + private[sql] def debugConfigs: Seq[String] = { + val log4j2 = s"$connectClientHomeDir/src/test/resources/log4j2.properties" if (isDebug) { Seq( // Enable to see the server plan change log @@ -70,9 +80,9 @@ object IntegrationTestUtils { // Log server start stop debug info into console // scalastyle:off println - private[connect] def debug(msg: String): Unit = if (isDebug) println(msg) + private[sql] def debug(msg: String): Unit = if (isDebug) println(msg) // scalastyle:on println - private[connect] def debug(error: Throwable): Unit = if (isDebug) error.printStackTrace() + private[sql] def debug(error: Throwable): Unit = if (isDebug) error.printStackTrace() private[sql] lazy val isSparkHiveJarAvailable: Boolean = { val filePath = s"$sparkHome/assembly/target/$scalaDir/jars/" + diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala similarity index 99% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala index a0d3d4368dd..adbd8286090 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.connect.client.util +package org.apache.spark.sql.test import java.util.TimeZone diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala new file mode 100644 index 00000000000..8a8f739a7c5 --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.test + +import java.io.{File, IOException, OutputStream} +import java.lang.ProcessBuilder +import java.lang.ProcessBuilder.Redirect +import java.nio.file.Paths +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration.FiniteDuration + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkBuildInfo +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connect.client.GrpcRetryHandler.RetryPolicy +import org.apache.spark.sql.connect.client.SparkConnectClient +import org.apache.spark.sql.connect.common.config.ConnectCommon +import org.apache.spark.sql.test.IntegrationTestUtils._ + +/** + * An util class to start a local spark connect server in a different process for local E2E tests. + * Pre-running the tests, the spark connect artifact needs to be built using e.g. `build/sbt + * package`. It is designed to start the server once but shared by all tests. It is equivalent to + * use the following command to start the connect server via command line: + * + * {{{ + * bin/spark-shell \ + * --jars `ls connector/connect/server/target/**/spark-connect*SNAPSHOT.jar | paste -sd ',' -` \ + * --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin + * }}} + * + * Set system property `spark.test.home` or env variable `SPARK_HOME` if the test is not executed + * from the Spark project top folder. Set system property `spark.debug.sc.jvm.client=true` or + * environment variable `SPARK_DEBUG_SC_JVM_CLIENT=true` to print the server process output in the + * console to debug server start stop problems. + */ +object SparkConnectServerUtils { + + // Server port + val port: Int = + ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000) + + @volatile private var stopped = false + + private var consoleOut: OutputStream = _ + private val serverStopCommand = "q" + + private lazy val sparkConnect: java.lang.Process = { + debug("Starting the Spark Connect Server...") + val connectJar = findJar( + "connector/connect/server", + "spark-connect-assembly", + "spark-connect").getCanonicalPath + + val command = Seq.newBuilder[String] + command += "bin/spark-submit" + command += "--driver-class-path" += connectJar + command += "--class" += "org.apache.spark.sql.connect.SimpleSparkConnectService" + command += "--conf" += s"spark.connect.grpc.binding.port=$port" + command ++= testConfigs + command ++= debugConfigs + command += connectJar + val builder = new ProcessBuilder(command.result(): _*) + builder.directory(new File(sparkHome)) + val environment = builder.environment() + environment.remove("SPARK_DIST_CLASSPATH") + if (isDebug) { + builder.redirectError(Redirect.INHERIT) + builder.redirectOutput(Redirect.INHERIT) + } + + val process = builder.start() + consoleOut = process.getOutputStream + + // Adding JVM shutdown hook + sys.addShutdownHook(stop()) + process + } + + /** + * As one shared spark will be started for all E2E tests, for tests that needs some special + * configs, we add them here + */ + private def testConfigs: Seq[String] = { + // To find InMemoryTableCatalog for V2 writer tests + val catalystTestJar = + findJar("sql/catalyst", "spark-catalyst", "spark-catalyst", test = true).getCanonicalPath + + val catalogImplementation = if (IntegrationTestUtils.isSparkHiveJarAvailable) { + "hive" + } else { + // scalastyle:off println + println( + "Will start Spark Connect server with `spark.sql.catalogImplementation=in-memory`, " + + "some tests that rely on Hive will be ignored. If you don't want to skip them:\n" + + "1. Test with maven: run `build/mvn install -DskipTests -Phive` before testing\n" + + "2. Test with sbt: run test with `-Phive` profile") + // scalastyle:on println + // SPARK-43647: Proactively cleaning the `classes` and `test-classes` dir of hive + // module to avoid unexpected loading of `DataSourceRegister` in hive module during + // testing without `-Phive` profile. + IntegrationTestUtils.cleanUpHiveClassesDirIfNeeded() + "in-memory" + } + val confs = Seq( + // Use InMemoryTableCatalog for V2 writer tests + "spark.sql.catalog.testcat=org.apache.spark.sql.connector.catalog.InMemoryTableCatalog", + // Try to use the hive catalog, fallback to in-memory if it is not there. + "spark.sql.catalogImplementation=" + catalogImplementation, + // Make the server terminate reattachable streams every 1 second and 123 bytes, + // to make the tests exercise reattach. + "spark.connect.execute.reattachable.senderMaxStreamDuration=1s", + "spark.connect.execute.reattachable.senderMaxStreamSize=123", + // Disable UI + "spark.ui.enabled=false") + Seq("--jars", catalystTestJar) ++ confs.flatMap(v => "--conf" :: v :: Nil) + } + + def start(): Unit = { + assert(!stopped) + sparkConnect + } + + def stop(): Int = { + stopped = true + debug("Stopping the Spark Connect Server...") + try { + consoleOut.write(serverStopCommand.getBytes) + consoleOut.flush() + consoleOut.close() + if (!sparkConnect.waitFor(2, TimeUnit.SECONDS)) { + sparkConnect.destroyForcibly() + } + val code = sparkConnect.exitValue() + debug(s"Spark Connect Server is stopped with exit code: $code") + code + } catch { + case e: IOException if e.getMessage.contains("Stream closed") => + -1 + case e: Throwable => + debug(e) + sparkConnect.destroyForcibly() + throw e + } + } + + def syncTestDependencies(spark: SparkSession): Unit = { + // Both SBT & Maven pass the test-classes as a directory instead of a jar. + val testClassesPath = Paths.get(IntegrationTestUtils.connectClientTestClassDir) + spark.client.artifactManager.addClassDir(testClassesPath) + + // We need scalatest & scalactic on the session's classpath to make the tests work. + val jars = System + .getProperty("java.class.path") + .split(File.pathSeparatorChar) + .filter { e: String => + val fileName = e.substring(e.lastIndexOf(File.separatorChar) + 1) + fileName.endsWith(".jar") && + (fileName.startsWith("scalatest") || fileName.startsWith("scalactic")) + } + .map(e => Paths.get(e).toUri) + spark.client.artifactManager.addArtifacts(jars) + } + + def createSparkSession(): SparkSession = { + SparkConnectServerUtils.start() + + val spark = SparkSession + .builder() + .client( + SparkConnectClient + .builder() + .userId("test") + .port(port) + .retryPolicy(RetryPolicy(maxRetries = 7, maxBackoff = FiniteDuration(10, "s"))) + .build()) + .create() + + // Execute an RPC which will get retried until the server is up. + assert(spark.version == SparkBuildInfo.spark_version) + + // Auto-sync dependencies. + SparkConnectServerUtils.syncTestDependencies(spark) + + spark + } +} + +trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll { + import SparkConnectServerUtils._ + var spark: SparkSession = _ + protected lazy val serverPort: Int = port + + override def beforeAll(): Unit = { + super.beforeAll() + spark = createSparkSession() + } + + override def afterAll(): Unit = { + try { + if (spark != null) spark.stop() + } catch { + case e: Throwable => debug(e) + } + spark = null + super.afterAll() + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLHelper.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLHelper.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala index f357270e20f..12212492e37 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLHelper.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala @@ -14,13 +14,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.test import java.io.File import java.util.UUID import org.scalatest.Assertions.fail +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.util.{SparkErrorUtils, SparkFileUtils} trait SQLHelper { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala index 136a31fca3c..b1a7746a84a 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala @@ -145,10 +145,28 @@ class ArtifactManager( addArtifacts(classFinders.asScala.flatMap(_.findClasses())) } + private[sql] def addClassDir(base: Path): Unit = { + if (!Files.isDirectory(base)) { + return + } + val builder = Seq.newBuilder[Artifact] + val stream = Files.walk(base) + try { + stream.forEach { path => + if (Files.isRegularFile(path) && path.toString.endsWith(".class")) { + builder += Artifact.newClassArtifact(base.relativize(path), new LocalFile(path)) + } + } + } finally { + stream.close() + } + addArtifacts(builder.result()) + } + /** * Add a number of artifacts to the session. */ - private def addArtifacts(artifacts: Iterable[Artifact]): Unit = { + private[client] def addArtifacts(artifacts: Iterable[Artifact]): Unit = { if (artifacts.isEmpty) { return } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala index 8b6f070b8f5..a6841e7f118 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala @@ -26,7 +26,7 @@ import io.grpc.stub.StreamObserver import org.apache.spark.internal.Logging -private[client] class GrpcRetryHandler( +private[sql] class GrpcRetryHandler( private val retryPolicy: GrpcRetryHandler.RetryPolicy, private val sleep: Long => Unit = Thread.sleep) { @@ -146,7 +146,7 @@ private[client] class GrpcRetryHandler( } } -private[client] object GrpcRetryHandler extends Logging { +private[sql] object GrpcRetryHandler extends Logging { /** * Retries the given function with exponential backoff according to the client's retryPolicy. diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index c41f6dfaae1..a0853cc0621 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -58,7 +58,7 @@ private[sql] class SparkConnectClient( // a new client will create a new session ID. private[sql] val sessionId: String = configuration.sessionId.getOrElse(UUID.randomUUID.toString) - private[client] val artifactManager: ArtifactManager = { + private[sql] val artifactManager: ArtifactManager = { new ArtifactManager(configuration, sessionId, bstub, stub) } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoDataTypes.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoDataTypes.scala index 19890558ab2..e85a2a40da2 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoDataTypes.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoDataTypes.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.connect.common import org.apache.spark.connect.proto -private[connect] object ProtoDataTypes { +private[sql] object ProtoDataTypes { val NullType: proto.DataType = proto.DataType .newBuilder() diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala index 3f594d79b62..dca65cf905f 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.connect.common.config -private[connect] object ConnectCommon { +private[sql] object ConnectCommon { val CONNECT_GRPC_BINDING_PORT: Int = 15002 val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024; } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala index c1dd7820c55..a2df11eeb58 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath} import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkContext, SparkEnv} import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.CONNECT_SCALA_UDF_STUB_PREFIXES +import org.apache.spark.internal.config.{CONNECT_SCALA_UDF_STUB_PREFIXES, EXECUTOR_USER_CLASS_PATH_FIRST} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.artifact.util.ArtifactUtils import org.apache.spark.sql.connect.config.Connect.CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL @@ -162,15 +162,37 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging */ def classloader: ClassLoader = { val urls = getSparkConnectAddedJars :+ classDir.toUri.toURL - val loader = if (SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES).nonEmpty) { - val stubClassLoader = - StubClassLoader(null, SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES)) - new ChildFirstURLClassLoader( - urls.toArray, - stubClassLoader, - Utils.getContextOrSparkClassLoader) + val prefixes = SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES) + val userClasspathFirst = SparkEnv.get.conf.get(EXECUTOR_USER_CLASS_PATH_FIRST) + val loader = if (prefixes.nonEmpty) { + // Two things you need to know about classloader for all of this to make sense: + // 1. A classloader needs to be able to fully define a class. + // 2. Classes are loaded lazily. Only when a class is used the classes it references are + // loaded. + // This makes stubbing a bit more complicated then you'd expect. We cannot put the stubbing + // classloader as a fallback at the end of the loading process, because then classes that + // have been found in one of the parent classloaders and that contain a reference to a + // missing, to-be-stubbed missing class will still fail with classloading errors later on. + // The way we currently fix this is by making the stubbing class loader the last classloader + // it delegates to. + if (userClasspathFirst) { + // USER -> SYSTEM -> STUB + new ChildFirstURLClassLoader( + urls.toArray, + StubClassLoader(Utils.getContextOrSparkClassLoader, prefixes)) + } else { + // SYSTEM -> USER -> STUB + new ChildFirstURLClassLoader( + urls.toArray, + StubClassLoader(null, prefixes), + Utils.getContextOrSparkClassLoader) + } } else { - new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader) + if (userClasspathFirst) { + new ChildFirstURLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader) + } else { + new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader) + } } logDebug(s"Using class loader: $loader, containing urls: $urls") diff --git a/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala b/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala index e27376e2b83..8d903c2a3e4 100644 --- a/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala @@ -18,6 +18,8 @@ package org.apache.spark.util import org.apache.xbean.asm9.{ClassWriter, Opcodes} +import org.apache.spark.internal.Logging + /** * [[ClassLoader]] that replaces missing classes with stubs, if the cannot be found. It will only * do this for classes that are marked for stubbing. @@ -27,11 +29,12 @@ import org.apache.xbean.asm9.{ClassWriter, Opcodes} * the class and therefor is safe to replace by a stub. */ class StubClassLoader(parent: ClassLoader, shouldStub: String => Boolean) - extends ClassLoader(parent) { + extends ClassLoader(parent) with Logging { override def findClass(name: String): Class[_] = { if (!shouldStub(name)) { throw new ClassNotFoundException(name) } + logDebug(s"Generating stub for $name") val bytes = StubClassLoader.generateStub(name) defineClass(name, bytes, 0, bytes.length) } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4b6f617e68f..563d5357754 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -858,7 +858,6 @@ object SparkConnectClient { "com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf" ) }, - dependencyOverrides ++= { val guavaVersion = SbtPomKeys.effectivePom.value.getProperties.get( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org