This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 1d8f74a  [SPARK-34545][SQL] Fix issues with valueCompare feature of 
pyrolite
1d8f74a is described below

commit 1d8f74ac0053e9af70b289866dd6055547fce21e
Author: Peter Toth <peter.t...@gmail.com>
AuthorDate: Sun Mar 7 19:12:42 2021 -0600

    [SPARK-34545][SQL] Fix issues with valueCompare feature of pyrolite
    
    ### What changes were proposed in this pull request?
    
    pyrolite 4.21 introduced and enabled value comparison by default 
(`valueCompare=true`) during object memoization and serialization: 
https://github.com/irmen/Pyrolite/blob/pyrolite-4.21/java/src/main/java/net/razorvine/pickle/Pickler.java#L112-L122
    This change has undesired effect when we serialize a row (actually 
`GenericRowWithSchema`) to be passed to python: 
https://github.com/apache/spark/blob/branch-3.0/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala#L60.
 A simple example is that
    ```
    new GenericRowWithSchema(Array(1.0, 1.0), StructType(Seq(StructField("_1", 
DoubleType), StructField("_2", DoubleType))))
    ```
    and
    ```
    new GenericRowWithSchema(Array(1, 1), StructType(Seq(StructField("_1", 
IntegerType), StructField("_2", IntegerType))))
    ```
    are currently equal and the second instance is replaced to the short code 
of the first one during serialization.
    
    ### Why are the changes needed?
    The above can cause nasty issues like the one in 
https://issues.apache.org/jira/browse/SPARK-34545 description:
    
    ```
    >>> from pyspark.sql.functions import udf
    >>> from pyspark.sql.types import *
    >>>
    >>> def udf1(data_type):
            def u1(e):
                return e[0]
            return udf(u1, data_type)
    >>>
    >>> df = spark.createDataFrame([((1.0, 1.0), (1, 1))], ['c1', 'c2'])
    >>>
    >>> df = df.withColumn("c3", udf1(DoubleType())("c1"))
    >>> df = df.withColumn("c4", udf1(IntegerType())("c2"))
    >>>
    >>> df.select("c3").show()
    +---+
    | c3|
    +---+
    |1.0|
    +---+
    
    >>> df.select("c4").show()
    +---+
    | c4|
    +---+
    |  1|
    +---+
    
    >>> df.select("c3", "c4").show()
    +---+----+
    | c3|  c4|
    +---+----+
    |1.0|null|
    +---+----+
    ```
    This is because during serialization from JVM to Python 
