This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 6eca5da8d3f [SPARK-44720][CONNECT] Make Dataset use Encoder instead of 
AgnosticEncoder
6eca5da8d3f is described below

commit 6eca5da8d3fba6d1e385f06494030996241937fa
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Wed Aug 9 15:58:18 2023 +0200

    [SPARK-44720][CONNECT] Make Dataset use Encoder instead of AgnosticEncoder
    
    ### What changes were proposed in this pull request?
    Make the Spark Connect Dataset use Encoder instead of AgnosticEncoder
    
    ### Why are the changes needed?
    We want to improve binary compatibility between the Spark Connect Scala 
Client and the original sql/core APIs.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. It changes the type of `Dataset.encoder` from `AgnosticEncoder` to 
`Encoder`.
    
    ### How was this patch tested?
    Existing tests.
    
    Closes #42396 from hvanhovell/SPARK-44720.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
    (cherry picked from commit be9ffb37585fe421705ceaa52fe49b89c50703a3)
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 87 ++++++++++++----------
 .../apache/spark/sql/KeyValueGroupedDataset.scala  |  6 +-
 .../spark/sql/streaming/DataStreamWriter.scala     |  2 +-
 .../CheckConnectJvmClientCompatibility.scala       |  3 -
 4 files changed, 50 insertions(+), 48 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 5f263903c8b..2d72ea6bda8 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -128,11 +128,13 @@ import org.apache.spark.util.SparkClassUtils
 class Dataset[T] private[sql] (
     val sparkSession: SparkSession,
     @DeveloperApi val plan: proto.Plan,
-    val encoder: AgnosticEncoder[T])
+    val encoder: Encoder[T])
     extends Serializable {
   // Make sure we don't forget to set plan id.
   assert(plan.getRoot.getCommon.hasPlanId)
 
+  private[sql] val agnosticEncoder: AgnosticEncoder[T] = encoderFor(encoder)
+
   override def toString: String = {
     try {
       val builder = new mutable.StringBuilder
@@ -828,7 +830,7 @@ class Dataset[T] private[sql] (
   }
 
   private def buildSort(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = 
{
-    sparkSession.newDataset(encoder) { builder =>
+    sparkSession.newDataset(agnosticEncoder) { builder =>
       builder.getSortBuilder
         .setInput(plan.getRoot)
         .setIsGlobal(global)
@@ -878,8 +880,8 @@ class Dataset[T] private[sql] (
       ProductEncoder[(T, U)](
         
ClassTag(SparkClassUtils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple2")),
         Seq(
-          EncoderField(s"_1", this.encoder, leftNullable, Metadata.empty),
-          EncoderField(s"_2", other.encoder, rightNullable, Metadata.empty)))
+          EncoderField(s"_1", this.agnosticEncoder, leftNullable, 
Metadata.empty),
+          EncoderField(s"_2", other.agnosticEncoder, rightNullable, 
Metadata.empty)))
 
     sparkSession.newDataset(tupleEncoder) { builder =>
       val joinBuilder = builder.getJoinBuilder
@@ -889,8 +891,8 @@ class Dataset[T] private[sql] (
         .setJoinType(joinTypeValue)
         .setJoinCondition(condition.expr)
         .setJoinDataType(joinBuilder.getJoinDataTypeBuilder
-          .setIsLeftStruct(this.encoder.isStruct)
-          .setIsRightStruct(other.encoder.isStruct))
+          .setIsLeftStruct(this.agnosticEncoder.isStruct)
+          .setIsRightStruct(other.agnosticEncoder.isStruct))
     }
   }
 
@@ -1010,13 +1012,13 @@ class Dataset[T] private[sql] (
    * @since 3.4.0
    */
   @scala.annotation.varargs
-  def hint(name: String, parameters: Any*): Dataset[T] = 
sparkSession.newDataset(encoder) {
-    builder =>
+  def hint(name: String, parameters: Any*): Dataset[T] =
+    sparkSession.newDataset(agnosticEncoder) { builder =>
       builder.getHintBuilder
         .setInput(plan.getRoot)
         .setName(name)
         .addAllParameters(parameters.map(p => functions.lit(p).expr).asJava)
-  }
+    }
 
   private def getPlanId: Option[Long] =
     if (plan.getRoot.hasCommon && plan.getRoot.getCommon.hasPlanId) {
@@ -1056,7 +1058,7 @@ class Dataset[T] private[sql] (
    * @group typedrel
    * @since 3.4.0
    */
-  def as(alias: String): Dataset[T] = sparkSession.newDataset(encoder) { 
builder =>
+  def as(alias: String): Dataset[T] = sparkSession.newDataset(agnosticEncoder) 
{ builder =>
     builder.getSubqueryAliasBuilder
       .setInput(plan.getRoot)
       .setAlias(alias)
@@ -1238,8 +1240,9 @@ class Dataset[T] private[sql] (
    * @group typedrel
    * @since 3.4.0
    */
-  def filter(condition: Column): Dataset[T] = sparkSession.newDataset(encoder) 
{ builder =>
-    
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
+  def filter(condition: Column): Dataset[T] = 
sparkSession.newDataset(agnosticEncoder) {
+    builder =>
+      
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
   }
 
   /**
@@ -1355,12 +1358,12 @@ class Dataset[T] private[sql] (
   def reduce(func: (T, T) => T): T = {
     val udf = ScalarUserDefinedFunction(
       function = func,
-      inputEncoders = encoder :: encoder :: Nil,
-      outputEncoder = encoder)
+      inputEncoders = agnosticEncoder :: agnosticEncoder :: Nil,
+      outputEncoder = agnosticEncoder)
     val reduceExpr = Column.fn("reduce", udf.apply(col("*"), col("*"))).expr
 
     val result = sparkSession
-      .newDataset(encoder) { builder =>
+      .newDataset(agnosticEncoder) { builder =>
         builder.getAggregateBuilder
           .setInput(plan.getRoot)
           .addAggregateExpressions(reduceExpr)
@@ -1718,7 +1721,7 @@ class Dataset[T] private[sql] (
    * @group typedrel
    * @since 3.4.0
    */
-  def limit(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
+  def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { 
builder =>
     builder.getLimitBuilder
       .setInput(plan.getRoot)
       .setLimit(n)
@@ -1730,7 +1733,7 @@ class Dataset[T] private[sql] (
    * @group typedrel
    * @since 3.4.0
    */
-  def offset(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder 
=>
+  def offset(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { 
builder =>
     builder.getOffsetBuilder
       .setInput(plan.getRoot)
       .setOffset(n)
@@ -1739,7 +1742,7 @@ class Dataset[T] private[sql] (
   private def buildSetOp(right: Dataset[T], setOpType: 
proto.SetOperation.SetOpType)(
       f: proto.SetOperation.Builder => Unit): Dataset[T] = {
     checkSameSparkSession(right)
-    sparkSession.newDataset(encoder) { builder =>
+    sparkSession.newDataset(agnosticEncoder) { builder =>
       f(
         builder.getSetOpBuilder
           .setSetOpType(setOpType)
@@ -2012,7 +2015,7 @@ class Dataset[T] private[sql] (
    * @since 3.4.0
    */
   def sample(withReplacement: Boolean, fraction: Double, seed: Long): 
Dataset[T] = {
-    sparkSession.newDataset(encoder) { builder =>
+    sparkSession.newDataset(agnosticEncoder) { builder =>
       builder.getSampleBuilder
         .setInput(plan.getRoot)
         .setWithReplacement(withReplacement)
@@ -2080,7 +2083,7 @@ class Dataset[T] private[sql] (
     normalizedCumWeights
       .sliding(2)
       .map { case Array(low, high) =>
-        sparkSession.newDataset(encoder) { builder =>
+        sparkSession.newDataset(agnosticEncoder) { builder =>
           builder.getSampleBuilder
             .setInput(sortedInput)
             .setWithReplacement(false)
@@ -2401,15 +2404,16 @@ class Dataset[T] private[sql] (
 
   private def buildDropDuplicates(
       columns: Option[Seq[String]],
-      withinWaterMark: Boolean): Dataset[T] = sparkSession.newDataset(encoder) 
{ builder =>
-    val dropBuilder = builder.getDeduplicateBuilder
-      .setInput(plan.getRoot)
-      .setWithinWatermark(withinWaterMark)
-    if (columns.isDefined) {
-      dropBuilder.addAllColumnNames(columns.get.asJava)
-    } else {
-      dropBuilder.setAllColumnsAsKeys(true)
-    }
+      withinWaterMark: Boolean): Dataset[T] = 
sparkSession.newDataset(agnosticEncoder) {
+    builder =>
+      val dropBuilder = builder.getDeduplicateBuilder
+        .setInput(plan.getRoot)
+        .setWithinWatermark(withinWaterMark)
+      if (columns.isDefined) {
+        dropBuilder.addAllColumnNames(columns.get.asJava)
+      } else {
+        dropBuilder.setAllColumnsAsKeys(true)
+      }
   }
 
   /**
@@ -2630,9 +2634,9 @@ class Dataset[T] private[sql] (
   def filter(func: T => Boolean): Dataset[T] = {
     val udf = ScalarUserDefinedFunction(
       function = func,
-      inputEncoders = encoder :: Nil,
+      inputEncoders = agnosticEncoder :: Nil,
       outputEncoder = PrimitiveBooleanEncoder)
-    sparkSession.newDataset[T](encoder) { builder =>
+    sparkSession.newDataset[T](agnosticEncoder) { builder =>
       builder.getFilterBuilder
         .setInput(plan.getRoot)
         .setCondition(udf.apply(col("*")).expr)
@@ -2683,7 +2687,7 @@ class Dataset[T] private[sql] (
     val outputEncoder = encoderFor[U]
     val udf = ScalarUserDefinedFunction(
       function = func,
-      inputEncoders = encoder :: Nil,
+      inputEncoders = agnosticEncoder :: Nil,
       outputEncoder = outputEncoder)
     sparkSession.newDataset(outputEncoder) { builder =>
       builder.getMapPartitionsBuilder
@@ -2785,7 +2789,7 @@ class Dataset[T] private[sql] (
    * @since 3.4.0
    */
   def tail(n: Int): Array[T] = {
-    val lastN = sparkSession.newDataset(encoder) { builder =>
+    val lastN = sparkSession.newDataset(agnosticEncoder) { builder =>
       builder.getTailBuilder
         .setInput(plan.getRoot)
         .setLimit(n)
@@ -2856,7 +2860,7 @@ class Dataset[T] private[sql] (
   }
 
   private def buildRepartition(numPartitions: Int, shuffle: Boolean): 
Dataset[T] = {
-    sparkSession.newDataset(encoder) { builder =>
+    sparkSession.newDataset(agnosticEncoder) { builder =>
       builder.getRepartitionBuilder
         .setInput(plan.getRoot)
         .setNumPartitions(numPartitions)
@@ -2866,11 +2870,12 @@ class Dataset[T] private[sql] (
 
   private def buildRepartitionByExpression(
       numPartitions: Option[Int],
-      partitionExprs: Seq[Column]): Dataset[T] = 
sparkSession.newDataset(encoder) { builder =>
-    val repartitionBuilder = builder.getRepartitionByExpressionBuilder
-      .setInput(plan.getRoot)
-      .addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
-    numPartitions.foreach(repartitionBuilder.setNumPartitions)
+      partitionExprs: Seq[Column]): Dataset[T] = 
sparkSession.newDataset(agnosticEncoder) {
+    builder =>
+      val repartitionBuilder = builder.getRepartitionByExpressionBuilder
+        .setInput(plan.getRoot)
+        .addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
+      numPartitions.foreach(repartitionBuilder.setNumPartitions)
   }
 
   /**
@@ -3183,7 +3188,7 @@ class Dataset[T] private[sql] (
    * @since 3.5.0
    */
   def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = {
-    sparkSession.newDataset(encoder) { builder =>
+    sparkSession.newDataset(agnosticEncoder) { builder =>
       builder.getWithWatermarkBuilder
         .setInput(plan.getRoot)
         .setEventTime(eventTime)
@@ -3251,7 +3256,7 @@ class Dataset[T] private[sql] (
     sparkSession.analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA)
   }
 
-  def collectResult(): SparkResult[T] = sparkSession.execute(plan, encoder)
+  def collectResult(): SparkResult[T] = sparkSession.execute(plan, 
agnosticEncoder)
 
   private[sql] def withResult[E](f: SparkResult[T] => E): E = {
     val result = collectResult()
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index e67ef1c0fa7..202891c66d7 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -988,15 +988,15 @@ private object KeyValueGroupedDatasetImpl {
       groupingFunc: V => K): KeyValueGroupedDatasetImpl[K, V, K, V] = {
     val gf = ScalarUserDefinedFunction(
       function = groupingFunc,
-      inputEncoders = ds.encoder :: Nil, // Using the original value and key 
encoders
+      inputEncoders = ds.agnosticEncoder :: Nil, // Using the original value 
and key encoders
       outputEncoder = kEncoder)
     new KeyValueGroupedDatasetImpl(
       ds.sparkSession,
       ds.plan,
       kEncoder,
       kEncoder,
-      ds.encoder,
-      ds.encoder,
+      ds.agnosticEncoder,
+      ds.agnosticEncoder,
       Arrays.asList(gf.apply(col("*")).expr),
       UdfUtils.identical(),
       () => ds.map(groupingFunc)(kEncoder))
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index b395a2d073d..b9aa1f5bc58 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -216,7 +216,7 @@ final class DataStreamWriter[T] private[sql] (ds: 
Dataset[T]) extends Logging {
    * @since 3.5.0
    */
   def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = {
-    val serialized = SparkSerDeUtils.serialize(ForeachWriterPacket(writer, 
ds.encoder))
+    val serialized = SparkSerDeUtils.serialize(ForeachWriterPacket(writer, 
ds.agnosticEncoder))
     val scalaWriterBuilder = proto.ScalarScalaUDF
       .newBuilder()
       .setPayload(ByteString.copyFrom(serialized))
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index d380a1bbb65..4439a5f3e2a 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -181,7 +181,6 @@ object CheckConnectJvmClientCompatibility {
       
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener"),
       
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener$"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.queryExecution"),
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.encoder"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.sqlContext"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.metadataColumn"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"), 
// protected
@@ -334,8 +333,6 @@ object CheckConnectJvmClientCompatibility {
       ProblemFilters.exclude[DirectMissingMethodProblem](
         "org.apache.spark.sql.Dataset.plan"
       ), // developer API
-      ProblemFilters.exclude[IncompatibleResultTypeProblem](
-        "org.apache.spark.sql.Dataset.encoder"),
       ProblemFilters.exclude[DirectMissingMethodProblem](
         "org.apache.spark.sql.Dataset.collectResult"),
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to