This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 8f6f301fa77 [SPARK-44686][CONNECT][SQL] Add the ability to create a 
RowEncoder in Encoders.scala
8f6f301fa77 is described below

commit 8f6f301fa778dfd0fd7dec4a29df7106846d3277
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Mon Aug 7 15:09:58 2023 +0200

    [SPARK-44686][CONNECT][SQL] Add the ability to create a RowEncoder in 
Encoders.scala
    
    ### What changes were proposed in this pull request?
    ### Why are the changes needed?
    It is currently not possible to create a `RowEncoder` using public API. The 
internal APIs for this will change in Spark 3.5, this means that library 
maintainers have to update their code if they use a RowEncoder. To avoid 
happening again, we add this method to the public API.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. It adds the `row` method to `Encoders`.
    
    ### How was this patch tested?
    Added tests to connect and sql.
    
    Closes #42366 from hvanhovell/SPARK-44686.
    
    Lead-authored-by: Herman van Hovell <her...@databricks.com>
    Co-authored-by: Hyukjin Kwon <gurwls...@gmail.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
    (cherry picked from commit bf7654998fbbec9d5bdee6f46462cffef495545f)
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Encoders.scala | 10 ++++++-
 .../org/apache/spark/sql/JavaEncoderSuite.java     | 31 +++++++++++++++++++---
 project/MimaExcludes.scala                         |  2 ++
 .../main/java/org/apache/spark/sql/RowFactory.java |  0
 .../main/scala/org/apache/spark/sql/Encoders.scala |  7 +++++
 .../org/apache/spark/sql/JavaDatasetSuite.java     | 19 +++++++++++++
 6 files changed, 64 insertions(+), 5 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala
index 3f2f7ec96d4..74f01338031 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql
 import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder => 
RowEncoderFactory}
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
+import org.apache.spark.sql.types.StructType
 
 /**
  * Methods for creating an [[Encoder]].
@@ -168,6 +169,13 @@ object Encoders {
    */
   def bean[T](beanClass: Class[T]): Encoder[T] = 
JavaTypeInference.encoderFor(beanClass)
 
+  /**
+   * Creates a [[Row]] encoder for schema `schema`.
+   *
+   * @since 3.5.0
+   */
+  def row(schema: StructType): Encoder[Row] = 
RowEncoderFactory.encoderFor(schema)
+
   private def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = {
     
ProductEncoder.tuple(encoders.asInstanceOf[Seq[AgnosticEncoder[_]]]).asInstanceOf[Encoder[T]]
   }
diff --git 
a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java
 
b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java
index c8210a7a485..6e5fb72d496 100644
--- 
a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java
+++ 
b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java
@@ -16,21 +16,26 @@
  */
 package org.apache.spark.sql;
 
+import java.io.Serializable;
+import java.math.BigDecimal;
+import java.util.Arrays;
+import java.util.List;
+
 import org.junit.*;
 import static org.junit.Assert.*;
 
 import static org.apache.spark.sql.Encoders.*;
 import static org.apache.spark.sql.functions.*;
+import static org.apache.spark.sql.RowFactory.create;
 import org.apache.spark.sql.connect.client.SparkConnectClient;
 import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils;
-
-import java.math.BigDecimal;
-import java.util.Arrays;
+import org.apache.spark.api.java.function.MapFunction;
+import org.apache.spark.sql.types.StructType;
 
 /**
  * Tests for the encoders class.
  */
-public class JavaEncoderSuite {
+public class JavaEncoderSuite implements Serializable {
   private static SparkSession spark;
 
   @BeforeClass
@@ -91,4 +96,22 @@ public class JavaEncoderSuite {
         dataset(DECIMAL(), bigDec(1000, 2), bigDec(2, 2))
             .select(sum(v)).as(DECIMAL()).head().setScale(2));
   }
+
+  @Test
+  public void testRowEncoder() {
+    final StructType schema = new StructType()
+        .add("a", "int")
+        .add("b", "string");
+    final Dataset<Row> df = spark.range(3)
+        .map(new MapFunction<Long, Row>() {
+               @Override
+               public Row call(Long i) {
+                 return create(i.intValue(), "s" + i);
+               }
+             },
+            Encoders.row(schema))
+        .filter(col("a").geq(1));
+    final List<Row> expected = Arrays.asList(create(1, "s1"), create(2, "s2"));
+    Assert.assertEquals(expected, df.collectAsList());
+  }
 }
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 6cc2033ebbe..c2ccb680cbd 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -67,6 +67,8 @@ object MimaExcludes {
     // [SPARK-44507][SQL][CONNECT] Move AnalysisException to sql/api.
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.AnalysisException"),
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.AnalysisException$"),
+    // [SPARK-44686][CONNECT][SQL] Add the ability to create a RowEncoder in 
Encoders
+    
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RowFactory"),
     // [SPARK-44535][CONNECT][SQL] Move required Streaming API to sql/api
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.GroupStateTimeout"),
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.OutputMode")
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java 
b/sql/api/src/main/java/org/apache/spark/sql/RowFactory.java
similarity index 100%
rename from sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java
rename to sql/api/src/main/java/org/apache/spark/sql/RowFactory.java
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
index a4198044886..9b95f74db3a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -178,6 +178,13 @@ object Encoders {
    */
   def bean[T](beanClass: Class[T]): Encoder[T] = 
ExpressionEncoder.javaBean(beanClass)
 
+  /**
+   * Creates a [[Row]] encoder for schema `schema`.
+   *
+   * @since 3.5.0
+   */
+  def row(schema: StructType): Encoder[Row] = ExpressionEncoder(schema)
+
   /**
    * (Scala-specific) Creates an encoder that serializes objects of type T 
using Kryo.
    * This encoder maps T into a single byte array (binary) field.
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 48fd009d6e7..4f7cf8da787 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -42,6 +42,7 @@ import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.*;
 import org.apache.spark.sql.*;
+import static org.apache.spark.sql.RowFactory.create;
 import org.apache.spark.sql.catalyst.encoders.OuterScopes;
 import org.apache.spark.sql.catalyst.expressions.GenericRow;
 import org.apache.spark.sql.test.TestSparkSession;
@@ -1956,6 +1957,24 @@ public class JavaDatasetSuite implements Serializable {
     Assert.assertEquals(beans, dataset.collectAsList());
   }
 
+  @Test
+  public void testRowEncoder() {
+    final StructType schema = new StructType()
+        .add("a", "int")
+        .add("b", "string");
+    final Dataset<Row> df = spark.range(3)
+        .map(new MapFunction<Long, Row>() {
+               @Override
+               public Row call(Long i) {
+                 return create(i.intValue(), "s" + i);
+               }
+             },
+            Encoders.row(schema))
+        .filter(col("a").geq(1));
+    final List<Row> expected = Arrays.asList(create(1, "s1"), create(2, "s2"));
+    Assert.assertEquals(expected, df.collectAsList());
+  }
+
   public static class SpecificListsBean implements Serializable {
     private ArrayList<Integer> arrayList;
     private LinkedList<Integer> linkedList;


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to