This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new beb8928c3a3 [SPARK-42519][CONNECT][TESTS] Add More WriteTo Tests In Spark Connect Client beb8928c3a3 is described below commit beb8928c3a320811c127b37368a43972fb7ad11f Author: Hisoka <fanjiaemi...@qq.com> AuthorDate: Mon Apr 3 14:29:13 2023 +0900 [SPARK-42519][CONNECT][TESTS] Add More WriteTo Tests In Spark Connect Client ### What changes were proposed in this pull request? Add more WriteTo tests for Spark Connect Client ### Why are the changes needed? Improve Test Case, remove same todo ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Add new tests Closes #40564 from Hisoka-X/connec_test. Authored-by: Hisoka <fanjiaemi...@qq.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit 41993250aa4943ee935376e4eba7e6e48430d298) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../org/apache/spark/sql/ClientE2ETestSuite.scala | 120 ++++++++++++++++----- .../connect/client/util/IntegrationTestUtils.scala | 11 +- .../connect/client/util/RemoteSparkSession.scala | 10 +- project/SparkBuild.scala | 1 + .../sql/connector/catalog/InMemoryBaseTable.scala | 5 +- .../sql/connector/catalog/InMemoryTable.scala | 10 +- 6 files changed, 118 insertions(+), 39 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 605b15123c6..ee7117552c8 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -220,41 +220,109 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper { } } + test("writeTo with create") { + withTable("testcat.myTableV2") { + + val rows = Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")) + + val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType))) + + spark.createDataFrame(rows.asJava, schema).writeTo("testcat.myTableV2").create() + + val outputRows = spark.table("testcat.myTableV2").collect() + assert(outputRows.length == 3) + } + } + test("writeTo with create and using") { - // TODO (SPARK-42519): Add more test after we can set configs. See more WriteTo test cases - // in SparkConnectProtoSuite. - // e.g. spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) - withTable("myTableV2") { - spark.range(3).writeTo("myTableV2").using("parquet").create() - val result = spark.sql("select * from myTableV2").sort("id").collect() - assert(result.length == 3) - assert(result(0).getLong(0) == 0) - assert(result(1).getLong(0) == 1) - assert(result(2).getLong(0) == 2) + withTable("testcat.myTableV2") { + val rows = Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")) + + val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType))) + + spark.createDataFrame(rows.asJava, schema).writeTo("testcat.myTableV2").create() + val outputRows = spark.table("testcat.myTableV2").collect() + assert(outputRows.length == 3) + + val columns = spark.table("testcat.myTableV2").columns + assert(columns.length == 2) + + val sqlOutputRows = spark.sql("select * from testcat.myTableV2").collect() + assert(outputRows.length == 3) + assert(sqlOutputRows(0).schema == schema) + assert(sqlOutputRows(1).getString(1) == "b") } } - // TODO (SPARK-42519): Revisit this test after we can set configs. - // e.g. spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) test("writeTo with create and append") { - withTable("myTableV2") { - spark.range(3).writeTo("myTableV2").using("parquet").create() - withTable("myTableV2") { - assertThrows[StatusRuntimeException] { - // Failed to append as Cannot write into v1 table: `spark_catalog`.`default`.`mytablev2`. - spark.range(3).writeTo("myTableV2").append() - } + withTable("testcat.myTableV2") { + + val rows = Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")) + + val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType))) + + spark.sql("CREATE TABLE testcat.myTableV2 (id bigint, data string) USING foo") + + assert(spark.table("testcat.myTableV2").collect().isEmpty) + + spark.createDataFrame(rows.asJava, schema).writeTo("testcat.myTableV2").append() + val outputRows = spark.table("testcat.myTableV2").collect() + assert(outputRows.length == 3) + } + } + + test("WriteTo with overwrite") { + withTable("testcat.myTableV2") { + + val rows1 = (1L to 3L).map { i => + Row(i, "" + (i - 1 + 'a')) + } + val rows2 = (4L to 7L).map { i => + Row(i, "" + (i - 1 + 'a')) } + + val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType))) + + spark.sql( + "CREATE TABLE testcat.myTableV2 (id bigint, data string) USING foo PARTITIONED BY (id)") + + assert(spark.table("testcat.myTableV2").collect().isEmpty) + + spark.createDataFrame(rows1.asJava, schema).writeTo("testcat.myTableV2").append() + val outputRows = spark.table("testcat.myTableV2").collect() + assert(outputRows.length == 3) + + spark + .createDataFrame(rows2.asJava, schema) + .writeTo("testcat.myTableV2") + .overwrite(functions.expr("true")) + val outputRows2 = spark.table("testcat.myTableV2").collect() + assert(outputRows2.length == 4) + } } - // TODO (SPARK-42519): Revisit this test after we can set configs. - // e.g. spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) - test("writeTo with create") { - assume(IntegrationTestUtils.isSparkHiveJarAvailable) - withTable("myTableV2") { - // Failed to create as Hive support is required. - spark.range(3).writeTo("myTableV2").create() + test("WriteTo with overwritePartitions") { + withTable("testcat.myTableV2") { + + val rows = (4L to 7L).map { i => + Row(i, "" + (i - 1 + 'a')) + } + + val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType))) + + spark.sql( + "CREATE TABLE testcat.myTableV2 (id bigint, data string) USING foo PARTITIONED BY (id)") + + assert(spark.table("testcat.myTableV2").collect().isEmpty) + + spark + .createDataFrame(rows.asJava, schema) + .writeTo("testcat.myTableV2") + .overwritePartitions() + val outputRows = spark.table("testcat.myTableV2").collect() + assert(outputRows.length == 4) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala index a98f7e9c13b..408caa58534 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala @@ -65,7 +65,11 @@ object IntegrationTestUtils { * @return * the jar */ - private[sql] def findJar(path: String, sbtName: String, mvnName: String): File = { + private[sql] def findJar( + path: String, + sbtName: String, + mvnName: String, + test: Boolean = false): File = { val targetDir = new File(new File(sparkHome, path), "target") assert( targetDir.exists(), @@ -73,14 +77,15 @@ object IntegrationTestUtils { s"SPARK_HOME='${new File(sparkHome).getCanonicalPath}'. " + "Make sure the spark project jars has been built (e.g. using build/sbt package)" + "and the env variable `SPARK_HOME` is set correctly.") + val suffix = if (test) "-tests.jar" else ".jar" val jars = recursiveListFiles(targetDir).filter { f => // SBT jar (f.getParentFile.getName == scalaDir && - f.getName.startsWith(sbtName) && f.getName.endsWith(".jar")) || + f.getName.startsWith(sbtName) && f.getName.endsWith(suffix)) || // Maven Jar (f.getParent.endsWith("target") && f.getName.startsWith(mvnName) && - f.getName.endsWith(s"${org.apache.spark.SPARK_VERSION}.jar")) + f.getName.endsWith(s"${org.apache.spark.SPARK_VERSION}$suffix")) } // It is possible we found more than one: one built by maven, and another by SBT assert(jars.nonEmpty, s"Failed to find the jar inside folder: ${targetDir.getCanonicalPath}") diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala index d1a34603f48..43bf722020c 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala @@ -58,10 +58,12 @@ object SparkConnectServerUtils { private lazy val sparkConnect: Process = { debug("Starting the Spark Connect Server...") - val jar = findJar( + val connectJar = findJar( "connector/connect/server", "spark-connect-assembly", "spark-connect").getCanonicalPath + val driverClassPath = connectJar + ":" + + findJar("sql/catalyst", "spark-catalyst", "spark-catalyst", test = true).getCanonicalPath val catalogImplementation = if (IntegrationTestUtils.isSparkHiveJarAvailable) { "hive" } else { @@ -78,16 +80,16 @@ object SparkConnectServerUtils { Seq( "bin/spark-submit", "--driver-class-path", - jar, + driverClassPath, "--conf", s"spark.connect.grpc.binding.port=$port", "--conf", - "spark.sql.catalog.testcat=org.apache.spark.sql.connect.catalog.InMemoryTableCatalog", + "spark.sql.catalog.testcat=org.apache.spark.sql.connector.catalog.InMemoryTableCatalog", "--conf", s"spark.sql.catalogImplementation=$catalogImplementation", "--class", "org.apache.spark.sql.connect.SimpleSparkConnectService", - jar), + connectJar), new File(sparkHome)) val io = new ProcessIO( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b62de6a1629..e7854cd539d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -854,6 +854,7 @@ object SparkConnectClient { buildTestDeps := { (LocalProject("assembly") / Compile / Keys.`package`).value + (LocalProject("catalyst") / Test / Keys.`package`).value }, // SPARK-42538: Make sure the `${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars` is available for testing. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index cd7d80a8296..236fb7a6dbc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -25,7 +25,6 @@ import java.util.OptionalLong import scala.collection.mutable import com.google.common.base.Objects -import org.scalatest.Assertions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow} @@ -433,7 +432,9 @@ abstract class InMemoryBaseTable( protected var streamingWriter: StreamingWrite = StreamingAppend override def overwriteDynamicPartitions(): WriteBuilder = { - assert(writer == Append) + if (writer != Append) { + throw new IllegalArgumentException(s"Unsupported writer type: $writer") + } writer = DynamicOverwrite streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions") this diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 318248dae05..ee6b3c3d9a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.connector.catalog import java.util -import org.scalatest.Assertions.assert - import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions.{SortOrder, Transform} import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwrite, WriteBuilder, WriterCommitMessage} @@ -89,14 +87,18 @@ class InMemoryTable( with SupportsOverwrite { override def truncate(): WriteBuilder = { - assert(writer == Append) + if (writer != Append) { + throw new IllegalArgumentException(s"Unsupported writer type: $writer") + } writer = TruncateAndAppend streamingWriter = StreamingTruncateAndAppend this } override def overwrite(filters: Array[Filter]): WriteBuilder = { - assert(writer == Append) + if (writer != Append) { + throw new IllegalArgumentException(s"Unsupported writer type: $writer") + } writer = new Overwrite(filters) streamingWriter = new StreamingNotSupportedOperation( s"overwrite (${filters.mkString("filters(", ", ", ")")})") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org