`GenericRowWithSchema(1.0, 1.0)` (`c1`) is memoized first and when 
`GenericRowWithSchema(1, 1)` (`c2`) comes next, it is replaced to some short 
code of the `c1` (instead of serializing `c2` out) as they are `equal()`. The 
python functions then runs but the return type of `c4` is expected to be 
`IntegerType` and if a different type (`DoubleType`) comes back from python 
then it is discarded: https://github.com/apache/spark/blob/bra [...]
    
    After this PR:
    ```
    >>> df.select("c3", "c4").show()
    +---+---+
    | c3| c4|
    +---+---+
    |1.0|  1|
    +---+---+
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, fixes a correctness issue.
    
    ### How was this patch tested?
    Added new UT + manual tests.
    
    Closes #31682 from peter-toth/SPARK-34545-fix-row-comparison.
    
    Authored-by: Peter Toth <peter.t...@gmail.com>
    Signed-off-by: Sean Owen <sro...@gmail.com>
---
 .../main/scala/org/apache/spark/api/python/SerDeUtil.scala  |  9 ++++++---
 .../org/apache/spark/mllib/api/python/PythonMLLibAPI.scala  |  6 ++++--
 python/pyspark/sql/tests/test_udf.py                        | 11 +++++++++++
 .../spark/sql/execution/python/BatchEvalPythonExec.scala    | 13 ++++++++++++-
 4 files changed, 33 insertions(+), 6 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala 
b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
index dc2587a..dd962ca 100644
--- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
@@ -78,7 +78,8 @@ private[spark] object SerDeUtil extends Logging {
    * Choose batch size based on size of objects
    */
   private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends 
Iterator[Array[Byte]] {
-    private val pickle = new Pickler()
+    private val pickle = new Pickler(/* useMemo = */ true,
+      /* valueCompare = */ false)
     private var batch = 1
     private val buffer = new mutable.ArrayBuffer[Any]
 
@@ -131,7 +132,8 @@ private[spark] object SerDeUtil extends Logging {
   }
 
   private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = {
-    val pickle = new Pickler
+    val pickle = new Pickler(/* useMemo = */ true,
+      /* valueCompare = */ false)
     val kt = Try {
       pickle.dumps(t._1)
     }
@@ -182,7 +184,8 @@ private[spark] object SerDeUtil extends Logging {
       if (batchSize == 0) {
         new AutoBatchedPickler(cleaned)
       } else {
-        val pickle = new Pickler
+        val pickle = new Pickler(/* useMemo = */ true,
+          /* valueCompare = */ false)
         cleaned.grouped(batchSize).map(batched => pickle.dumps(batched.asJava))
       }
     }
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 68f6ed4..92ae6b5 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -1313,8 +1313,10 @@ private[spark] abstract class SerDeBase {
   def dumps(obj: AnyRef): Array[Byte] = {
     obj match {
       // Pickler in Python side cannot deserialize Scala Array normally. See 
SPARK-12834.
-      case array: Array[_] => new Pickler().dumps(array.toSeq.asJava)
-      case _ => new Pickler().dumps(obj)
+      case array: Array[_] => new Pickler(/* useMemo = */ true,
+        /* valueCompare = */ false).dumps(array.toSeq.asJava)
+      case _ => new Pickler(/* useMemo = */ true,
+        /* valueCompare = */ false).dumps(obj)
     }
   }
 
diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index bfc55df..0d13361 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -674,6 +674,17 @@ class UDFTests(ReusedSQLTestCase):
         self.assertEqual(df.select(udf(func)("id"))._jdf.queryExecution()
                          .withCachedData().getClass().getSimpleName(), 
'InMemoryRelation')
 
+    # SPARK-34545
+    def test_udf_input_serialization_valuecompare_disabled(self):
+        def f(e):
+            return e[0]
+
+        df = self.spark.createDataFrame([((1.0, 1.0), (1, 1))], ['c1', 'c2'])
+        result = df.select("*", udf(f, DoubleType())("c1").alias('c3'),
+                           udf(f, IntegerType())("c2").alias('c4'))
+        self.assertEqual(result.collect(),
+                         [Row(c1=Row(_1=1.0, _2=1.0), c2=Row(_1=1, _2=1), 
c3=1.0, c4=1)])
+
 
 class UDFInitializationTests(unittest.TestCase):
     def tearDown(self):
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index b6d8e59..2ab7262 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -46,7 +46,18 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], 
resultAttrs: Seq[Attribute]
     val needConversion = 
dataTypes.exists(EvaluatePython.needConversionInPython)
 
     // enable memo iff we serialize the row with schema (schema and class 
should be memorized)
-    val pickle = new Pickler(needConversion)
+    // pyrolite 4.21+ can lookup objects in its cache by value, but 
`GenericRowWithSchema` objects,
+    // that we pass from JVM to Python, don't define their `equals()` to take 
the type of the
+    // values or the schema of the row into account. This causes like
+    // `GenericRowWithSchema(Array(1.0, 1.0),
+    //    StructType(Seq(StructField("_1", DoubleType), StructField("_2", 
DoubleType))))`
+    // and
+    // `GenericRowWithSchema(Array(1, 1),
+    //    StructType(Seq(StructField("_1", IntegerType), StructField("_2", 
IntegerType))))`
+    // to be `equal()` and so we need to disable this feature explicitly 
(`valueCompare=false`).
+    // Please note that cache by reference is still enabled depending on 
`needConversion`.
+    val pickle = new Pickler(/* useMemo = */ needConversion,
+      /* valueCompare = */ false)
     // Input iterator to Python: input rows are grouped so we send them in 
batches to Python.
     // For each row, add it to the queue.
     val inputIterator = iter.map { row =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to