This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 3bca65a68ef [SPARK-42482][CONNECT] Scala Client Write API V1 3bca65a68ef is described below commit 3bca65a68ef9c0716eef3e7a965fe280ba751673 Author: Zhen Li <zhenli...@users.noreply.github.com> AuthorDate: Sun Feb 19 13:04:37 2023 -0400 [SPARK-42482][CONNECT] Scala Client Write API V1 ### What changes were proposed in this pull request? Implemented the basic Dataset#write API to allow users to write the df into tables, csv etc. files. ### Why are the changes needed? Basic write operation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Integration tests. Closes #40061 from zhenlineo/write. Authored-by: Zhen Li <zhenli...@users.noreply.github.com> Signed-off-by: Herman van Hovell <her...@databricks.com> (cherry picked from commit ede1a541182e043b1f79a9ffbfc4a7fa97604078) Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../main/java/org/apache/spark/sql/SaveMode.java | 58 +++ .../org/apache/spark/sql/DataFrameWriter.scala | 457 +++++++++++++++++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 10 + .../scala/org/apache/spark/sql/SparkSession.scala | 7 + .../org/apache/spark/sql/ClientE2ETestSuite.scala | 51 ++- .../scala/org/apache/spark/sql/DatasetSuite.scala | 32 ++ .../sql/connect/client/CompatibilitySuite.scala | 4 +- .../connect/client/SparkConnectClientSuite.scala | 16 +- 8 files changed, 626 insertions(+), 9 deletions(-) diff --git a/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/SaveMode.java b/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/SaveMode.java new file mode 100644 index 00000000000..95af157687c --- /dev/null +++ b/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/SaveMode.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql; + +import org.apache.spark.annotation.Stable; + +/** + * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source. + * + * @since 3.4.0 + */ +@Stable +public enum SaveMode { + /** + * Append mode means that when saving a DataFrame to a data source, if data/table already exists, + * contents of the DataFrame are expected to be appended to existing data. + * + * @since 3.4.0 + */ + Append, + /** + * Overwrite mode means that when saving a DataFrame to a data source, + * if data/table already exists, existing data is expected to be overwritten by the contents of + * the DataFrame. + * + * @since 3.4.0 + */ + Overwrite, + /** + * ErrorIfExists mode means that when saving a DataFrame to a data source, if data already exists, + * an exception is expected to be thrown. + * + * @since 3.4.0 + */ + ErrorIfExists, + /** + * Ignore mode means that when saving a DataFrame to a data source, if data already exists, + * the save operation is expected to not save the contents of the DataFrame and to not + * change the existing data. + * + * @since 3.4.0 + */ + Ignore +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala new file mode 100644 index 00000000000..b7c4ed7bcab --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -0,0 +1,457 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.Locale + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Stable +import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap + +/** + * Interface used to write a [[Dataset]] to external storage systems (e.g. file systems, key-value + * stores, etc). Use `Dataset.write` to access this. + * + * @since 3.4.0 + */ +@Stable +final class DataFrameWriter[T] private[sql] (ds: Dataset[T]) { + + /** + * Specifies the behavior when data or table already exists. Options include: <ul> + * <li>`SaveMode.Overwrite`: overwrite the existing data.</li> <li>`SaveMode.Append`: append the + * data.</li> <li>`SaveMode.Ignore`: ignore the operation (i.e. no-op).</li> + * <li>`SaveMode.ErrorIfExists`: throw an exception at runtime.</li> </ul> <p> The default + * option is `ErrorIfExists`. + * + * @since 3.4.0 + */ + def mode(saveMode: SaveMode): DataFrameWriter[T] = { + this.mode = saveMode + this + } + + /** + * Specifies the behavior when data or table already exists. Options include: <ul> + * <li>`overwrite`: overwrite the existing data.</li> <li>`append`: append the data.</li> + * <li>`ignore`: ignore the operation (i.e. no-op).</li> <li>`error` or `errorifexists`: default + * option, throw an exception at runtime.</li> </ul> + * + * @since 3.4.0 + */ + def mode(saveMode: String): DataFrameWriter[T] = { + saveMode.toLowerCase(Locale.ROOT) match { + case "overwrite" => mode(SaveMode.Overwrite) + case "append" => mode(SaveMode.Append) + case "ignore" => mode(SaveMode.Ignore) + case "error" | "errorifexists" | "default" => mode(SaveMode.ErrorIfExists) + case _ => + throw new IllegalArgumentException(s"Unknown save mode: $saveMode. Accepted " + + "save modes are 'overwrite', 'append', 'ignore', 'error', 'errorifexists', 'default'.") + } + } + + /** + * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. + * + * @since 3.4.0 + */ + def format(source: String): DataFrameWriter[T] = { + this.source = Some(source) + this + } + + /** + * Adds an output option for the underlying data source. + * + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. + * + * @since 3.4.0 + */ + def option(key: String, value: String): DataFrameWriter[T] = { + this.extraOptions = this.extraOptions + (key -> value) + this + } + + /** + * Adds an output option for the underlying data source. + * + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. + * + * @since 3.4.0 + */ + def option(key: String, value: Boolean): DataFrameWriter[T] = option(key, value.toString) + + /** + * Adds an output option for the underlying data source. + * + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. + * + * @since 3.4.0 + */ + def option(key: String, value: Long): DataFrameWriter[T] = option(key, value.toString) + + /** + * Adds an output option for the underlying data source. + * + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. + * + * @since 3.4.0 + */ + def option(key: String, value: Double): DataFrameWriter[T] = option(key, value.toString) + + /** + * (Scala-specific) Adds output options for the underlying data source. + * + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. + * + * @since 3.4.0 + */ + def options(options: scala.collection.Map[String, String]): DataFrameWriter[T] = { + this.extraOptions ++= options + this + } + + /** + * Adds output options for the underlying data source. + * + * All options are maintained in a case-insensitive way in terms of key names. If a new option + * has the same key case-insensitively, it will override the existing option. + * + * @since 3.4.0 + */ + def options(options: java.util.Map[String, String]): DataFrameWriter[T] = { + this.options(options.asScala) + this + } + + /** + * Partitions the output by the given columns on the file system. If specified, the output is + * laid out on the file system similar to Hive's partitioning scheme. As an example, when we + * partition a dataset by year and then month, the directory layout would look like: <ul> + * <li>year=2016/month=01/</li> <li>year=2016/month=02/</li> </ul> + * + * Partitioning is one of the most widely used techniques to optimize physical data layout. It + * provides a coarse-grained index for skipping unnecessary data reads when queries have + * predicates on the partitioned columns. In order for partitioning to work well, the number of + * distinct values in each column should typically be less than tens of thousands. + * + * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark + * 2.1.0. + * + * @since 3.4.0 + */ + @scala.annotation.varargs + def partitionBy(colNames: String*): DataFrameWriter[T] = { + this.partitioningColumns = Option(colNames) + this + } + + /** + * Buckets the output by the given columns. If specified, the output is laid out on the file + * system similar to Hive's bucketing scheme, but with a different bucket hash function and is + * not compatible with Hive's bucketing. + * + * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark + * 2.1.0. + * + * @since 3.4.0 + */ + @scala.annotation.varargs + def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter[T] = { + require(numBuckets > 0, "The numBuckets should be > 0.") + this.numBuckets = Option(numBuckets) + this.bucketColumnNames = Option(colName +: colNames) + this + } + + /** + * Sorts the output in each bucket by the given columns. + * + * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark + * 2.1.0. + * + * @since 3.4.0 + */ + @scala.annotation.varargs + def sortBy(colName: String, colNames: String*): DataFrameWriter[T] = { + this.sortColumnNames = Option(colName +: colNames) + this + } + + /** + * Saves the content of the `DataFrame` at the specified path. + * + * @since 3.4.0 + */ + def save(path: String): Unit = { + saveInternal(Some(path)) + } + + /** + * Saves the content of the `DataFrame` as the specified table. + * + * @since 3.4.0 + */ + def save(): Unit = saveInternal(None) + + private def saveInternal(path: Option[String]): Unit = { + executeWriteOperation(builder => path.foreach(builder.setPath)) + } + + private def executeWriteOperation(f: proto.WriteOperation.Builder => Unit): Unit = { + val builder = proto.WriteOperation.newBuilder() + + builder.setInput(ds.plan.getRoot) + + // Set path or table + f(builder) + require(builder.hasPath != builder.hasTable) // Only one can be set + + builder.setMode(mode match { + case SaveMode.Append => proto.WriteOperation.SaveMode.SAVE_MODE_APPEND + case SaveMode.Overwrite => proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE + case SaveMode.Ignore => proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE + case SaveMode.ErrorIfExists => proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS + }) + + source.foreach(builder.setSource) + sortColumnNames.foreach(names => builder.addAllSortColumnNames(names.asJava)) + partitioningColumns.foreach(cols => builder.addAllPartitioningColumns(cols.asJava)) + + numBuckets.foreach(n => { + val bucketBuilder = proto.WriteOperation.BucketBy.newBuilder() + bucketBuilder.setNumBuckets(n) + bucketColumnNames.foreach(names => bucketBuilder.addAllBucketColumnNames(names.asJava)) + builder.setBucketBy(bucketBuilder) + }) + + extraOptions.foreach { case (k, v) => + builder.putOptions(k, v) + } + + ds.session.execute(proto.Command.newBuilder().setWriteOperation(builder).build()) + } + + /** + * Inserts the content of the `DataFrame` to the specified table. It requires that the schema of + * the `DataFrame` is the same as the schema of the table. + * + * @note + * Unlike `saveAsTable`, `insertInto` ignores the column names and just uses position-based + * resolution. For example: + * + * @note + * SaveMode.ErrorIfExists and SaveMode.Ignore behave as SaveMode.Append in `insertInto` as + * `insertInto` is not a table creating operation. + * + * {{{ + * scala> Seq((1, 2)).toDF("i", "j").write.mode("overwrite").saveAsTable("t1") + * scala> Seq((3, 4)).toDF("j", "i").write.insertInto("t1") + * scala> Seq((5, 6)).toDF("a", "b").write.insertInto("t1") + * scala> sql("select * from t1").show + * +---+---+ + * | i| j| + * +---+---+ + * | 5| 6| + * | 3| 4| + * | 1| 2| + * +---+---+ + * }}} + * + * Because it inserts data to an existing table, format or options will be ignored. + * + * @since 3.4.0 + */ + def insertInto(tableName: String): Unit = { + executeWriteOperation(builder => { + builder.setTable( + proto.WriteOperation.SaveTable + .newBuilder() + .setTableName(tableName) + .setSaveMethod( + proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO)) + }) + } + + /** + * Saves the content of the `DataFrame` as the specified table. + * + * In the case the table already exists, behavior of this function depends on the save mode, + * specified by the `mode` function (default to throwing an exception). When `mode` is + * `Overwrite`, the schema of the `DataFrame` does not need to be the same as that of the + * existing table. + * + * When `mode` is `Append`, if there is an existing table, we will use the format and options of + * the existing table. The column order in the schema of the `DataFrame` doesn't need to be same + * as that of the existing table. Unlike `insertInto`, `saveAsTable` will use the column names + * to find the correct column positions. For example: + * + * {{{ + * scala> Seq((1, 2)).toDF("i", "j").write.mode("overwrite").saveAsTable("t1") + * scala> Seq((3, 4)).toDF("j", "i").write.mode("append").saveAsTable("t1") + * scala> sql("select * from t1").show + * +---+---+ + * | i| j| + * +---+---+ + * | 1| 2| + * | 4| 3| + * +---+---+ + * }}} + * + * In this method, save mode is used to determine the behavior if the data source table exists + * in Spark catalog. We will always overwrite the underlying data of data source (e.g. a table + * in JDBC data source) if the table doesn't exist in Spark catalog, and will always append to + * the underlying data of data source if the table already exists. + * + * When the DataFrame is created from a non-partitioned `HadoopFsRelation` with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * + * @since 3.4.0 + */ + def saveAsTable(tableName: String): Unit = { + executeWriteOperation(builder => { + builder.setTable( + proto.WriteOperation.SaveTable + .newBuilder() + .setTableName(tableName) + .setSaveMethod( + proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE)) + }) + } + + /** + * Saves the content of the `DataFrame` in JSON format (<a href="http://jsonlines.org/"> JSON + * Lines text format or newline-delimited JSON</a>) at the specified path. This is equivalent + * to: + * {{{ + * format("json").save(path) + * }}} + * + * You can find the JSON-specific options for writing JSON files in <a + * href="https://spark.apache.org/docs/latest/sql-data-sources-json.html#data-source-option"> + * Data Source Option</a> in the version you use. + * + * @since 3.4.0 + */ + def json(path: String): Unit = { + format("json").save(path) + } + + /** + * Saves the content of the `DataFrame` in Parquet format at the specified path. This is + * equivalent to: + * {{{ + * format("parquet").save(path) + * }}} + * + * Parquet-specific option(s) for writing Parquet files can be found in <a href= + * "https://spark.apache.org/docs/latest/sql-data-sources-parquet.html#data-source-option"> Data + * Source Option</a> in the version you use. + * + * @since 3.4.0 + */ + def parquet(path: String): Unit = { + format("parquet").save(path) + } + + /** + * Saves the content of the `DataFrame` in ORC format at the specified path. This is equivalent + * to: + * {{{ + * format("orc").save(path) + * }}} + * + * ORC-specific option(s) for writing ORC files can be found in <a href= + * "https://spark.apache.org/docs/latest/sql-data-sources-orc.html#data-source-option"> Data + * Source Option</a> in the version you use. + * + * @since 3.4.0 + */ + def orc(path: String): Unit = { + format("orc").save(path) + } + + /** + * Saves the content of the `DataFrame` in a text file at the specified path. The DataFrame must + * have only one column that is of string type. Each row becomes a new line in the output file. + * For example: + * {{{ + * // Scala: + * df.write.text("/path/to/output") + * + * // Java: + * df.write().text("/path/to/output") + * }}} + * The text files will be encoded as UTF-8. + * + * You can find the text-specific options for writing text files in <a + * href="https://spark.apache.org/docs/latest/sql-data-sources-text.html#data-source-option"> + * Data Source Option</a> in the version you use. + * + * @since 3.4.0 + */ + def text(path: String): Unit = { + format("text").save(path) + } + + /** + * Saves the content of the `DataFrame` in CSV format at the specified path. This is equivalent + * to: + * {{{ + * format("csv").save(path) + * }}} + * + * You can find the CSV-specific options for writing CSV files in <a + * href="https://spark.apache.org/docs/latest/sql-data-sources-csv.html#data-source-option"> + * Data Source Option</a> in the version you use. + * + * @since 3.4.0 + */ + def csv(path: String): Unit = { + format("csv").save(path) + } + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: Option[String] = None + + private var mode: SaveMode = SaveMode.ErrorIfExists + + private var extraOptions = CaseInsensitiveMap[String](Map.empty) + + private var partitioningColumns: Option[Seq[String]] = None + + private var bucketColumnNames: Option[Seq[String]] = None + + private var numBuckets: Option[Int] = None + + private var sortColumnNames: Option[Seq[String]] = None +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 3c34b45fccb..3c876c05432 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2246,6 +2246,16 @@ class Dataset[T] private[sql] (val session: SparkSession, private[sql] val plan: */ def inputFiles: Array[String] = analyze.getInputFilesList.asScala.toArray + /** + * Interface for saving the content of the non-streaming Dataset out into external storage. + * + * @group basic + * @since 3.4.0 + */ + def write: DataFrameWriter[T] = { + new DataFrameWriter[T](this) + } + private[sql] def analyze: proto.AnalyzePlanResponse = { session.analyze(plan, proto.Explain.ExplainMode.SIMPLE) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 54871c99b56..1761e8ce42d 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql import java.io.Closeable +import scala.collection.JavaConverters._ + import org.apache.arrow.memory.RootAllocator import org.apache.spark.connect.proto @@ -162,6 +164,11 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner: result } + private[sql] def execute(command: proto.Command): Unit = { + val plan = proto.Plan.newBuilder().setCommand(command).build() + client.execute(plan).asScala.foreach(_ => ()) + } + override def close(): Unit = { client.shutdown() allocator.close() diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 058ba1a8efc..145d62feefc 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -21,6 +21,8 @@ import java.io.{ByteArrayOutputStream, PrintStream} import scala.collection.JavaConverters._ import io.grpc.StatusRuntimeException +import java.nio.file.Files +import org.apache.commons.io.FileUtils import org.apache.commons.io.output.TeeOutputStream import org.scalactic.TolerantNumerics @@ -67,7 +69,7 @@ class ClientE2ETestSuite extends RemoteSparkSession { } } - test("read") { + test("read and write") { val testDataPath = java.nio.file.Paths .get( IntegrationTestUtils.sparkHome, @@ -91,11 +93,20 @@ class ClientE2ETestSuite extends RemoteSparkSession { StructField("age", IntegerType) :: StructField("job", StringType) :: Nil)) .load() - val array = df.collectResult().toArray - assert(array.length == 2) - assert(array(0).getString(0) == "Jorge") - assert(array(0).getInt(1) == 30) - assert(array(0).getString(2) == "Developer") + val outputFolderPath = Files.createTempDirectory("output").toAbsolutePath + + df.write + .format("csv") + .mode("overwrite") + .options(Map("header" -> "true", "delimiter" -> ";")) + .save(outputFolderPath.toString) + + // We expect only one csv file saved. + val outputFile = outputFolderPath.toFile + .listFiles() + .filter(file => file.getPath.endsWith(".csv"))(0) + + assert(FileUtils.contentEquals(testDataPath.toFile, outputFile)) } test("read path collision") { @@ -124,7 +135,33 @@ class ClientE2ETestSuite extends RemoteSparkSession { .csv(testDataPath.toString) // Failed because the path cannot be provided both via option and load method (csv). assertThrows[StatusRuntimeException] { - df.collectResult().toArray + df.collect() + } + } + + test("write table") { + try { + val df = spark.range(10).limit(3) + df.write.mode(SaveMode.Overwrite).saveAsTable("myTable") + spark.range(2).write.insertInto("myTable") + val result = spark.sql("select * from myTable").sort("id").collect() + assert(result.length == 5) + assert(result(0).getLong(0) == 0) + assert(result(1).getLong(0) == 0) + assert(result(2).getLong(0) == 1) + assert(result(3).getLong(0) == 1) + assert(result(4).getLong(0) == 2) + } finally { + spark.sql("drop table if exists myTable").collect() + } + } + + test("write path collision") { + val df = spark.range(10) + val outputFolderPath = Files.createTempDirectory("output").toAbsolutePath + // Failed because the path cannot be provided both via option and save method. + assertThrows[StatusRuntimeException] { + df.write.option("path", outputFolderPath.toString).save(outputFolderPath.toString) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 087dcbb360a..66e597f2457 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -115,4 +115,36 @@ class DatasetSuite val actualPlan = service.getAndClearLatestInputPlan() assert(actualPlan.equals(expectedPlan)) } + + test("write") { + val df = ss.newDataset(_ => ()).limit(10) + + val builder = proto.WriteOperation.newBuilder() + builder + .setInput(df.plan.getRoot) + .setPath("my/test/path") + .setMode(proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS) + .setSource("parquet") + .addSortColumnNames("col1") + .addPartitioningColumns("col99") + .setBucketBy( + proto.WriteOperation.BucketBy + .newBuilder() + .setNumBuckets(2) + .addBucketColumnNames("col1") + .addBucketColumnNames("col2")) + + val expectedPlan = proto.Plan + .newBuilder() + .setCommand(proto.Command.newBuilder().setWriteOperation(builder)) + .build() + + df.write + .sortBy("col1") + .partitionBy("col99") + .bucketBy(2, "col1", "col2") + .parquet("my/test/path") + val actualPlan = service.getAndClearLatestInputPlan() + assert(actualPlan.equals(expectedPlan)) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala index fa2cb18cda2..81d58566cd9 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala @@ -76,6 +76,7 @@ class CompatibilitySuite extends AnyFunSuite { // scalastyle:ignore funsuite // IncludeByName("org.apache.spark.sql.Dataset$"), IncludeByName("org.apache.spark.sql.DataFrame"), IncludeByName("org.apache.spark.sql.DataFrameReader"), + IncludeByName("org.apache.spark.sql.DataFrameWriter"), IncludeByName("org.apache.spark.sql.SparkSession"), IncludeByName("org.apache.spark.sql.SparkSession$")) ++ includeImplementedMethods(clientJar) val excludeRules = Seq( @@ -135,7 +136,8 @@ class CompatibilitySuite extends AnyFunSuite { // scalastyle:ignore funsuite // the Dataset methods, as too many overload methods are missing. // "org.apache.spark.sql.Dataset", "org.apache.spark.sql.SparkSession", - "org.apache.spark.sql.DataFrameReader") + "org.apache.spark.sql.DataFrameReader", + "org.apache.spark.sql.DataFrameWriter") val clientClassLoader: URLClassLoader = new URLClassLoader(Seq(clientJar.toURI.toURL).toArray) clsNames diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index f3caba28ffd..908eddbe7bf 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.{AnalyzePlanRequest, AnalyzePlanResponse, SparkConnectServiceGrpc} +import org.apache.spark.connect.proto.{AnalyzePlanRequest, AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc} import org.apache.spark.sql.connect.common.config.ConnectCommon class SparkConnectClientSuite @@ -160,6 +160,20 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer plan } + override def executePlan( + request: ExecutePlanRequest, + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + // Reply with a dummy response using the same client ID + val requestClientId = request.getClientId + inputPlan = request.getPlan + val response = ExecutePlanResponse + .newBuilder() + .setClientId(requestClientId) + .build() + responseObserver.onNext(response) + responseObserver.onCompleted() + } + override def analyzePlan( request: AnalyzePlanRequest, responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org