This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6380859  [SPARK-36702][SQL][FOLLOWUP] ArrayUnion handle duplicated 
Double.NaN and Float.NaN
6380859 is described below

commit 638085953f931f98241856c9f652e5f15202fcc0
Author: Angerszhuuuu <angers....@gmail.com>
AuthorDate: Wed Sep 15 22:04:09 2021 +0800

    [SPARK-36702][SQL][FOLLOWUP] ArrayUnion handle duplicated Double.NaN and 
Float.NaN
    
    ### What changes were proposed in this pull request?
    According to 
https://github.com/apache/spark/pull/33955#discussion_r708570515 use normalized 
 NaN
    
    ### Why are the changes needed?
    Use normalized NaN for duplicated NaN value
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Exiting UT
    
    Closes #34003 from AngersZhuuuu/SPARK-36702-FOLLOWUP.
    
    Authored-by: Angerszhuuuu <angers....@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/expressions/collectionOperations.scala     | 13 ++++++++-----
 .../scala/org/apache/spark/sql/util/SQLOpenHashSet.scala    |  8 ++++++++
 2 files changed, 16 insertions(+), 5 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index e5620a1..47b2719 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -3578,6 +3578,7 @@ case class ArrayUnion(left: Expression, right: 
Expression) extends ArrayBinaryLi
         val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
         val hs = new SQLOpenHashSet[Any]()
         val isNaN = SQLOpenHashSet.isNaN(elementType)
+        val valueNaN = SQLOpenHashSet.valueNaN(elementType)
         Seq(array1, array2).foreach { array =>
           var i = 0
           while (i < array.numElements()) {
@@ -3590,7 +3591,7 @@ case class ArrayUnion(left: Expression, right: 
Expression) extends ArrayBinaryLi
               val elem = array.get(i, elementType)
               if (isNaN(elem)) {
                 if (!hs.containsNaN) {
-                  arrayBuffer += elem
+                  arrayBuffer += valueNaN
                   hs.addNaN
                 }
               } else {
@@ -3688,16 +3689,18 @@ case class ArrayUnion(left: Expression, right: 
Expression) extends ArrayBinaryLi
 
         def withNaNCheck(body: String): String = {
           (elementType match {
-            case DoubleType => Some(s"java.lang.Double.isNaN((double)$value)")
-            case FloatType => Some(s"java.lang.Float.isNaN((float)$value)")
+            case DoubleType =>
+              Some((s"java.lang.Double.isNaN((double)$value)", 
"java.lang.Double.NaN"))
+            case FloatType =>
+              Some((s"java.lang.Float.isNaN((float)$value)", 
"java.lang.Float.NaN"))
             case _ => None
-          }).map { isNaN =>
+          }).map { case (isNaN, valueNaN) =>
             s"""
                |if ($isNaN) {
                |  if (!$hashSet.containsNaN()) {
                |     $size++;
                |     $hashSet.addNaN();
-               |     $builder.$$plus$$eq($value);
+               |     $builder.$$plus$$eq($valueNaN);
                |  }
                |} else {
                |  $body
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala
index 5ffe733..083cfdd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala
@@ -69,4 +69,12 @@ object SQLOpenHashSet {
       case _ => (_: Any) => false
     }
   }
+
+  def valueNaN(dataType: DataType): Any = {
+    dataType match {
+      case DoubleType => java.lang.Double.NaN
+      case FloatType => java.lang.Float.NaN
+      case _ => null
+    }
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to