This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 816774aa532 [SPARK-42610][CONNECT] Add encoders to SQLImplicits 816774aa532 is described below commit 816774aa532ae4e017c937c86fdb784df200ee0e Author: Herman van Hovell <her...@databricks.com> AuthorDate: Mon Feb 27 21:32:49 2023 -0800 [SPARK-42610][CONNECT] Add encoders to SQLImplicits ### What changes were proposed in this pull request? Add implicit encoder resolution to `SQLImplicits` class. ### Why are the changes needed? API parity. ### Does this PR introduce _any_ user-facing change? Yes. ### How was this patch tested? Added test to `SQLImplicitsTestSuite`. Closes #40205 from hvanhovell/SPARK-42610. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Dongjoon Hyun <dongj...@apache.org> (cherry picked from commit 968f280fd0d488372b0b09738ff9728b45499bef) Signed-off-by: Dongjoon Hyun <dongj...@apache.org> --- .../scala/org/apache/spark/sql/SQLImplicits.scala | 240 ++++++++++++++++++++- .../scala/org/apache/spark/sql/SparkSession.scala | 2 +- .../apache/spark/sql/SQLImplicitsTestSuite.scala | 95 ++++++++ 3 files changed, 334 insertions(+), 3 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index e63c9481da5..8f429541def 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,13 +17,20 @@ package org.apache.spark.sql import scala.language.implicitConversions +import scala.reflect.classTag +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ /** - * A collection of implicit methods for converting names and Symbols into [[Column]]s. + * A collection of implicit methods for converting names and Symbols into [[Column]]s, and for + * converting common Scala objects into [[Dataset]]s. * * @since 3.4.0 */ -abstract class SQLImplicits { +abstract class SQLImplicits extends LowPrioritySQLImplicits { /** * Converts $"col name" into a [[Column]]. @@ -41,4 +48,233 @@ abstract class SQLImplicits { * @since 3.4.0 */ implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) + + /** @since 3.4.0 */ + implicit val newIntEncoder: Encoder[Int] = PrimitiveIntEncoder + + /** @since 3.4.0 */ + implicit val newLongEncoder: Encoder[Long] = PrimitiveLongEncoder + + /** @since 3.4.0 */ + implicit val newDoubleEncoder: Encoder[Double] = PrimitiveDoubleEncoder + + /** @since 3.4.0 */ + implicit val newFloatEncoder: Encoder[Float] = PrimitiveFloatEncoder + + /** @since 3.4.0 */ + implicit val newByteEncoder: Encoder[Byte] = PrimitiveByteEncoder + + /** @since 3.4.0 */ + implicit val newShortEncoder: Encoder[Short] = PrimitiveShortEncoder + + /** @since 3.4.0 */ + implicit val newBooleanEncoder: Encoder[Boolean] = PrimitiveBooleanEncoder + + /** @since 3.4.0 */ + implicit val newStringEncoder: Encoder[String] = StringEncoder + + /** @since 3.4.0 */ + implicit val newJavaDecimalEncoder: Encoder[java.math.BigDecimal] = + AgnosticEncoders.DEFAULT_JAVA_DECIMAL_ENCODER + + /** @since 3.4.0 */ + implicit val newScalaDecimalEncoder: Encoder[scala.math.BigDecimal] = + AgnosticEncoders.DEFAULT_SCALA_DECIMAL_ENCODER + + /** @since 3.4.0 */ + implicit val newDateEncoder: Encoder[java.sql.Date] = AgnosticEncoders.STRICT_DATE_ENCODER + + /** @since 3.4.0 */ + implicit val newLocalDateEncoder: Encoder[java.time.LocalDate] = + AgnosticEncoders.STRICT_LOCAL_DATE_ENCODER + + /** @since 3.4.0 */ + implicit val newLocalDateTimeEncoder: Encoder[java.time.LocalDateTime] = + AgnosticEncoders.LocalDateTimeEncoder + + /** @since 3.4.0 */ + implicit val newTimeStampEncoder: Encoder[java.sql.Timestamp] = + AgnosticEncoders.STRICT_TIMESTAMP_ENCODER + + /** @since 3.4.0 */ + implicit val newInstantEncoder: Encoder[java.time.Instant] = + AgnosticEncoders.STRICT_INSTANT_ENCODER + + /** @since 3.4.0 */ + implicit val newDurationEncoder: Encoder[java.time.Duration] = DayTimeIntervalEncoder + + /** @since 3.4.0 */ + implicit val newPeriodEncoder: Encoder[java.time.Period] = YearMonthIntervalEncoder + + /** @since 3.4.0 */ + implicit def newJavaEnumEncoder[A <: java.lang.Enum[_]: TypeTag]: Encoder[A] = { + ScalaReflection.encoderFor[A] + } + + // Boxed primitives + + /** @since 3.4.0 */ + implicit val newBoxedIntEncoder: Encoder[java.lang.Integer] = BoxedIntEncoder + + /** @since 3.4.0 */ + implicit val newBoxedLongEncoder: Encoder[java.lang.Long] = BoxedLongEncoder + + /** @since 3.4.0 */ + implicit val newBoxedDoubleEncoder: Encoder[java.lang.Double] = BoxedDoubleEncoder + + /** @since 3.4.0 */ + implicit val newBoxedFloatEncoder: Encoder[java.lang.Float] = BoxedFloatEncoder + + /** @since 3.4.0 */ + implicit val newBoxedByteEncoder: Encoder[java.lang.Byte] = BoxedByteEncoder + + /** @since 3.4.0 */ + implicit val newBoxedShortEncoder: Encoder[java.lang.Short] = BoxedShortEncoder + + /** @since 3.4.0 */ + implicit val newBoxedBooleanEncoder: Encoder[java.lang.Boolean] = BoxedBooleanEncoder + + // Seqs + private def newSeqEncoder[E](elementEncoder: AgnosticEncoder[E]): AgnosticEncoder[Seq[E]] = { + IterableEncoder( + classTag[Seq[E]], + elementEncoder, + elementEncoder.nullable, + elementEncoder.lenientSerialization) + } + + /** + * @since 3.4.0 + * @deprecated + * use [[newSequenceEncoder]] + */ + val newIntSeqEncoder: Encoder[Seq[Int]] = newSeqEncoder(PrimitiveIntEncoder) + + /** + * @since 3.4.0 + * @deprecated + * use [[newSequenceEncoder]] + */ + val newLongSeqEncoder: Encoder[Seq[Long]] = newSeqEncoder(PrimitiveLongEncoder) + + /** + * @since 3.4.0 + * @deprecated + * use [[newSequenceEncoder]] + */ + val newDoubleSeqEncoder: Encoder[Seq[Double]] = newSeqEncoder(PrimitiveDoubleEncoder) + + /** + * @since 3.4.0 + * @deprecated + * use [[newSequenceEncoder]] + */ + val newFloatSeqEncoder: Encoder[Seq[Float]] = newSeqEncoder(PrimitiveFloatEncoder) + + /** + * @since 3.4.0 + * @deprecated + * use [[newSequenceEncoder]] + */ + val newByteSeqEncoder: Encoder[Seq[Byte]] = newSeqEncoder(PrimitiveByteEncoder) + + /** + * @since 3.4.0 + * @deprecated + * use [[newSequenceEncoder]] + */ + val newShortSeqEncoder: Encoder[Seq[Short]] = newSeqEncoder(PrimitiveShortEncoder) + + /** + * @since 3.4.0 + * @deprecated + * use [[newSequenceEncoder]] + */ + val newBooleanSeqEncoder: Encoder[Seq[Boolean]] = newSeqEncoder(PrimitiveBooleanEncoder) + + /** + * @since 3.4.0 + * @deprecated + * use [[newSequenceEncoder]] + */ + val newStringSeqEncoder: Encoder[Seq[String]] = newSeqEncoder(StringEncoder) + + /** + * @since 3.4.0 + * @deprecated + * use [[newSequenceEncoder]] + */ + def newProductSeqEncoder[A <: Product: TypeTag]: Encoder[Seq[A]] = + newSeqEncoder(ScalaReflection.encoderFor[A]) + + /** @since 3.4.0 */ + implicit def newSequenceEncoder[T <: Seq[_]: TypeTag]: Encoder[T] = + ScalaReflection.encoderFor[T] + + // Maps + /** @since 3.4.0 */ + implicit def newMapEncoder[T <: Map[_, _]: TypeTag]: Encoder[T] = ScalaReflection.encoderFor[T] + + /** + * Notice that we serialize `Set` to Catalyst array. The set property is only kept when + * manipulating the domain objects. The serialization format doesn't keep the set property. When + * we have a Catalyst array which contains duplicated elements and convert it to + * `Dataset[Set[T]]` by using the encoder, the elements will be de-duplicated. + * + * @since 3.4.0 + */ + implicit def newSetEncoder[T <: Set[_]: TypeTag]: Encoder[T] = ScalaReflection.encoderFor[T] + + // Arrays + private def newArrayEncoder[E]( + elementEncoder: AgnosticEncoder[E]): AgnosticEncoder[Array[E]] = { + ArrayEncoder(elementEncoder, elementEncoder.nullable) + } + + /** @since 3.4.0 */ + implicit val newIntArrayEncoder: Encoder[Array[Int]] = newArrayEncoder(PrimitiveIntEncoder) + + /** @since 3.4.0 */ + implicit val newLongArrayEncoder: Encoder[Array[Long]] = newArrayEncoder(PrimitiveLongEncoder) + + /** @since 3.4.0 */ + implicit val newDoubleArrayEncoder: Encoder[Array[Double]] = + newArrayEncoder(PrimitiveDoubleEncoder) + + /** @since 3.4.0 */ + implicit val newFloatArrayEncoder: Encoder[Array[Float]] = newArrayEncoder( + PrimitiveFloatEncoder) + + /** @since 3.4.0 */ + implicit val newByteArrayEncoder: Encoder[Array[Byte]] = BinaryEncoder + + /** @since 3.4.0 */ + implicit val newShortArrayEncoder: Encoder[Array[Short]] = newArrayEncoder( + PrimitiveShortEncoder) + + /** @since 3.4.0 */ + implicit val newBooleanArrayEncoder: Encoder[Array[Boolean]] = + newArrayEncoder(PrimitiveBooleanEncoder) + + /** @since 3.4.0 */ + implicit val newStringArrayEncoder: Encoder[Array[String]] = newArrayEncoder(StringEncoder) + + /** @since 3.4.0 */ + implicit def newProductArrayEncoder[A <: Product: TypeTag]: Encoder[Array[A]] = { + newArrayEncoder(ScalaReflection.encoderFor[A]) + } +} + +/** + * Lower priority implicit methods for converting Scala objects into [[Dataset]]s. Conflicting + * implicits are placed here to disambiguate resolution. + * + * Reasons for including specific implicits: newProductEncoder - to disambiguate for `List`s which + * are both `Seq` and `Product` + */ +trait LowPrioritySQLImplicits { + + /** @since 3.4.0 */ + implicit def newProductEncoder[T <: Product: TypeTag]: Encoder[T] = + ScalaReflection.encoderFor[T] } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 3aed781855c..fa13af00f14 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -207,7 +207,7 @@ class SparkSession( // Disable style checker so "implicits" object can start with lowercase i /** * (Scala-specific) Implicit methods available in Scala for converting common names and - * [[Symbol]]s into [[Column]]s. + * [[Symbol]]s into [[Column]]s, and for converting common Scala objects into `DataFrame`s. * * {{{ * val sparkSession = SparkSession.builder.getOrCreate() diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala index 1f141d7c71a..3fcc135a22e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala @@ -16,15 +16,21 @@ */ package org.apache.spark.sql +import java.sql.{Date, Timestamp} +import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} import java.util.concurrent.atomic.AtomicLong import io.grpc.inprocess.InProcessChannelBuilder import org.scalatest.BeforeAndAfterAll import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder} import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.connect.client.util.ConnectFunSuite +/** + * Test suite for SQL implicits. + */ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { private var session: SparkSession = _ @@ -44,4 +50,93 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { assertEqual($"x", Column("x")) assertEqual('y, Column("y")) } + + test("test implicit encoder resolution") { + val spark = session + import spark.implicits._ + def testImplicit[T: Encoder](expected: T): Unit = { + val encoder = implicitly[Encoder[T]].asInstanceOf[AgnosticEncoder[T]] + val expressionEncoder = ExpressionEncoder(encoder).resolveAndBind() + val serializer = expressionEncoder.createSerializer() + val deserializer = expressionEncoder.createDeserializer() + val actual = deserializer(serializer(expected)) + assert(actual === expected) + } + + val booleans = Array(false, true, false, false) + testImplicit(booleans.head) + testImplicit(java.lang.Boolean.valueOf(booleans.head)) + testImplicit(booleans) + testImplicit(booleans.toSeq) + testImplicit(booleans.toSeq)(newBooleanSeqEncoder) + + val bytes = Array(76.toByte, 59.toByte, 121.toByte) + testImplicit(bytes.head) + testImplicit(java.lang.Byte.valueOf(bytes.head)) + testImplicit(bytes) + testImplicit(bytes.toSeq) + testImplicit(bytes.toSeq)(newByteSeqEncoder) + + val shorts = Array(21.toShort, (-213).toShort, 14876.toShort) + testImplicit(shorts.head) + testImplicit(java.lang.Short.valueOf(shorts.head)) + testImplicit(shorts) + testImplicit(shorts.toSeq) + testImplicit(shorts.toSeq)(newShortSeqEncoder) + + val ints = Array(4, 6, 5) + testImplicit(ints.head) + testImplicit(java.lang.Integer.valueOf(ints.head)) + testImplicit(ints) + testImplicit(ints.toSeq) + testImplicit(ints.toSeq)(newIntSeqEncoder) + + val longs = Array(System.nanoTime(), System.currentTimeMillis()) + testImplicit(longs.head) + testImplicit(java.lang.Long.valueOf(longs.head)) + testImplicit(longs) + testImplicit(longs.toSeq) + testImplicit(longs.toSeq)(newLongSeqEncoder) + + val floats = Array(3f, 10.9f) + testImplicit(floats.head) + testImplicit(java.lang.Float.valueOf(floats.head)) + testImplicit(floats) + testImplicit(floats.toSeq) + testImplicit(floats.toSeq)(newFloatSeqEncoder) + + val doubles = Array(23.78d, -329.6d) + testImplicit(doubles.head) + testImplicit(java.lang.Double.valueOf(doubles.head)) + testImplicit(doubles) + testImplicit(doubles.toSeq) + testImplicit(doubles.toSeq)(newDoubleSeqEncoder) + + val strings = Array("foo", "baz", "bar") + testImplicit(strings.head) + testImplicit(strings) + testImplicit(strings.toSeq) + testImplicit(strings.toSeq)(newStringSeqEncoder) + + val myTypes = Array(MyType(12L, Math.E, Math.PI), MyType(0, 0, 0)) + testImplicit(myTypes.head) + testImplicit(myTypes) + testImplicit(myTypes.toSeq) + testImplicit(myTypes.toSeq)(newProductSeqEncoder[MyType]) + + // Others. + val decimal = java.math.BigDecimal.valueOf(3141527000000000000L, 18) + testImplicit(decimal) + testImplicit(BigDecimal(decimal)) + testImplicit(Date.valueOf(LocalDate.now())) + testImplicit(LocalDate.now()) + testImplicit(LocalDateTime.now()) + testImplicit(Instant.now()) + testImplicit(Timestamp.from(Instant.now())) + testImplicit(Period.ofYears(2)) + testImplicit(Duration.ofMinutes(77)) + testImplicit(SaveMode.Append) + testImplicit(Map(("key", "value"), ("foo", "baz"))) + testImplicit(Set(1, 2, 4)) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org