This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new bde7aa61ce3 [SPARK-44613][CONNECT] Add Encoders object
bde7aa61ce3 is described below

commit bde7aa61ce3de15323a8920e8114a681fcd17000
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Tue Aug 1 14:39:38 2023 -0400

    [SPARK-44613][CONNECT] Add Encoders object
    
    ### What changes were proposed in this pull request?
    This PR adds the org.apache.spark.sql.Encoders object to Connect.
    
    ### Why are the changes needed?
    To increase compatibility with the SQL Dataframe API
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, it adds missing functionality.
    
    ### How was this patch tested?
    Added a couple of java based tests.
    
    Closes #42264 from hvanhovell/SPARK-44613.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
    (cherry picked from commit 4f62f8a718e80dca13a1d44b6fdf8857f037c15e)
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Encoders.scala | 262 +++++++++++++++++++++
 .../spark/sql/connect/client/SparkResult.scala     |  14 +-
 .../org/apache/spark/sql/JavaEncoderSuite.java     |  94 ++++++++
 .../CheckConnectJvmClientCompatibility.scala       |   8 +-
 .../connect/client/util/RemoteSparkSession.scala   |   2 +-
 5 files changed, 371 insertions(+), 9 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala
new file mode 100644
index 00000000000..3f2f7ec96d4
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
+
+/**
+ * Methods for creating an [[Encoder]].
+ *
+ * @since 3.5.0
+ */
+object Encoders {
+
+  /**
+   * An encoder for nullable boolean type. The Scala primitive encoder is 
available as
+   * [[scalaBoolean]].
+   * @since 3.5.0
+   */
+  def BOOLEAN: Encoder[java.lang.Boolean] = BoxedBooleanEncoder
+
+  /**
+   * An encoder for nullable byte type. The Scala primitive encoder is 
available as [[scalaByte]].
+   * @since 3.5.0
+   */
+  def BYTE: Encoder[java.lang.Byte] = BoxedByteEncoder
+
+  /**
+   * An encoder for nullable short type. The Scala primitive encoder is 
available as
+   * [[scalaShort]].
+   * @since 3.5.0
+   */
+  def SHORT: Encoder[java.lang.Short] = BoxedShortEncoder
+
+  /**
+   * An encoder for nullable int type. The Scala primitive encoder is 
available as [[scalaInt]].
+   * @since 3.5.0
+   */
+  def INT: Encoder[java.lang.Integer] = BoxedIntEncoder
+
+  /**
+   * An encoder for nullable long type. The Scala primitive encoder is 
available as [[scalaLong]].
+   * @since 3.5.0
+   */
+  def LONG: Encoder[java.lang.Long] = BoxedLongEncoder
+
+  /**
+   * An encoder for nullable float type. The Scala primitive encoder is 
available as
+   * [[scalaFloat]].
+   * @since 3.5.0
+   */
+  def FLOAT: Encoder[java.lang.Float] = BoxedFloatEncoder
+
+  /**
+   * An encoder for nullable double type. The Scala primitive encoder is 
available as
+   * [[scalaDouble]].
+   * @since 3.5.0
+   */
+  def DOUBLE: Encoder[java.lang.Double] = BoxedDoubleEncoder
+
+  /**
+   * An encoder for nullable string type.
+   *
+   * @since 3.5.0
+   */
+  def STRING: Encoder[java.lang.String] = StringEncoder
+
+  /**
+   * An encoder for nullable decimal type.
+   *
+   * @since 3.5.0
+   */
+  def DECIMAL: Encoder[java.math.BigDecimal] = DEFAULT_JAVA_DECIMAL_ENCODER
+
+  /**
+   * An encoder for nullable date type.
+   *
+   * @since 3.5.0
+   */
+  def DATE: Encoder[java.sql.Date] = DateEncoder(lenientSerialization = false)
+
+  /**
+   * Creates an encoder that serializes instances of the `java.time.LocalDate` 
class to the
+   * internal representation of nullable Catalyst's DateType.
+   *
+   * @since 3.5.0
+   */
+  def LOCALDATE: Encoder[java.time.LocalDate] = STRICT_LOCAL_DATE_ENCODER
+
+  /**
+   * Creates an encoder that serializes instances of the 
`java.time.LocalDateTime` class to the
+   * internal representation of nullable Catalyst's TimestampNTZType.
+   *
+   * @since 3.5.0
+   */
+  def LOCALDATETIME: Encoder[java.time.LocalDateTime] = LocalDateTimeEncoder
+
+  /**
+   * An encoder for nullable timestamp type.
+   *
+   * @since 3.5.0
+   */
+  def TIMESTAMP: Encoder[java.sql.Timestamp] = STRICT_TIMESTAMP_ENCODER
+
+  /**
+   * Creates an encoder that serializes instances of the `java.time.Instant` 
class to the internal
+   * representation of nullable Catalyst's TimestampType.
+   *
+   * @since 3.5.0
+   */
+  def INSTANT: Encoder[java.time.Instant] = STRICT_INSTANT_ENCODER
+
+  /**
+   * An encoder for arrays of bytes.
+   *
+   * @since 3.5.0
+   */
+  def BINARY: Encoder[Array[Byte]] = BinaryEncoder
+
+  /**
+   * Creates an encoder that serializes instances of the `java.time.Duration` 
class to the
+   * internal representation of nullable Catalyst's DayTimeIntervalType.
+   *
+   * @since 3.5.0
+   */
+  def DURATION: Encoder[java.time.Duration] = DayTimeIntervalEncoder
+
+  /**
+   * Creates an encoder that serializes instances of the `java.time.Period` 
class to the internal
+   * representation of nullable Catalyst's YearMonthIntervalType.
+   *
+   * @since 3.5.0
+   */
+  def PERIOD: Encoder[java.time.Period] = YearMonthIntervalEncoder
+
+  /**
+   * Creates an encoder for Java Bean of type T.
+   *
+   * T must be publicly accessible.
+   *
+   * supported types for java bean field:
+   *   - primitive types: boolean, int, double, etc.
+   *   - boxed types: Boolean, Integer, Double, etc.
+   *   - String
+   *   - java.math.BigDecimal, java.math.BigInteger
+   *   - time related: java.sql.Date, java.sql.Timestamp, java.time.LocalDate, 
java.time.Instant
+   *   - collection types: array, java.util.List, and map
+   *   - nested java bean.
+   *
+   * @since 3.5.0
+   */
+  def bean[T](beanClass: Class[T]): Encoder[T] = 
JavaTypeInference.encoderFor(beanClass)
+
+  private def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = {
+    
ProductEncoder.tuple(encoders.asInstanceOf[Seq[AgnosticEncoder[_]]]).asInstanceOf[Encoder[T]]
+  }
+
+  /**
+   * An encoder for 2-ary tuples.
+   *
+   * @since 3.5.0
+   */
+  def tuple[T1, T2](e1: Encoder[T1], e2: Encoder[T2]): Encoder[(T1, T2)] = 
tupleEncoder(e1, e2)
+
+  /**
+   * An encoder for 3-ary tuples.
+   *
+   * @since 3.5.0
+   */
+  def tuple[T1, T2, T3](
+      e1: Encoder[T1],
+      e2: Encoder[T2],
+      e3: Encoder[T3]): Encoder[(T1, T2, T3)] = tupleEncoder(e1, e2, e3)
+
+  /**
+   * An encoder for 4-ary tuples.
+   *
+   * @since 3.5.0
+   */
+  def tuple[T1, T2, T3, T4](
+      e1: Encoder[T1],
+      e2: Encoder[T2],
+      e3: Encoder[T3],
+      e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = tupleEncoder(e1, e2, e3, 
e4)
+
+  /**
+   * An encoder for 5-ary tuples.
+   *
+   * @since 3.5.0
+   */
+  def tuple[T1, T2, T3, T4, T5](
+      e1: Encoder[T1],
+      e2: Encoder[T2],
+      e3: Encoder[T3],
+      e4: Encoder[T4],
+      e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = tupleEncoder(e1, e2, 
e3, e4, e5)
+
+  /**
+   * An encoder for Scala's product type (tuples, case classes, etc).
+   * @since 3.5.0
+   */
+  def product[T <: Product: TypeTag]: Encoder[T] = 
ScalaReflection.encoderFor[T]
+
+  /**
+   * An encoder for Scala's primitive int type.
+   * @since 3.5.0
+   */
+  def scalaInt: Encoder[Int] = PrimitiveIntEncoder
+
+  /**
+   * An encoder for Scala's primitive long type.
+   * @since 3.5.0
+   */
+  def scalaLong: Encoder[Long] = PrimitiveLongEncoder
+
+  /**
+   * An encoder for Scala's primitive double type.
+   * @since 3.5.0
+   */
+  def scalaDouble: Encoder[Double] = PrimitiveDoubleEncoder
+
+  /**
+   * An encoder for Scala's primitive float type.
+   * @since 3.5.0
+   */
+  def scalaFloat: Encoder[Float] = PrimitiveFloatEncoder
+
+  /**
+   * An encoder for Scala's primitive byte type.
+   * @since 3.5.0
+   */
+  def scalaByte: Encoder[Byte] = PrimitiveByteEncoder
+
+  /**
+   * An encoder for Scala's primitive short type.
+   * @since 3.5.0
+   */
+  def scalaShort: Encoder[Short] = PrimitiveShortEncoder
+
+  /**
+   * An encoder for Scala's primitive boolean type.
+   * @since 3.5.0
+   */
+  def scalaBoolean: Encoder[Boolean] = PrimitiveBooleanEncoder
+}
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index e3055b2678f..93c32aa2954 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -182,11 +182,15 @@ private[sql] class SparkResult[T](
   def toArray: Array[T] = {
     val result = encoder.clsTag.newArray(length)
     val rows = iterator
-    var i = 0
-    while (rows.hasNext) {
-      result(i) = rows.next()
-      assert(i < numRecords)
-      i += 1
+    try {
+      var i = 0
+      while (rows.hasNext) {
+        result(i) = rows.next()
+        assert(i < numRecords)
+        i += 1
+      }
+    } finally {
+      rows.close()
     }
     result
   }
diff --git 
a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java
 
b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java
new file mode 100644
index 00000000000..c8210a7a485
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql;
+
+import org.junit.*;
+import static org.junit.Assert.*;
+
+import static org.apache.spark.sql.Encoders.*;
+import static org.apache.spark.sql.functions.*;
+import org.apache.spark.sql.connect.client.SparkConnectClient;
+import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils;
+
+import java.math.BigDecimal;
+import java.util.Arrays;
+
+/**
+ * Tests for the encoders class.
+ */
+public class JavaEncoderSuite {
+  private static SparkSession spark;
+
+  @BeforeClass
+  public static void setup() {
+    SparkConnectServerUtils.start();
+    spark = SparkSession
+        .builder()
+        .client(SparkConnectClient
+            .builder()
+            .port(SparkConnectServerUtils.port())
+            .build())
+        .create();
+  }
+
+  @AfterClass
+  public static void tearDown() {
+    spark.stop();
+    spark = null;
+    SparkConnectServerUtils.stop();
+  }
+
+  private static BigDecimal bigDec(long unscaled, int scale) {
+    return BigDecimal.valueOf(unscaled, scale);
+  }
+
+
+  private <T> Dataset<T> dataset(Encoder<T> encoder, T... elements) {
+    return spark.createDataset(Arrays.asList(elements), encoder);
+  }
+
+  @Test
+  public void testSimpleEncoders() {
+    final Column v = col("value");
+    assertFalse(
+        dataset(BOOLEAN(), false, true, 
false).select(every(v)).as(BOOLEAN()).head());
+    assertEquals(
+        7L,
+        dataset(BYTE(), (byte) -120, 
(byte)127).select(sum(v)).as(LONG()).head().longValue());
+    assertEquals(
+        (short) 16,
+        dataset(SHORT(), (short)16, 
(short)2334).select(min(v)).as(SHORT()).head().shortValue());
+    assertEquals(
+        10L,
+        dataset(INT(), 1, 2, 3, 
4).select(sum(v)).as(LONG()).head().longValue());
+    assertEquals(
+        96L,
+        dataset(LONG(), 77L, 
19L).select(sum(v)).as(LONG()).head().longValue());
+    assertEquals(
+        0.12f,
+        dataset(FLOAT(), 0.12f, 0.3f, 44f).select(min(v)).as(FLOAT()).head(),
+        0.0001f);
+    assertEquals(
+        789d,
+        dataset(DOUBLE(), 789d, 12.213d, 
10.01d).select(max(v)).as(DOUBLE()).head(),
+        0.0001f);
+    assertEquals(
+        bigDec(1002, 2),
+        dataset(DECIMAL(), bigDec(1000, 2), bigDec(2, 2))
+            .select(sum(v)).as(DECIMAL()).head().setScale(2));
+  }
+}
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 08028f26eb4..6e577e0f212 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -192,7 +192,6 @@ object CheckConnectJvmClientCompatibility {
 
       // functions
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udf"),
-      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.call_udf"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.callUDF"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"),
@@ -216,7 +215,6 @@ object CheckConnectJvmClientCompatibility {
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sqlContext"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.listenerManager"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.experimental"),
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udf"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udtf"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.streams"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataFrame"),
@@ -418,7 +416,11 @@ object CheckConnectJvmClientCompatibility {
       ProblemFilters.exclude[MissingClassProblem](
         "org.apache.spark.sql.streaming.RemoteStreamingQuery"),
       ProblemFilters.exclude[MissingClassProblem](
-        "org.apache.spark.sql.streaming.RemoteStreamingQuery$"))
+        "org.apache.spark.sql.streaming.RemoteStreamingQuery$"),
+
+      // Encoders are in the wrong JAR
+      
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders"),
+      
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders$"))
 
     checkMiMaCompatibility(sqlJar, clientJar, includedRules, excludeRules)
   }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
index 88c0785d3af..f14109e49b5 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
@@ -50,7 +50,7 @@ import 
org.apache.spark.sql.connect.common.config.ConnectCommon
 object SparkConnectServerUtils {
 
   // Server port
-  private[spark] val port: Int =
+  val port: Int =
     ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000)
 
   @volatile private var stopped = false


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to