This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 816774aa532 [SPARK-42610][CONNECT] Add encoders to SQLImplicits
816774aa532 is described below

commit 816774aa532ae4e017c937c86fdb784df200ee0e
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Mon Feb 27 21:32:49 2023 -0800

    [SPARK-42610][CONNECT] Add encoders to SQLImplicits
    
    ### What changes were proposed in this pull request?
    Add implicit encoder resolution to `SQLImplicits` class.
    
    ### Why are the changes needed?
    API parity.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes.
    
    ### How was this patch tested?
    Added test to `SQLImplicitsTestSuite`.
    
    Closes #40205 from hvanhovell/SPARK-42610.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
    (cherry picked from commit 968f280fd0d488372b0b09738ff9728b45499bef)
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 .../scala/org/apache/spark/sql/SQLImplicits.scala  | 240 ++++++++++++++++++++-
 .../scala/org/apache/spark/sql/SparkSession.scala  |   2 +-
 .../apache/spark/sql/SQLImplicitsTestSuite.scala   |  95 ++++++++
 3 files changed, 334 insertions(+), 3 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index e63c9481da5..8f429541def 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -17,13 +17,20 @@
 package org.apache.spark.sql
 
 import scala.language.implicitConversions
+import scala.reflect.classTag
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, 
AgnosticEncoders}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
 
 /**
- * A collection of implicit methods for converting names and Symbols into 
[[Column]]s.
+ * A collection of implicit methods for converting names and Symbols into 
[[Column]]s, and for
+ * converting common Scala objects into [[Dataset]]s.
  *
  * @since 3.4.0
  */
