This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 79092f081aa [SPARK-42580][CONNECT] Scala client add client side typed APIs 79092f081aa is described below commit 79092f081aafd3ac6718df2f4c65475bc8161638 Author: Herman van Hovell <her...@databricks.com> AuthorDate: Mon Feb 27 09:15:09 2023 -0400 [SPARK-42580][CONNECT] Scala client add client side typed APIs ### What changes were proposed in this pull request? This PR adds the client side typed API to the Spark Connect Scala Client. ### Why are the changes needed? We want to reach API parity with the existing APIs. ### Does this PR introduce _any_ user-facing change? Yes, it adds user API. ### How was this patch tested? Added tests to `ClientE2ETestSuite`, and updated existing tests. Closes #40175 from hvanhovell/SPARK-42580. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> (cherry picked from commit 5243d0be2c15e3af36e981a9487ea600ab86a808) Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../org/apache/spark/sql/DataFrameReader.scala | 4 +- .../main/scala/org/apache/spark/sql/Dataset.scala | 122 +++++++++++++-------- .../spark/sql/RelationalGroupedDataset.scala | 2 +- .../scala/org/apache/spark/sql/SparkSession.scala | 29 +++-- .../spark/sql/connect/client/SparkResult.scala | 37 ++++--- .../org/apache/spark/sql/ClientE2ETestSuite.scala | 80 ++++++++++++-- .../scala/org/apache/spark/sql/DatasetSuite.scala | 6 +- .../apache/spark/sql/PlanGenerationTestSuite.scala | 2 +- .../sql/catalyst/encoders/AgnosticEncoder.scala | 11 +- 9 files changed, 204 insertions(+), 89 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 5a486efee31..3e17b03173b 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -171,7 +171,7 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging */ @scala.annotation.varargs def load(paths: String*): DataFrame = { - sparkSession.newDataset { builder => + sparkSession.newDataFrame { builder => val dataSourceBuilder = builder.getReadBuilder.getDataSourceBuilder assertSourceFormatSpecified() dataSourceBuilder.setFormat(source) @@ -308,7 +308,7 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging * @since 3.4.0 */ def table(tableName: String): DataFrame = { - sparkSession.newDataset { builder => + sparkSession.newDataFrame { builder => builder.getReadBuilder.getNamedTableBuilder.setUnparsedIdentifier(tableName) } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index dcc770dfe55..73de35456fc 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -23,6 +23,8 @@ import scala.collection.mutable import scala.util.control.NonFatal import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveLongEncoder, StringEncoder, UnboundRowEncoder} import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.connect.client.SparkResult import org.apache.spark.sql.connect.common.DataTypeProtoConverter @@ -116,7 +118,10 @@ import org.apache.spark.util.Utils * * @since 3.4.0 */ -class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val plan: proto.Plan) +class Dataset[T] private[sql] ( + val sparkSession: SparkSession, + private[sql] val plan: proto.Plan, + val encoder: AgnosticEncoder[T]) extends Serializable { // Make sure we don't forget to set plan id. assert(plan.getRoot.getCommon.hasPlanId) @@ -151,9 +156,32 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @group basic * @since 3.4.0 */ - def toDF(): DataFrame = { - // Note this will change as soon as we add the typed APIs. - this.asInstanceOf[Dataset[Row]] + def toDF(): DataFrame = new Dataset(sparkSession, plan, UnboundRowEncoder) + + /** + * Returns a new Dataset where each record has been mapped on to the specified type. The method + * used to map columns depend on the type of `U`: <ul> <li>When `U` is a class, fields for the + * class will be mapped to columns of the same name (case sensitivity is determined by + * `spark.sql.caseSensitive`).</li> <li>When `U` is a tuple, the columns will be mapped by + * ordinal (i.e. the first column will be assigned to `_1`).</li> <li>When `U` is a primitive + * type (i.e. String, Int, etc), then the first column of the `DataFrame` will be used.</li> + * </ul> + * + * If the schema of the Dataset does not match the desired `U` type, you can use `select` along + * with `alias` or `as` to rearrange or rename as required. + * + * Note that `as[]` only changes the view of the data that is passed into typed operations, such + * as `map()`, and does not eagerly project away any columns that are not present in the + * specified class. + * + * @group basic + * @since 3.4.0 + */ + def as[U: Encoder]: Dataset[U] = { + val encoder = implicitly[Encoder[U]].asInstanceOf[AgnosticEncoder[U]] + // We should add some validation/coercion here. We cannot use `to` + // because that does not work with positional arguments. + new Dataset[U](sparkSession, plan, encoder) } /** @@ -170,7 +198,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @since 3.4.0 */ @scala.annotation.varargs - def toDF(colNames: String*): DataFrame = sparkSession.newDataset { builder => + def toDF(colNames: String*): DataFrame = sparkSession.newDataFrame { builder => builder.getToDfBuilder .setInput(plan.getRoot) .addAllColumnNames(colNames.asJava) @@ -192,7 +220,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @group basic * @since 3.4.0 */ - def to(schema: StructType): DataFrame = sparkSession.newDataset { builder => + def to(schema: StructType): DataFrame = sparkSession.newDataFrame { builder => builder.getToSchemaBuilder .setInput(plan.getRoot) .setSchema(DataTypeProtoConverter.toConnectProtoType(schema)) @@ -205,7 +233,11 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @since 3.4.0 */ def schema: StructType = { - DataTypeProtoConverter.toCatalystType(analyze.getSchema).asInstanceOf[StructType] + if (encoder == UnboundRowEncoder) { + DataTypeProtoConverter.toCatalystType(analyze.getSchema).asInstanceOf[StructType] + } else { + encoder.schema + } } /** @@ -469,7 +501,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @since 3.4.0 */ def show(numRows: Int, truncate: Int, vertical: Boolean): Unit = { - val df = sparkSession.newDataset { builder => + val df = sparkSession.newDataset(StringEncoder) { builder => builder.getShowStringBuilder .setInput(plan.getRoot) .setNumRows(numRows) @@ -480,13 +512,13 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val assert(result.length == 1) assert(result.schema.size == 1) // scalastyle:off println - println(result.toArray.head.getString(0)) + println(result.toArray.head) // scalastyle:on println } } private def buildJoin(right: Dataset[_])(f: proto.Join.Builder => Unit): DataFrame = { - sparkSession.newDataset { builder => + sparkSession.newDataFrame { builder => val joinBuilder = builder.getJoinBuilder joinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot) f(joinBuilder) @@ -752,7 +784,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val } private def buildSort(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { - sparkSession.newDataset { builder => + sparkSession.newDataset(encoder) { builder => builder.getSortBuilder .setInput(plan.getRoot) .setIsGlobal(global) @@ -860,11 +892,12 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @since 3.4.0 */ @scala.annotation.varargs - def hint(name: String, parameters: Any*): Dataset[T] = sparkSession.newDataset { builder => - builder.getHintBuilder - .setInput(plan.getRoot) - .setName(name) - .addAllParameters(parameters.map(p => functions.lit(p).expr).asJava) + def hint(name: String, parameters: Any*): Dataset[T] = sparkSession.newDataset(encoder) { + builder => + builder.getHintBuilder + .setInput(plan.getRoot) + .setName(name) + .addAllParameters(parameters.map(p => functions.lit(p).expr).asJava) } /** @@ -900,7 +933,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @group typedrel * @since 3.4.0 */ - def as(alias: String): Dataset[T] = sparkSession.newDataset { builder => + def as(alias: String): Dataset[T] = sparkSession.newDataset(encoder) { builder => builder.getSubqueryAliasBuilder .setInput(plan.getRoot) .setAlias(alias) @@ -940,7 +973,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @since 3.4.0 */ @scala.annotation.varargs - def select(cols: Column*): DataFrame = sparkSession.newDataset { builder => + def select(cols: Column*): DataFrame = sparkSession.newDataFrame { builder => builder.getProjectBuilder .setInput(plan.getRoot) .addAllExpressions(cols.map(_.expr).asJava) @@ -990,7 +1023,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @group typedrel * @since 3.4.0 */ - def filter(condition: Column): Dataset[T] = sparkSession.newDataset { builder => + def filter(condition: Column): Dataset[T] = sparkSession.newDataset(encoder) { builder => builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr) } @@ -1033,7 +1066,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val ids: Array[Column], valuesOption: Option[Array[Column]], variableColumnName: String, - valueColumnName: String): DataFrame = sparkSession.newDataset { builder => + valueColumnName: String): DataFrame = sparkSession.newDataFrame { builder => val unpivot = builder.getUnpivotBuilder .setInput(plan.getRoot) .addAllIds(ids.toSeq.map(_.expr).asJava) @@ -1423,7 +1456,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @group typedrel * @since 3.4.0 */ - def limit(n: Int): Dataset[T] = sparkSession.newDataset { builder => + def limit(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder => builder.getLimitBuilder .setInput(plan.getRoot) .setLimit(n) @@ -1435,7 +1468,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @group typedrel * @since 3.4.0 */ - def offset(n: Int): Dataset[T] = sparkSession.newDataset { builder => + def offset(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder => builder.getOffsetBuilder .setInput(plan.getRoot) .setOffset(n) @@ -1443,7 +1476,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val private def buildSetOp(right: Dataset[T], setOpType: proto.SetOperation.SetOpType)( f: proto.SetOperation.Builder => Unit): Dataset[T] = { - sparkSession.newDataset { builder => + sparkSession.newDataset(encoder) { builder => f( builder.getSetOpBuilder .setSetOpType(setOpType) @@ -1707,7 +1740,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @since 3.4.0 */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { - sparkSession.newDataset { builder => + sparkSession.newDataset(encoder) { builder => builder.getSampleBuilder .setInput(plan.getRoot) .setWithReplacement(withReplacement) @@ -1775,7 +1808,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val normalizedCumWeights .sliding(2) .map { case Array(low, high) => - sparkSession.newDataset[T] { builder => + sparkSession.newDataset(encoder) { builder => builder.getSampleBuilder .setInput(sortedInput) .setWithReplacement(false) @@ -1819,7 +1852,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val val aliases = values.zip(names).map { case (value, name) => value.name(name).expr.getAlias } - sparkSession.newDataset { builder => + sparkSession.newDataFrame { builder => builder.getWithColumnsBuilder .setInput(plan.getRoot) .addAllAliases(aliases.asJava) @@ -1910,7 +1943,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @since 3.4.0 */ def withColumnsRenamed(colsMap: java.util.Map[String, String]): DataFrame = { - sparkSession.newDataset { builder => + sparkSession.newDataFrame { builder => builder.getWithColumnsRenamedBuilder .setInput(plan.getRoot) .putAllRenameColumnsMap(colsMap) @@ -1929,7 +1962,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val .setExpr(col(columnName).expr) .addName(columnName) .setMetadata(metadata.json) - sparkSession.newDataset { builder => + sparkSession.newDataFrame { builder => builder.getWithColumnsBuilder .setInput(plan.getRoot) .addAliases(newAlias) @@ -2083,7 +2116,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val @scala.annotation.varargs def drop(col: Column, cols: Column*): DataFrame = buildDrop(col +: cols) - private def buildDrop(cols: Seq[Column]): DataFrame = sparkSession.newDataset { builder => + private def buildDrop(cols: Seq[Column]): DataFrame = sparkSession.newDataFrame { builder => builder.getDropBuilder .setInput(plan.getRoot) .addAllCols(cols.map(_.expr).asJava) @@ -2096,7 +2129,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @group typedrel * @since 3.4.0 */ - def dropDuplicates(): Dataset[T] = sparkSession.newDataset { builder => + def dropDuplicates(): Dataset[T] = sparkSession.newDataset(encoder) { builder => builder.getDeduplicateBuilder .setInput(plan.getRoot) .setAllColumnsAsKeys(true) @@ -2109,10 +2142,11 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @group typedrel * @since 3.4.0 */ - def dropDuplicates(colNames: Seq[String]): Dataset[T] = sparkSession.newDataset { builder => - builder.getDeduplicateBuilder - .setInput(plan.getRoot) - .addAllColumnNames(colNames.asJava) + def dropDuplicates(colNames: Seq[String]): Dataset[T] = sparkSession.newDataset(encoder) { + builder => + builder.getDeduplicateBuilder + .setInput(plan.getRoot) + .addAllColumnNames(colNames.asJava) } /** @@ -2166,7 +2200,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @since 3.4.0 */ @scala.annotation.varargs - def describe(cols: String*): DataFrame = sparkSession.newDataset { builder => + def describe(cols: String*): DataFrame = sparkSession.newDataFrame { builder => builder.getDescribeBuilder .setInput(plan.getRoot) .addAllCols(cols.asJava) @@ -2241,7 +2275,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @since 3.4.0 */ @scala.annotation.varargs - def summary(statistics: String*): DataFrame = sparkSession.newDataset { builder => + def summary(statistics: String*): DataFrame = sparkSession.newDataFrame { builder => builder.getSummaryBuilder .setInput(plan.getRoot) .addAllStatistics(statistics.asJava) @@ -2309,7 +2343,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @since 3.4.0 */ def tail(n: Int): Array[T] = { - val lastN = sparkSession.newDataset[T] { builder => + val lastN = sparkSession.newDataset(encoder) { builder => builder.getTailBuilder .setInput(plan.getRoot) .setLimit(n) @@ -2340,7 +2374,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @since 3.4.0 */ def collect(): Array[T] = withResult { result => - result.toArray.asInstanceOf[Array[T]] + result.toArray } /** @@ -2368,7 +2402,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val */ def toLocalIterator(): java.util.Iterator[T] = { // TODO make this a destructive iterator. - collectResult().iterator.asInstanceOf[java.util.Iterator[T]] + collectResult().iterator } /** @@ -2377,11 +2411,11 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val * @since 3.4.0 */ def count(): Long = { - groupBy().count().collect().head.getLong(0) + groupBy().count().as(PrimitiveLongEncoder).collect().head } private def buildRepartition(numPartitions: Int, shuffle: Boolean): Dataset[T] = { - sparkSession.newDataset { builder => + sparkSession.newDataset(encoder) { builder => builder.getRepartitionBuilder .setInput(plan.getRoot) .setNumPartitions(numPartitions) @@ -2391,7 +2425,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val private def buildRepartitionByExpression( numPartitions: Option[Int], - partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset { builder => + partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset(encoder) { builder => val repartitionBuilder = builder.getRepartitionByExpressionBuilder .setInput(plan.getRoot) .addAllPartitionExprs(partitionExprs.map(_.expr).asJava) @@ -2651,9 +2685,9 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val sparkSession.analyze(plan, proto.Explain.ExplainMode.SIMPLE) } - def collectResult(): SparkResult = sparkSession.execute(plan) + def collectResult(): SparkResult[T] = sparkSession.execute(plan, encoder) - private[sql] def withResult[E](f: SparkResult => E): E = { + private[sql] def withResult[E](f: SparkResult[T] => E): E = { val result = collectResult() try f(result) finally { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 76d3ab5cf09..89bc5bfec57 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -42,7 +42,7 @@ class RelationalGroupedDataset protected[sql] ( pivot: Option[proto.Aggregate.Pivot] = None) { private[this] def toDF(aggExprs: Seq[Column]): DataFrame = { - df.sparkSession.newDataset { builder => + df.sparkSession.newDataFrame { builder => builder.getAggregateBuilder .setInput(df.plan.getRoot) .addAllGroupingExpressions(groupingExprs.asJava) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index e85c7008ca9..3aed781855c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -28,6 +28,8 @@ import org.apache.spark.SPARK_VERSION import org.apache.spark.annotation.Experimental import org.apache.spark.connect.proto import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder} import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.client.util.Cleaner @@ -118,7 +120,7 @@ class SparkSession( * @since 3.4.0 */ @Experimental - def sql(sqlText: String, args: java.util.Map[String, String]): DataFrame = newDataset { + def sql(sqlText: String, args: java.util.Map[String, String]): DataFrame = newDataFrame { builder => builder .setSql(proto.SQL.newBuilder().setQuery(sqlText).putAllArgs(args)) @@ -169,7 +171,7 @@ class SparkSession( * * @since 3.4.0 */ - def range(end: Long): Dataset[Row] = range(0, end) + def range(end: Long): Dataset[java.lang.Long] = range(0, end) /** * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a @@ -177,7 +179,7 @@ class SparkSession( * * @since 3.4.0 */ - def range(start: Long, end: Long): Dataset[Row] = { + def range(start: Long, end: Long): Dataset[java.lang.Long] = { range(start, end, step = 1) } @@ -187,7 +189,7 @@ class SparkSession( * * @since 3.4.0 */ - def range(start: Long, end: Long, step: Long): Dataset[Row] = { + def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { range(start, end, step, None) } @@ -197,7 +199,7 @@ class SparkSession( * * @since 3.4.0 */ - def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[Row] = { + def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = { range(start, end, step, Option(numPartitions)) } @@ -221,8 +223,8 @@ class SparkSession( start: Long, end: Long, step: Long, - numPartitions: Option[Int]): Dataset[Row] = { - newDataset { builder => + numPartitions: Option[Int]): Dataset[java.lang.Long] = { + newDataset(BoxedLongEncoder) { builder => val rangeBuilder = builder.getRangeBuilder .setStart(start) .setEnd(end) @@ -231,12 +233,17 @@ class SparkSession( } } - private[sql] def newDataset[T](f: proto.Relation.Builder => Unit): Dataset[T] = { + private[sql] def newDataFrame(f: proto.Relation.Builder => Unit): DataFrame = { + newDataset(UnboundRowEncoder)(f) + } + + private[sql] def newDataset[T](encoder: AgnosticEncoder[T])( + f: proto.Relation.Builder => Unit): Dataset[T] = { val builder = proto.Relation.newBuilder() f(builder) builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement()) val plan = proto.Plan.newBuilder().setRoot(builder).build() - new Dataset[T](this, plan) + new Dataset[T](this, plan, encoder) } private[sql] def newCommand[T](f: proto.Command.Builder => Unit): proto.Command = { @@ -250,9 +257,9 @@ class SparkSession( mode: proto.Explain.ExplainMode): proto.AnalyzePlanResponse = client.analyze(plan, mode) - private[sql] def execute(plan: proto.Plan): SparkResult = { + private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = { val value = client.execute(plan) - val result = new SparkResult(value, allocator) + val result = new SparkResult(value, allocator, encoder) cleaner.register(result) result } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index 317c20cad3e..80db558918b 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -26,26 +26,37 @@ import org.apache.arrow.vector.FieldVector import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.spark.connect.proto -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer import org.apache.spark.sql.connect.client.util.{AutoCloseables, Cleanable} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} -private[sql] class SparkResult( +private[sql] class SparkResult[T]( responses: java.util.Iterator[proto.ExecutePlanResponse], - allocator: BufferAllocator) + allocator: BufferAllocator, + encoder: AgnosticEncoder[T]) extends AutoCloseable with Cleanable { private[this] var numRecords: Int = 0 private[this] var structType: StructType = _ - private[this] var encoder: ExpressionEncoder[Row] = _ + private[this] var boundEncoder: ExpressionEncoder[T] = _ private[this] val batches = mutable.Buffer.empty[ColumnarBatch] + private def createEncoder(schema: StructType): ExpressionEncoder[T] = { + val agnosticEncoder = if (encoder == UnboundRowEncoder) { + // Create a row encoder based on the schema. + RowEncoder.encoderFor(schema).asInstanceOf[AgnosticEncoder[T]] + } else { + encoder + } + ExpressionEncoder(agnosticEncoder) + } + private def processResponses(stopOnFirstNonEmptyResponse: Boolean): Boolean = { while (responses.hasNext) { val response = responses.next() @@ -57,7 +68,7 @@ private[sql] class SparkResult( if (batches.isEmpty) { structType = ArrowUtils.fromArrowSchema(root.getSchema) // TODO: create encoders that directly operate on arrow vectors. - encoder = RowEncoder(structType).resolveAndBind(structType.toAttributes) + boundEncoder = createEncoder(structType).resolveAndBind(structType.toAttributes) } while (reader.loadNextBatch()) { val rowCount = root.getRowCount @@ -108,8 +119,8 @@ private[sql] class SparkResult( /** * Create an Array with the contents of the result. */ - def toArray: Array[Row] = { - val result = new Array[Row](length) + def toArray: Array[T] = { + val result = encoder.clsTag.newArray(length) val rows = iterator var i = 0 while (rows.hasNext) { @@ -123,11 +134,11 @@ private[sql] class SparkResult( /** * Returns an iterator over the contents of the result. */ - def iterator: java.util.Iterator[Row] with AutoCloseable = { - new java.util.Iterator[Row] with AutoCloseable { + def iterator: java.util.Iterator[T] with AutoCloseable = { + new java.util.Iterator[T] with AutoCloseable { private[this] var batchIndex: Int = -1 private[this] var iterator: java.util.Iterator[InternalRow] = Collections.emptyIterator() - private[this] var deserializer: Deserializer[Row] = _ + private[this] var deserializer: Deserializer[T] = _ override def hasNext: Boolean = { if (iterator.hasNext) { return true @@ -142,13 +153,13 @@ private[sql] class SparkResult( batchIndex = nextBatchIndex iterator = batches(nextBatchIndex).rowIterator() if (deserializer == null) { - deserializer = encoder.createDeserializer() + deserializer = boundEncoder.createDeserializer() } } hasNextBatch } - override def next(): Row = { + override def next(): T = { if (!hasNext) { throw new NoSuchElementException } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 122e7d5d271..debb314f8c3 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql import java.io.{ByteArrayOutputStream, PrintStream} +import java.nio.file.Files import scala.collection.JavaConverters._ +import scala.reflect.runtime.universe.TypeTag import io.grpc.StatusRuntimeException -import java.nio.file.Files import org.apache.commons.io.FileUtils import org.apache.commons.io.output.TeeOutputStream import org.scalactic.TolerantNumerics import org.apache.spark.SPARK_VERSION +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession} import org.apache.spark.sql.functions.{aggregate, array, col, lit, rand, sequence, shuffle, transform, udf} import org.apache.spark.sql.types._ @@ -54,9 +56,9 @@ class ClientE2ETestSuite extends RemoteSparkSession { val df = spark.range(10).limit(3) val result = df.collect() assert(result.length == 3) - assert(result(0).getLong(0) == 0) - assert(result(1).getLong(0) == 1) - assert(result(2).getLong(0) == 2) + assert(result(0) == 0) + assert(result(1) == 1) + assert(result(2) == 2) } test("simple udf") { @@ -237,30 +239,40 @@ class ClientE2ETestSuite extends RemoteSparkSession { checkFragments(result, fragmentsToCheck) } - private val simpleSchema = new StructType().add("id", "long", nullable = false) + private val simpleSchema = new StructType().add("value", "long", nullable = true) // Dataset tests test("Dataset inspection") { val df = spark.range(10) - val local = spark.newDataset { builder => + val local = spark.newDataFrame { builder => builder.getLocalRelationBuilder.setSchema(simpleSchema.catalogString) } assert(!df.isLocal) assert(local.isLocal) assert(!df.isStreaming) - assert(df.toString.contains("[id: bigint]")) + assert(df.toString.contains("[value: bigint]")) assert(df.inputFiles.isEmpty) } test("Dataset schema") { val df = spark.range(10) assert(df.schema === simpleSchema) - assert(df.dtypes === Array(("id", "LongType"))) - assert(df.columns === Array("id")) + assert(df.dtypes === Array(("value", "LongType"))) + assert(df.columns === Array("value")) testCapturedStdOut(df.printSchema(), simpleSchema.treeString) testCapturedStdOut(df.printSchema(5), simpleSchema.treeString(5)) } + test("Dataframe schema") { + val df = spark.sql("select * from range(10)") + val expectedSchema = new StructType().add("id", "long", nullable = false) + assert(df.schema === expectedSchema) + assert(df.dtypes === Array(("id", "LongType"))) + assert(df.columns === Array("id")) + testCapturedStdOut(df.printSchema(), expectedSchema.treeString) + testCapturedStdOut(df.printSchema(5), expectedSchema.treeString(5)) + } + test("Dataset explain") { val df = spark.range(10) val simpleExplainFragments = Seq("== Physical Plan ==") @@ -282,9 +294,9 @@ class ClientE2ETestSuite extends RemoteSparkSession { } test("Dataset result collection") { - def checkResult(rows: TraversableOnce[Row], expectedValues: Long*): Unit = { + def checkResult(rows: TraversableOnce[java.lang.Long], expectedValues: Long*): Unit = { rows.toIterator.zipAll(expectedValues.iterator, null, null).foreach { - case (actual, expected) => assert(actual.getLong(0) === expected) + case (actual, expected) => assert(actual === expected) } } val df = spark.range(10) @@ -355,7 +367,11 @@ class ClientE2ETestSuite extends RemoteSparkSession { implicit val tolerance = TolerantNumerics.tolerantDoubleEquality(0.01) val df = spark.range(100) - def checkSample(ds: DataFrame, lower: Double, upper: Double, seed: Long): Unit = { + def checkSample( + ds: Dataset[java.lang.Long], + lower: Double, + upper: Double, + seed: Long): Unit = { assert(ds.plan.getRoot.hasSample) val sample = ds.plan.getRoot.getSample assert(sample.getSeed === seed) @@ -375,6 +391,44 @@ class ClientE2ETestSuite extends RemoteSparkSession { checkSample(datasets.get(3), 6.0 / 10.0, 1.0, 9L) } + test("Dataset count") { + assert(spark.range(10).count() === 10) + } + + // We can remove this as soon this is added to SQLImplicits. + private implicit def newProductEncoder[T <: Product: TypeTag]: Encoder[T] = + ScalaReflection.encoderFor[T] + + test("Dataset collect tuple") { + val result = spark + .range(3) + .select(col("id"), (col("id") % 2).cast("int").as("a"), (col("id") / lit(10.0d)).as("b")) + .as[(Long, Int, Double)] + .collect() + result.zipWithIndex.foreach { case ((id, a, b), i) => + assert(id == i) + assert(a == id % 2) + assert(b == id / 10.0d) + } + } + + test("Dataset collect complex type") { + val result = spark + .range(3) + .select( + (col("id") / lit(10.0d)).as("b"), + col("id"), + lit("world").as("d"), + (col("id") % 2).cast("int").as("a")) + .as[MyType] + .collect() + result.zipWithIndex.foreach { case (MyType(id, a, b), i) => + assert(id == i) + assert(a == id % 2) + assert(b == id / 10.0d) + } + } + test("lambda functions") { // This test is mostly to validate lambda variables are properly resolved. val result = spark @@ -447,3 +501,5 @@ class ClientE2ETestSuite extends RemoteSparkSession { intercept[Exception](spark.conf.set("spark.sql.globalTempDatabase", "/dev/null")) } } + +private[sql] case class MyType(id: Long, a: Double, b: Double) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 9c07c5abe3c..4a26a32353a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -69,7 +69,7 @@ class DatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { } test("write") { - val df = ss.newDataset(_ => ()).limit(10) + val df = ss.newDataFrame(_ => ()).limit(10) val builder = proto.WriteOperation.newBuilder() builder @@ -101,7 +101,7 @@ class DatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { } test("write V2") { - val df = ss.newDataset(_ => ()).limit(10) + val df = ss.newDataFrame(_ => ()).limit(10) val builder = proto.WriteOperationV2.newBuilder() builder @@ -129,7 +129,7 @@ class DatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { } test("Pivot") { - val df = ss.newDataset(_ => ()) + val df = ss.newDataFrame(_ => ()) intercept[IllegalArgumentException] { df.groupBy().pivot(Column("c"), Seq(Column("col"))) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 6a789b1494f..67ea148cb87 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -215,7 +215,7 @@ class PlanGenerationTestSuite private val temporalsSchemaString = temporalsSchema.catalogString - private def createLocalRelation(schema: String): DataFrame = session.newDataset { builder => + private def createLocalRelation(schema: String): DataFrame = session.newDataFrame { builder => // TODO API is not consistent. Now we have two different ways of working with schemas! builder.getLocalRelationBuilder.setSchema(schema) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala index 1a3c1089649..24c8bad5c2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala @@ -107,13 +107,20 @@ object AgnosticEncoders { override def dataType: DataType = schema } - case class RowEncoder(fields: Seq[EncoderField]) extends AgnosticEncoder[Row] { + abstract class BaseRowEncoder extends AgnosticEncoder[Row] { override def isPrimitive: Boolean = false - override val schema: StructType = StructType(fields.map(_.structField)) override def dataType: DataType = schema override def clsTag: ClassTag[Row] = classTag[Row] } + case class RowEncoder(fields: Seq[EncoderField]) extends BaseRowEncoder { + override val schema: StructType = StructType(fields.map(_.structField)) + } + + object UnboundRowEncoder extends BaseRowEncoder { + override val schema: StructType = new StructType() + } + case class JavaBeanEncoder[K]( override val clsTag: ClassTag[K], fields: Seq[EncoderField]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org