This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new b96cbfb  [SPARK-36753][SQL] ArrayExcept handle duplicated Double.NaN 
and Float.NaN
b96cbfb is described below

commit b96cbfb27d33f28dc5526dcf626642244125957d
Author: Angerszhuuuu <angers....@gmail.com>
AuthorDate: Wed Sep 22 23:51:41 2021 +0800

    [SPARK-36753][SQL] ArrayExcept handle duplicated Double.NaN and Float.NaN
    
    ### What changes were proposed in this pull request?
    For query
    ```
    select array_except(array(cast('nan' as double), 1d), array(cast('nan' as 
double)))
    ```
    This returns [NaN, 1d], but it should return [1d].
    This issue is caused by `OpenHashSet` can't handle `Double.NaN` and 
`Float.NaN` too.
    In this pr fix this based on https://github.com/apache/spark/pull/33955
    
    ### Why are the changes needed?
    Fix bug
    
    ### Does this PR introduce _any_ user-facing change?
    ArrayExcept won't show handle equal `NaN` value
    
    ### How was this patch tested?
    Added UT
    
    Closes #33994 from AngersZhuuuu/SPARK-36753.
    
    Authored-by: Angerszhuuuu <angers....@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit a7cbe699863a6b68d27bdf3934dda7d396d80404)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../expressions/collectionOperations.scala         | 61 +++++++++++++---------
 .../expressions/CollectionExpressionsSuite.scala   | 17 ++++++
 2 files changed, 53 insertions(+), 25 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index c153181..d89c77e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -37,7 +37,6 @@ import org.apache.spark.unsafe.UTF8StringBuilder
 import org.apache.spark.unsafe.array.ByteArrayMethods
 import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
 import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String}