-abstract class SQLImplicits {
+abstract class SQLImplicits extends LowPrioritySQLImplicits {
 
   /**
    * Converts $"col name" into a [[Column]].
@@ -41,4 +48,233 @@ abstract class SQLImplicits {
    * @since 3.4.0
    */
   implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
+
+  /** @since 3.4.0 */
+  implicit val newIntEncoder: Encoder[Int] = PrimitiveIntEncoder
+
+  /** @since 3.4.0 */
+  implicit val newLongEncoder: Encoder[Long] = PrimitiveLongEncoder
+
+  /** @since 3.4.0 */
+  implicit val newDoubleEncoder: Encoder[Double] = PrimitiveDoubleEncoder
+
+  /** @since 3.4.0 */
+  implicit val newFloatEncoder: Encoder[Float] = PrimitiveFloatEncoder
+
+  /** @since 3.4.0 */
+  implicit val newByteEncoder: Encoder[Byte] = PrimitiveByteEncoder
+
+  /** @since 3.4.0 */
+  implicit val newShortEncoder: Encoder[Short] = PrimitiveShortEncoder
+
+  /** @since 3.4.0 */
+  implicit val newBooleanEncoder: Encoder[Boolean] = PrimitiveBooleanEncoder
+
+  /** @since 3.4.0 */
+  implicit val newStringEncoder: Encoder[String] = StringEncoder
+
+  /** @since 3.4.0 */
+  implicit val newJavaDecimalEncoder: Encoder[java.math.BigDecimal] =
+    AgnosticEncoders.DEFAULT_JAVA_DECIMAL_ENCODER
+
+  /** @since 3.4.0 */
+  implicit val newScalaDecimalEncoder: Encoder[scala.math.BigDecimal] =
+    AgnosticEncoders.DEFAULT_SCALA_DECIMAL_ENCODER
+
+  /** @since 3.4.0 */
+  implicit val newDateEncoder: Encoder[java.sql.Date] = 
AgnosticEncoders.STRICT_DATE_ENCODER
+
+  /** @since 3.4.0 */
+  implicit val newLocalDateEncoder: Encoder[java.time.LocalDate] =
+    AgnosticEncoders.STRICT_LOCAL_DATE_ENCODER
+
+  /** @since 3.4.0 */
+  implicit val newLocalDateTimeEncoder: Encoder[java.time.LocalDateTime] =
+    AgnosticEncoders.LocalDateTimeEncoder
+
+  /** @since 3.4.0 */
+  implicit val newTimeStampEncoder: Encoder[java.sql.Timestamp] =
+    AgnosticEncoders.STRICT_TIMESTAMP_ENCODER
+
+  /** @since 3.4.0 */
+  implicit val newInstantEncoder: Encoder[java.time.Instant] =
+    AgnosticEncoders.STRICT_INSTANT_ENCODER
+
+  /** @since 3.4.0 */
+  implicit val newDurationEncoder: Encoder[java.time.Duration] = 
DayTimeIntervalEncoder
+
+  /** @since 3.4.0 */
+  implicit val newPeriodEncoder: Encoder[java.time.Period] = 
YearMonthIntervalEncoder
+
+  /** @since 3.4.0 */
+  implicit def newJavaEnumEncoder[A <: java.lang.Enum[_]: TypeTag]: Encoder[A] 
= {
+    ScalaReflection.encoderFor[A]
+  }
+
+  // Boxed primitives
+
+  /** @since 3.4.0 */
+  implicit val newBoxedIntEncoder: Encoder[java.lang.Integer] = BoxedIntEncoder
+
+  /** @since 3.4.0 */
+  implicit val newBoxedLongEncoder: Encoder[java.lang.Long] = BoxedLongEncoder
+
+  /** @since 3.4.0 */
+  implicit val newBoxedDoubleEncoder: Encoder[java.lang.Double] = 
BoxedDoubleEncoder
+
+  /** @since 3.4.0 */
+  implicit val newBoxedFloatEncoder: Encoder[java.lang.Float] = 
BoxedFloatEncoder
+
+  /** @since 3.4.0 */
+  implicit val newBoxedByteEncoder: Encoder[java.lang.Byte] = BoxedByteEncoder
+
+  /** @since 3.4.0 */
+  implicit val newBoxedShortEncoder: Encoder[java.lang.Short] = 
BoxedShortEncoder
+
+  /** @since 3.4.0 */
+  implicit val newBoxedBooleanEncoder: Encoder[java.lang.Boolean] = 
BoxedBooleanEncoder
+
+  // Seqs
+  private def newSeqEncoder[E](elementEncoder: AgnosticEncoder[E]): 
AgnosticEncoder[Seq[E]] = {
+    IterableEncoder(
+      classTag[Seq[E]],
+      elementEncoder,
+      elementEncoder.nullable,
+      elementEncoder.lenientSerialization)
+  }
+
+  /**
+   * @since 3.4.0
+   * @deprecated
+   *   use [[newSequenceEncoder]]
+   */
+  val newIntSeqEncoder: Encoder[Seq[Int]] = newSeqEncoder(PrimitiveIntEncoder)
+
+  /**
+   * @since 3.4.0
+   * @deprecated
+   *   use [[newSequenceEncoder]]
+   */
+  val newLongSeqEncoder: Encoder[Seq[Long]] = 
newSeqEncoder(PrimitiveLongEncoder)
+
+  /**
+   * @since 3.4.0
+   * @deprecated
+   *   use [[newSequenceEncoder]]
+   */
+  val newDoubleSeqEncoder: Encoder[Seq[Double]] = 
newSeqEncoder(PrimitiveDoubleEncoder)
+
+  /**
+   * @since 3.4.0
+   * @deprecated
+   *   use [[newSequenceEncoder]]
+   */
+  val newFloatSeqEncoder: Encoder[Seq[Float]] = 
newSeqEncoder(PrimitiveFloatEncoder)
+
+  /**
+   * @since 3.4.0
+   * @deprecated
+   *   use [[newSequenceEncoder]]
+   */
+  val newByteSeqEncoder: Encoder[Seq[Byte]] = 
newSeqEncoder(PrimitiveByteEncoder)
+
+  /**
+   * @since 3.4.0
+   * @deprecated
+   *   use [[newSequenceEncoder]]
+   */
+  val newShortSeqEncoder: Encoder[Seq[Short]] = 
newSeqEncoder(PrimitiveShortEncoder)
+
+  /**
+   * @since 3.4.0
+   * @deprecated
+   *   use [[newSequenceEncoder]]
+   */
+  val newBooleanSeqEncoder: Encoder[Seq[Boolean]] = 
newSeqEncoder(PrimitiveBooleanEncoder)
+
+  /**
+   * @since 3.4.0
+   * @deprecated
+   *   use [[newSequenceEncoder]]
+   */
+  val newStringSeqEncoder: Encoder[Seq[String]] = newSeqEncoder(StringEncoder)
+
+  /**
+   * @since 3.4.0
+   * @deprecated
+   *   use [[newSequenceEncoder]]
+   */
+  def newProductSeqEncoder[A <: Product: TypeTag]: Encoder[Seq[A]] =
+    newSeqEncoder(ScalaReflection.encoderFor[A])
+
+  /** @since 3.4.0 */
+  implicit def newSequenceEncoder[T <: Seq[_]: TypeTag]: Encoder[T] =
+    ScalaReflection.encoderFor[T]
+
+  // Maps
+  /** @since 3.4.0 */
+  implicit def newMapEncoder[T <: Map[_, _]: TypeTag]: Encoder[T] = 
ScalaReflection.encoderFor[T]
+
+  /**
+   * Notice that we serialize `Set` to Catalyst array. The set property is 
only kept when
+   * manipulating the domain objects. The serialization format doesn't keep 
the set property. When
+   * we have a Catalyst array which contains duplicated elements and convert 
it to
+   * `Dataset[Set[T]]` by using the encoder, the elements will be 
de-duplicated.
+   *
+   * @since 3.4.0
+   */
+  implicit def newSetEncoder[T <: Set[_]: TypeTag]: Encoder[T] = 
ScalaReflection.encoderFor[T]
+
+  // Arrays
+  private def newArrayEncoder[E](
+      elementEncoder: AgnosticEncoder[E]): AgnosticEncoder[Array[E]] = {
+    ArrayEncoder(elementEncoder, elementEncoder.nullable)
+  }
+
+  /** @since 3.4.0 */
+  implicit val newIntArrayEncoder: Encoder[Array[Int]] = 
newArrayEncoder(PrimitiveIntEncoder)
+
+  /** @since 3.4.0 */
+  implicit val newLongArrayEncoder: Encoder[Array[Long]] = 
newArrayEncoder(PrimitiveLongEncoder)
+
+  /** @since 3.4.0 */
+  implicit val newDoubleArrayEncoder: Encoder[Array[Double]] =
+    newArrayEncoder(PrimitiveDoubleEncoder)
+
+  /** @since 3.4.0 */
+  implicit val newFloatArrayEncoder: Encoder[Array[Float]] = newArrayEncoder(
+    PrimitiveFloatEncoder)
+
+  /** @since 3.4.0 */
+  implicit val newByteArrayEncoder: Encoder[Array[Byte]] = BinaryEncoder
+
+  /** @since 3.4.0 */
+  implicit val newShortArrayEncoder: Encoder[Array[Short]] = newArrayEncoder(
+    PrimitiveShortEncoder)
+
+  /** @since 3.4.0 */
+  implicit val newBooleanArrayEncoder: Encoder[Array[Boolean]] =
+    newArrayEncoder(PrimitiveBooleanEncoder)
+
+  /** @since 3.4.0 */
+  implicit val newStringArrayEncoder: Encoder[Array[String]] = 
newArrayEncoder(StringEncoder)
+
+  /** @since 3.4.0 */
+  implicit def newProductArrayEncoder[A <: Product: TypeTag]: 
Encoder[Array[A]] = {
+    newArrayEncoder(ScalaReflection.encoderFor[A])
+  }
+}
+
+/**
+ * Lower priority implicit methods for converting Scala objects into 
[[Dataset]]s. Conflicting
+ * implicits are placed here to disambiguate resolution.
+ *
+ * Reasons for including specific implicits: newProductEncoder - to 
disambiguate for `List`s which
+ * are both `Seq` and `Product`
+ */
+trait LowPrioritySQLImplicits {
+
+  /** @since 3.4.0 */
+  implicit def newProductEncoder[T <: Product: TypeTag]: Encoder[T] =
+    ScalaReflection.encoderFor[T]
 }
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 3aed781855c..fa13af00f14 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -207,7 +207,7 @@ class SparkSession(
   // Disable style checker so "implicits" object can start with lowercase i
   /**
    * (Scala-specific) Implicit methods available in Scala for converting 
common names and
-   * [[Symbol]]s into [[Column]]s.
+   * [[Symbol]]s into [[Column]]s, and for converting common Scala objects 
into `DataFrame`s.
    *
    * {{{
    *   val sparkSession = SparkSession.builder.getOrCreate()
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
index 1f141d7c71a..3fcc135a22e 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
@@ -16,15 +16,21 @@
  */
 package org.apache.spark.sql
 
+import java.sql.{Date, Timestamp}
+import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
 import java.util.concurrent.atomic.AtomicLong
 
 import io.grpc.inprocess.InProcessChannelBuilder
 import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, 
ExpressionEncoder}
 import org.apache.spark.sql.connect.client.SparkConnectClient
 import org.apache.spark.sql.connect.client.util.ConnectFunSuite
 
+/**
+ * Test suite for SQL implicits.
+ */
 class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll {
   private var session: SparkSession = _
 
@@ -44,4 +50,93 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
     assertEqual($"x", Column("x"))
     assertEqual('y, Column("y"))
   }
+
+  test("test implicit encoder resolution") {
+    val spark = session
+    import spark.implicits._
+    def testImplicit[T: Encoder](expected: T): Unit = {
+      val encoder = implicitly[Encoder[T]].asInstanceOf[AgnosticEncoder[T]]
+      val expressionEncoder = ExpressionEncoder(encoder).resolveAndBind()
+      val serializer = expressionEncoder.createSerializer()
+      val deserializer = expressionEncoder.createDeserializer()
+      val actual = deserializer(serializer(expected))
+      assert(actual === expected)
+    }
+
+    val booleans = Array(false, true, false, false)
+    testImplicit(booleans.head)
+    testImplicit(java.lang.Boolean.valueOf(booleans.head))
+    testImplicit(booleans)
+    testImplicit(booleans.toSeq)
+    testImplicit(booleans.toSeq)(newBooleanSeqEncoder)
+
+    val bytes = Array(76.toByte, 59.toByte, 121.toByte)
+    testImplicit(bytes.head)
+    testImplicit(java.lang.Byte.valueOf(bytes.head))
+    testImplicit(bytes)
+    testImplicit(bytes.toSeq)
+    testImplicit(bytes.toSeq)(newByteSeqEncoder)
+
+    val shorts = Array(21.toShort, (-213).toShort, 14876.toShort)
+    testImplicit(shorts.head)
+    testImplicit(java.lang.Short.valueOf(shorts.head))
+    testImplicit(shorts)
+    testImplicit(shorts.toSeq)
+    testImplicit(shorts.toSeq)(newShortSeqEncoder)
+
+    val ints = Array(4, 6, 5)
+    testImplicit(ints.head)
+    testImplicit(java.lang.Integer.valueOf(ints.head))
+    testImplicit(ints)
+    testImplicit(ints.toSeq)
+    testImplicit(ints.toSeq)(newIntSeqEncoder)
+
+    val longs = Array(System.nanoTime(), System.currentTimeMillis())
+    testImplicit(longs.head)
+    testImplicit(java.lang.Long.valueOf(longs.head))
+    testImplicit(longs)
+    testImplicit(longs.toSeq)
+    testImplicit(longs.toSeq)(newLongSeqEncoder)
+
+    val floats = Array(3f, 10.9f)
+    testImplicit(floats.head)
+    testImplicit(java.lang.Float.valueOf(floats.head))
+    testImplicit(floats)
+    testImplicit(floats.toSeq)
+    testImplicit(floats.toSeq)(newFloatSeqEncoder)
+
+    val doubles = Array(23.78d, -329.6d)
+    testImplicit(doubles.head)
+    testImplicit(java.lang.Double.valueOf(doubles.head))
+    testImplicit(doubles)
+    testImplicit(doubles.toSeq)
+    testImplicit(doubles.toSeq)(newDoubleSeqEncoder)
+
+    val strings = Array("foo", "baz", "bar")
+    testImplicit(strings.head)
+    testImplicit(strings)
+    testImplicit(strings.toSeq)
+    testImplicit(strings.toSeq)(newStringSeqEncoder)
+
+    val myTypes = Array(MyType(12L, Math.E, Math.PI), MyType(0, 0, 0))
+    testImplicit(myTypes.head)
+    testImplicit(myTypes)
+    testImplicit(myTypes.toSeq)
+    testImplicit(myTypes.toSeq)(newProductSeqEncoder[MyType])
+
+    // Others.
+    val decimal = java.math.BigDecimal.valueOf(3141527000000000000L, 18)
+    testImplicit(decimal)
+    testImplicit(BigDecimal(decimal))
+    testImplicit(Date.valueOf(LocalDate.now()))
+    testImplicit(LocalDate.now())
+    testImplicit(LocalDateTime.now())
+    testImplicit(Instant.now())
+    testImplicit(Timestamp.from(Instant.now()))
+    testImplicit(Period.ofYears(2))
+    testImplicit(Duration.ofMinutes(77))
+    testImplicit(SaveMode.Append)
+    testImplicit(Map(("key", "value"), ("foo", "baz")))
+    testImplicit(Set(1, 2, 4))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to