This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 0843b7741fa [SPARK-44311][CONNECT][SQL] Improved support for UDFs on 
value classes
0843b7741fa is described below

commit 0843b7741fa959173fcc66067eedda9be501192c
Author: Emil Ejbyfeldt <eejbyfe...@liveintent.com>
AuthorDate: Tue Aug 1 10:50:04 2023 -0400

    [SPARK-44311][CONNECT][SQL] Improved support for UDFs on value classes
    
    ### What changes were proposed in this pull request?
    
    This pr fixes using UDFs on value classes when it serialized as in 
underlying type. Previously it would only work if one either defined a UDF 
taking the underlying type and/or for cases where the schema derived does not 
"unbox" the value to its underlying type.
    
    Before this change the following code:
    ```
    final case class ValueClass(a: Int) extends AnyVal
    final case class Wrapper(v: ValueClass)
    
    val f = udf((a: ValueClass) => a.a > 0)
    
    spark.createDataset(Seq(Wrapper(ValueClass(1)))).filter(f(col("v"))).show()
    ```
    would fails with
    ```
    java.lang.ClassCastException: class org.apache.spark.sql.types.IntegerType$ 
cannot be cast to class org.apache.spark.sql.types.StructType 
(org.apache.spark.sql.types.IntegerType$ and 
org.apache.spark.sql.types.StructType are in unnamed module of loader 'app')
      at 
org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveEncodersInUDF$$anonfun$apply$42$$anonfun$applyOrElse$218.$anonfun$applyOrElse$220(Analyzer.scala:3241)
      at scala.Option.map(Option.scala:242)
      at 
org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveEncodersInUDF$$anonfun$apply$42$$anonfun$applyOrElse$218.$anonfun$applyOrElse$219(Analyzer.scala:3239)
      at scala.collection.immutable.List.map(List.scala:246)
      at scala.collection.immutable.List.map(List.scala:79)
      at 
org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveEncodersInUDF$$anonfun$apply$42$$anonfun$applyOrElse$218.applyOrElse(Analyzer.scala:3237)
      at 
org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveEncodersInUDF$$anonfun$apply$42$$anonfun$applyOrElse$218.applyOrElse(Analyzer.scala:3234)
      at 
org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUpWithPruning$2(TreeNode.scala:566)
      at 
org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:104)
      at 
org.apache.spark.sql.catalyst.trees.TreeNode.transformUpWithPruning(TreeNode.scala:566)
    ```
    
    ### Why are the changes needed?
    This is something as a user I would expect to just work.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, it if fixes using a UDF on value class that is serialized as it 
underlying type.
    
    ### How was this patch tested?
    Existing test and new tests cases in DatasetSuite.scala
    
    Closes #41876 from eejbyfeldt/SPARK-44311.
    
    Authored-by: Emil Ejbyfeldt <eejbyfe...@liveintent.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
    (cherry picked from commit 821026bc730ce87e6e97d304c7673bfcb23fd03a)
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  7 ++++++-
 .../spark/sql/catalyst/expressions/ScalaUDF.scala  |  4 +++-
 .../scala/org/apache/spark/sql/DatasetSuite.scala  | 24 ++++++++++++++++++++++
 3 files changed, 33 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 30c6e4b4bc0..7f2471c9e19 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3245,7 +3245,12 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
             val dataType = udf.children(i).dataType
             encOpt.map { enc =>
               val attrs = if (enc.isSerializedAsStructForTopLevel) {
-                DataTypeUtils.toAttributes(dataType.asInstanceOf[StructType])
+                // Value class that has been replaced with its underlying type
+                if (enc.schema.fields.size == 1 && 
enc.schema.fields.head.dataType == dataType) {
+                  
DataTypeUtils.toAttributes(enc.schema.asInstanceOf[StructType])
+                } else {
+                  DataTypeUtils.toAttributes(dataType.asInstanceOf[StructType])
+                }
               } else {
                 // the field name doesn't matter here, so we use
                 // a simple literal to avoid any overhead
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 40274a83340..910960bf84b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -162,7 +162,9 @@ case class ScalaUDF(
     if (useEncoder) {
       val enc = inputEncoders(i).get
       val fromRow = enc.createDeserializer()
-      val converter = if (enc.isSerializedAsStructForTopLevel) {
+      val unwrappedValueClass = enc.isSerializedAsStruct &&
+        enc.schema.fields.size == 1 && enc.schema.fields.head.dataType == 
dataType
+      val converter = if (enc.isSerializedAsStructForTopLevel && 
!unwrappedValueClass) {
         row: Any => fromRow(row.asInstanceOf[InternalRow])
       } else {
         val inputRow = new GenericInternalRow(1)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index a021b049cf0..c967540541a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -2514,6 +2514,27 @@ class DatasetSuite extends QueryTest
       }
     }
   }
+
+  test("SPARK-44311: UDF on value class taking underlying type (backwards 
compatability)") {
+    val f = udf((v: Int) => v > 1)
+    val ds = Seq(ValueClassContainer(ValueClass(1)), 
ValueClassContainer(ValueClass(2))).toDS()
+
+    checkDataset(ds.filter(f(col("v"))), ValueClassContainer(ValueClass(2)))
+  }
+
+  test("SPARK-44311: UDF on value class field in product") {
+    val f = udf((v: ValueClass) => v.i > 1)
+    val ds = Seq(ValueClassContainer(ValueClass(1)), 
ValueClassContainer(ValueClass(2))).toDS()
+
+    checkDataset(ds.filter(f(col("v"))), ValueClassContainer(ValueClass(2)))
+  }
+
+  test("SPARK-44311: UDF on value class this is stored as a struct") {
+    val f = udf((v: ValueClass) => v.i > 1)
+    val ds = Seq(Tuple1(ValueClass(1)), Tuple1(ValueClass(2))).toDS()
+
+    checkDataset(ds.filter(f(col("_1"))), Tuple1(ValueClass(2)))
+  }
 }
 
 class DatasetLargeResultCollectingSuite extends QueryTest
@@ -2545,6 +2566,9 @@ class DatasetLargeResultCollectingSuite extends QueryTest
   }
 }
 
+case class ValueClass(i: Int) extends AnyVal
+case class ValueClassContainer(v: ValueClass)
+
 case class Bar(a: Int)
 
 object AssertExecutionId {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to