-import org.apache.spark.util.collection.OpenHashSet
 
 /**
  * Base trait for [[BinaryExpression]]s with two arrays of the same element 
type and implicit
@@ -3839,32 +3838,38 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArrayBinaryL
   @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = {
     if (TypeUtils.typeWithProperEquals(elementType)) {
       (array1, array2) =>
-        val hs = new OpenHashSet[Any]
-        var notFoundNullElement = true
+        val hs = new SQLOpenHashSet[Any]
+        val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+        val withArray2NaNCheckFunc = 
SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
+          (value: Any) => hs.add(value),
+          (valueNaN: Any) => {})
+        val withArray1NaNCheckFunc = 
SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
+          (value: Any) =>
+            if (!hs.contains(value)) {
+              arrayBuffer += value
+              hs.add(value)
+            },
+          (valueNaN: Any) => arrayBuffer += valueNaN)
         var i = 0
         while (i < array2.numElements()) {
           if (array2.isNullAt(i)) {
-            notFoundNullElement = false
+            hs.addNull()
           } else {
             val elem = array2.get(i, elementType)
-            hs.add(elem)
+            withArray2NaNCheckFunc(elem)
           }
           i += 1
         }
-        val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
         i = 0
         while (i < array1.numElements()) {
           if (array1.isNullAt(i)) {
-            if (notFoundNullElement) {
+            if (!hs.containsNull()) {
               arrayBuffer += null
-              notFoundNullElement = false
+              hs.addNull()
             }
           } else {
             val elem = array1.get(i, elementType)
-            if (!hs.contains(elem)) {
-              arrayBuffer += elem
-              hs.add(elem)
-            }
+            withArray1NaNCheckFunc(elem)
           }
           i += 1
         }
@@ -3933,10 +3938,9 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArrayBinaryL
       val ptName = CodeGenerator.primitiveTypeName(jt)
 
       nullSafeCodeGen(ctx, ev, (array1, array2) => {
-        val notFoundNullElement = ctx.freshName("notFoundNullElement")
         val nullElementIndex = ctx.freshName("nullElementIndex")
         val builder = ctx.freshName("builder")
-        val openHashSet = classOf[OpenHashSet[_]].getName
+        val openHashSet = classOf[SQLOpenHashSet[_]].getName
         val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
         val hashSet = ctx.freshName("hashSet")
         val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
@@ -3947,7 +3951,7 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArrayBinaryL
             if (left.dataType.asInstanceOf[ArrayType].containsNull) {
               s"""
                  |if ($array2.isNullAt($i)) {
-                 |  $notFoundNullElement = false;
+                 |  $hashSet.addNull();
                  |} else {
                  |  $body
                  |}
@@ -3965,18 +3969,18 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArrayBinaryL
           }
 
         val writeArray2ToHashSet = withArray2NullCheck(
-          s"""
-             |$jt $value = ${genGetValue(array2, i)};
-             |$hashSet.add$hsPostFix($hsValueCast$value);
-           """.stripMargin)
+          s"$jt $value = ${genGetValue(array2, i)};" +
+            SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet,
+              s"$hashSet.add$hsPostFix($hsValueCast$value);",
+              (valueNaN: Any) => ""))
 
         def withArray1NullAssignment(body: String) =
           if (left.dataType.asInstanceOf[ArrayType].containsNull) {
             s"""
                |if ($array1.isNullAt($i)) {
-               |  if ($notFoundNullElement) {
+               |  if (!$hashSet.containsNull()) {
+               |    $hashSet.addNull();
                |    $nullElementIndex = $size;
-               |    $notFoundNullElement = false;
                |    $size++;
                |    $builder.$$plus$$eq($nullValueHolder);
                |  }
@@ -3988,9 +3992,8 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArrayBinaryL
             body
           }
 
-        val processArray1 = withArray1NullAssignment(
+        val body =
           s"""
-             |$jt $value = ${genGetValue(array1, i)};
              |if (!$hashSet.contains($hsValueCast$value)) {
              |  if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
              |    break;
@@ -3998,12 +4001,20 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArrayBinaryL
              |  $hashSet.add$hsPostFix($hsValueCast$value);
              |  $builder.$$plus$$eq($value);
              |}
-           """.stripMargin)
+           """.stripMargin
+
+        val processArray1 = withArray1NullAssignment(
+          s"$jt $value = ${genGetValue(array1, i)};" +
+            SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body,
+              (valueNaN: String) =>
+                s"""
+                   |$size++;
+                   |$builder.$$plus$$eq($valueNaN);
+                 """.stripMargin))
 
         // Only need to track null element index when array1's element is 
nullable.
         val declareNullTrackVariables = if 
(left.dataType.asInstanceOf[ArrayType].containsNull) {
           s"""
-             |boolean $notFoundNullElement = true;
              |int $nullElementIndex = -1;
            """.stripMargin
         } else {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 6a66e36..efffe95 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -1912,6 +1912,23 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
       Seq(Float.NaN, null, 1f))
   }
 
+  test("SPARK-36753: ArrayExcept should handle duplicated Double.NaN and 
Float.Nan") {
+    checkEvaluation(ArrayExcept(
+      Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN))),
+      Seq(1d))
+    checkEvaluation(ArrayExcept(
+      Literal.create(Seq(null, Double.NaN, null, 1d), ArrayType(DoubleType)),
+      Literal.create(Seq(Double.NaN, null), ArrayType(DoubleType))),
+      Seq(1d))
+    checkEvaluation(ArrayExcept(
+      Literal.apply(Array(Float.NaN, 1f)), Literal.apply(Array(Float.NaN))),
+      Seq(1f))
+    checkEvaluation(ArrayExcept(
+      Literal.create(Seq(null, Float.NaN, null, 1f), ArrayType(FloatType)),
+      Literal.create(Seq(Float.NaN, null), ArrayType(FloatType))),
+      Seq(1f))
+  }
+
   test("SPARK-36754: ArrayIntersect should handle duplicated Double.NaN and 
Float.Nan") {
     checkEvaluation(ArrayIntersect(
       Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN, 
1d, 2d))),

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to