This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 3f062ded22eb [SPARK-52601][SQL][4.0] Support primitive types in 
TransformingEncoder
3f062ded22eb is described below

commit 3f062ded22eb529e74f6b77379c7dcd6b2bf310c
Author: Emil Ejbyfeldt <[email protected]>
AuthorDate: Tue Sep 16 15:11:53 2025 -0700

    [SPARK-52601][SQL][4.0] Support primitive types in TransformingEncoder
    
    Backport of https://github.com/apache/spark/pull/51313 to 4.0 branch.
    
    ### What changes were proposed in this pull request?
    
    Support defining TransformingEncoder that has a primitive type as the input 
type.
    
    ### Why are the changes needed?
    
    This came up for me when using a Scala 3 opaque type around a Long as a 
timestamp but wating have the encoder encode it as a timestamp. Ideally Spark 
would have some way of encoding a micro second timestamp without going through 
a java.sql.Timestamp or java.time.Instant. But this at least makes it possible 
to achive something similar (but less efficient) by defining a 
TransformingEncoder that takes a Long and returns a java.sql.Timestamp.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it allows TransformingEncoder to be used in more cases.
    
    ### How was this patch tested?
    
    New and existing unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #52354 from eejbyfeldt/SPARK-52601-4.0.
    
    Authored-by: Emil Ejbyfeldt <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../spark/sql/catalyst/DeserializerBuildHelper.scala   |  4 ++--
 .../spark/sql/catalyst/encoders/EncoderUtils.scala     |  6 ++++++
 .../sql/catalyst/encoders/ExpressionEncoderSuite.scala | 18 ++++++++++++++++++
 3 files changed, 26 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
index 492ea741236e..5d1bbb024074 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.{expressions => exprs}
 import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, 
UnresolvedExtractValue}
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, 
AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, 
KryoSerializationCodec}
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, 
BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, 
InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, 
JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, 
MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, 
PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, 
PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncode [...]
-import 
org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, 
isNativeEncoder}
+import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{dataTypeForClass, 
externalDataTypeFor, isNativeEncoder}
 import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, 
IsNull, Literal, MapKeys, MapValues, UpCast}
 import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, 
CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke, 
NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, 
UnresolvedMapObjects, WrapOption}
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, 
CharVarcharCodegenUtils, DateTimeUtils, IntervalUtils}
@@ -459,7 +459,7 @@ object DeserializerBuildHelper {
       Invoke(
         Literal.create(provider(), ObjectType(classOf[Codec[_, _]])),
         "decode",
-        ObjectType(tag.runtimeClass),
+        dataTypeForClass(tag.runtimeClass),
         createDeserializer(encoder, path, walkedTypePath) :: Nil)
   }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
index 8f717795605f..16d5adb064da 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
@@ -152,6 +152,12 @@ object EncoderUtils {
     VariantType -> classOf[VariantVal]
   )
 
+  def dataTypeForClass(c: Class[_]): DataType =
+    javaClassToPrimitiveType.get(c).getOrElse(ObjectType(c))
+
+  private val javaClassToPrimitiveType: Map[Class[_], DataType] =
+    typeJavaMapping.iterator.filter(_._2.isPrimitive).map(_.swap).toMap
+
   val typeBoxedJavaMapping: Map[DataType, Class[_]] = Map[DataType, Class[_]](
     BooleanType -> classOf[java.lang.Boolean],
     ByteType -> classOf[java.lang.Byte],
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 616c6d65636d..1b5f1b109c45 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -749,6 +749,24 @@ class ExpressionEncoderSuite extends 
CodegenInterpretedPlanTest with AnalysisTes
     testDataTransformingEnc(enc, data)
   }
 
+  test("SPARK-52601 TransformingEncoder from primitive to timestamp") {
+    val enc: AgnosticEncoder[Long] =
+      TransformingEncoder[Long, java.sql.Timestamp](
+        classTag,
+        TimestampEncoder(true),
+        () =>
+          new Codec[Long, java.sql.Timestamp] with Serializable {
+            override def encode(in: Long): Timestamp = 
Timestamp.from(microsToInstant(in))
+            override def decode(out: Timestamp): Long = 
instantToMicros(out.toInstant)
+        }
+    )
+    val data: Seq[Long] = Seq(0L, 1L, 2L)
+
+    assert(enc.dataType === TimestampType)
+
+    testDataTransformingEnc(enc, data)
+  }
+
   val longEncForTimestamp: AgnosticEncoder[V[Long]] =
     TransformingEncoder[V[Long], java.sql.Timestamp](
       classTag,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to