This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new bf7654998fb [SPARK-44686][CONNECT][SQL] Add the ability to create a RowEncoder in Encoders.scala bf7654998fb is described below commit bf7654998fbbec9d5bdee6f46462cffef495545f Author: Herman van Hovell <her...@databricks.com> AuthorDate: Mon Aug 7 15:09:58 2023 +0200 [SPARK-44686][CONNECT][SQL] Add the ability to create a RowEncoder in Encoders.scala ### What changes were proposed in this pull request? ### Why are the changes needed? It is currently not possible to create a `RowEncoder` using public API. The internal APIs for this will change in Spark 3.5, this means that library maintainers have to update their code if they use a RowEncoder. To avoid happening again, we add this method to the public API. ### Does this PR introduce _any_ user-facing change? Yes. It adds the `row` method to `Encoders`. ### How was this patch tested? Added tests to connect and sql. Closes #42366 from hvanhovell/SPARK-44686. Lead-authored-by: Herman van Hovell <her...@databricks.com> Co-authored-by: Hyukjin Kwon <gurwls...@gmail.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../main/scala/org/apache/spark/sql/Encoders.scala | 10 ++++++- .../org/apache/spark/sql/JavaEncoderSuite.java | 31 +++++++++++++++++++--- project/MimaExcludes.scala | 2 ++ .../main/java/org/apache/spark/sql/RowFactory.java | 0 .../main/scala/org/apache/spark/sql/Encoders.scala | 7 +++++ .../org/apache/spark/sql/JavaDatasetSuite.java | 19 +++++++++++++ 6 files changed, 64 insertions(+), 5 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala index 3f2f7ec96d4..74f01338031 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder => RowEncoderFactory} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ +import org.apache.spark.sql.types.StructType /** * Methods for creating an [[Encoder]]. @@ -168,6 +169,13 @@ object Encoders { */ def bean[T](beanClass: Class[T]): Encoder[T] = JavaTypeInference.encoderFor(beanClass) + /** + * Creates a [[Row]] encoder for schema `schema`. + * + * @since 3.5.0 + */ + def row(schema: StructType): Encoder[Row] = RowEncoderFactory.encoderFor(schema) + private def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = { ProductEncoder.tuple(encoders.asInstanceOf[Seq[AgnosticEncoder[_]]]).asInstanceOf[Encoder[T]] } diff --git a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java index c8210a7a485..6e5fb72d496 100644 --- a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java +++ b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java @@ -16,21 +16,26 @@ */ package org.apache.spark.sql; +import java.io.Serializable; +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.List; + import org.junit.*; import static org.junit.Assert.*; import static org.apache.spark.sql.Encoders.*; import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.RowFactory.create; import org.apache.spark.sql.connect.client.SparkConnectClient; import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils; - -import java.math.BigDecimal; -import java.util.Arrays; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.types.StructType; /** * Tests for the encoders class. */ -public class JavaEncoderSuite { +public class JavaEncoderSuite implements Serializable { private static SparkSession spark; @BeforeClass @@ -91,4 +96,22 @@ public class JavaEncoderSuite { dataset(DECIMAL(), bigDec(1000, 2), bigDec(2, 2)) .select(sum(v)).as(DECIMAL()).head().setScale(2)); } + + @Test + public void testRowEncoder() { + final StructType schema = new StructType() + .add("a", "int") + .add("b", "string"); + final Dataset<Row> df = spark.range(3) + .map(new MapFunction<Long, Row>() { + @Override + public Row call(Long i) { + return create(i.intValue(), "s" + i); + } + }, + Encoders.row(schema)) + .filter(col("a").geq(1)); + final List<Row> expected = Arrays.asList(create(1, "s1"), create(2, "s2")); + Assert.assertEquals(expected, df.collectAsList()); + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d0fc8f2b116..9e5eb66ce94 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -71,6 +71,8 @@ object MimaExcludes { // [SPARK-44507][SQL][CONNECT] Move AnalysisException to sql/api. ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.AnalysisException"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.AnalysisException$"), + // [SPARK-44686][CONNECT][SQL] Add the ability to create a RowEncoder in Encoders + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RowFactory"), // [SPARK-44535][CONNECT][SQL] Move required Streaming API to sql/api ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.GroupStateTimeout"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.OutputMode"), diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java b/sql/api/src/main/java/org/apache/spark/sql/RowFactory.java similarity index 100% rename from sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java rename to sql/api/src/main/java/org/apache/spark/sql/RowFactory.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index a4198044886..9b95f74db3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -178,6 +178,13 @@ object Encoders { */ def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass) + /** + * Creates a [[Row]] encoder for schema `schema`. + * + * @since 3.5.0 + */ + def row(schema: StructType): Encoder[Row] = ExpressionEncoder(schema) + /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 48fd009d6e7..4f7cf8da787 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -42,6 +42,7 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.sql.*; +import static org.apache.spark.sql.RowFactory.create; import org.apache.spark.sql.catalyst.encoders.OuterScopes; import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.test.TestSparkSession; @@ -1956,6 +1957,24 @@ public class JavaDatasetSuite implements Serializable { Assert.assertEquals(beans, dataset.collectAsList()); } + @Test + public void testRowEncoder() { + final StructType schema = new StructType() + .add("a", "int") + .add("b", "string"); + final Dataset<Row> df = spark.range(3) + .map(new MapFunction<Long, Row>() { + @Override + public Row call(Long i) { + return create(i.intValue(), "s" + i); + } + }, + Encoders.row(schema)) + .filter(col("a").geq(1)); + final List<Row> expected = Arrays.asList(create(1, "s1"), create(2, "s2")); + Assert.assertEquals(expected, df.collectAsList()); + } + public static class SpecificListsBean implements Serializable { private ArrayList<Integer> arrayList; private LinkedList<Integer> linkedList; --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org