This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f71f377  [SPARK-36702][SQL] ArrayUnion handle duplicated Double.NaN 
and Float.Nan
f71f377 is described below

commit f71f37755d581017f549ecc8683fb7afc2852c67
Author: Angerszhuuuu <angers....@gmail.com>
AuthorDate: Tue Sep 14 18:25:47 2021 +0800

    [SPARK-36702][SQL] ArrayUnion handle duplicated Double.NaN and Float.Nan
    
    ### What changes were proposed in this pull request?
    For query
    ```
    select array_union(array(cast('nan' as double), cast('nan' as double)), 
array())
    ```
    This returns [NaN, NaN], but it should return [NaN].
    This issue is caused by `OpenHashSet` can't handle `Double.NaN` and 
`Float.NaN` too.
    In this pr we add a wrap for OpenHashSet that can handle `null`, 
`Double.NaN`, `Float.NaN` together
    
    ### Why are the changes needed?
    Fix bug
    
    ### Does this PR introduce _any_ user-facing change?
    ArrayUnion won't show duplicated `NaN` value
    
    ### How was this patch tested?
    Added UT
    
    Closes #33955 from AngersZhuuuu/SPARK-36702-WrapOpenHashSet.
    
    Lead-authored-by: Angerszhuuuu <angers....@gmail.com>
    Co-authored-by: AngersZhuuuu <angers....@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../expressions/collectionOperations.scala         | 61 +++++++++++++-----
 .../org/apache/spark/sql/util/SQLOpenHashSet.scala | 72 ++++++++++++++++++++++
 .../expressions/CollectionExpressionsSuite.scala   | 17 +++++
 3 files changed, 133 insertions(+), 17 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index ce17231..e5620a1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils._
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.SQLOpenHashSet
 import org.apache.spark.unsafe.UTF8StringBuilder
 import org.apache.spark.unsafe.array.ByteArrayMethods
 import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
@@ -3575,24 +3576,31 @@ case class ArrayUnion(left: Expression, right: 
Expression) extends ArrayBinaryLi
     if (TypeUtils.typeWithProperEquals(elementType)) {
       (array1, array2) =>
         val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
-        val hs = new OpenHashSet[Any]
-        var foundNullElement = false
+        val hs = new SQLOpenHashSet[Any]()
+        val isNaN = SQLOpenHashSet.isNaN(elementType)
         Seq(array1, array2).foreach { array =>
           var i = 0
           while (i < array.numElements()) {
             if (array.isNullAt(i)) {
-              if (!foundNullElement) {
+              if (!hs.containsNull) {
+                hs.addNull
                 arrayBuffer += null
-                foundNullElement = true
               }
             } else {
               val elem = array.get(i, elementType)
-              if (!hs.contains(elem)) {
-                if (arrayBuffer.size > 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
-                  
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
+              if (isNaN(elem)) {
+                if (!hs.containsNaN) {
+                  arrayBuffer += elem
+                  hs.addNaN
+                }
+              } else {
+                if (!hs.contains(elem)) {
+                  if (arrayBuffer.size > 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+                    
ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size)
+                  }
+                  arrayBuffer += elem
+                  hs.add(elem)
                 }
-                arrayBuffer += elem
-                hs.add(elem)
               }
             }
             i += 1
@@ -3649,13 +3657,12 @@ case class ArrayUnion(left: Expression, right: 
Expression) extends ArrayBinaryLi
       val ptName = CodeGenerator.primitiveTypeName(jt)
 
       nullSafeCodeGen(ctx, ev, (array1, array2) => {
-        val foundNullElement = ctx.freshName("foundNullElement")
         val nullElementIndex = ctx.freshName("nullElementIndex")
         val builder = ctx.freshName("builder")
         val array = ctx.freshName("array")
         val arrays = ctx.freshName("arrays")
         val arrayDataIdx = ctx.freshName("arrayDataIdx")
-        val openHashSet = classOf[OpenHashSet[_]].getName
+        val openHashSet = classOf[SQLOpenHashSet[_]].getName
         val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
         val hashSet = ctx.freshName("hashSet")
         val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
@@ -3665,9 +3672,9 @@ case class ArrayUnion(left: Expression, right: 
Expression) extends ArrayBinaryLi
           if (dataType.asInstanceOf[ArrayType].containsNull) {
             s"""
                |if ($array.isNullAt($i)) {
-               |  if (!$foundNullElement) {
+               |  if (!$hashSet.containsNull()) {
                |    $nullElementIndex = $size;
-               |    $foundNullElement = true;
+               |    $hashSet.addNull();
                |    $size++;
                |    $builder.$$plus$$eq($nullValueHolder);
                |  }
@@ -3679,9 +3686,28 @@ case class ArrayUnion(left: Expression, right: 
Expression) extends ArrayBinaryLi
             body
           }
 
-        val processArray = withArrayNullAssignment(
+        def withNaNCheck(body: String): String = {
+          (elementType match {
+            case DoubleType => Some(s"java.lang.Double.isNaN((double)$value)")
+            case FloatType => Some(s"java.lang.Float.isNaN((float)$value)")
+            case _ => None
+          }).map { isNaN =>
+            s"""
+               |if ($isNaN) {
+               |  if (!$hashSet.containsNaN()) {
+               |     $size++;
+               |     $hashSet.addNaN();
+               |     $builder.$$plus$$eq($value);
+               |  }
+               |} else {
+               |  $body
+               |}
+             """.stripMargin
+          }
+        }.getOrElse(body)
+
+        val body =
           s"""
-             |$jt $value = ${genGetValue(array, i)};
              |if (!$hashSet.contains($hsValueCast$value)) {
              |  if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
              |    break;
@@ -3689,12 +3715,13 @@ case class ArrayUnion(left: Expression, right: 
Expression) extends ArrayBinaryLi
              |  $hashSet.add$hsPostFix($hsValueCast$value);
              |  $builder.$$plus$$eq($value);
              |}
-           """.stripMargin)
+           """.stripMargin
+        val processArray =
+          withArrayNullAssignment(s"$jt $value = ${genGetValue(array, i)};" + 
withNaNCheck(body))
 
         // Only need to track null element index when result array's element 
is nullable.
         val declareNullTrackVariables = if 
(dataType.asInstanceOf[ArrayType].containsNull) {
           s"""
-             |boolean $foundNullElement = false;
              |int $nullElementIndex = -1;
            """.stripMargin
         } else {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala
new file mode 100644
index 0000000..5ffe733
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.util
+
+import scala.reflect._
+
+import org.apache.spark.annotation.Private
+import org.apache.spark.sql.types.{DataType, DoubleType, FloatType}
+import org.apache.spark.util.collection.OpenHashSet
+
+// A wrap of OpenHashSet that can handle null, Double.NaN and Float.NaN w.r.t. 
the SQL semantic.
+@Private
+class SQLOpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
+    initialCapacity: Int,
+    loadFactor: Double) {
+
+  def this(initialCapacity: Int) = this(initialCapacity, 0.7)
+
+  def this() = this(64)
+
+  private val hashSet = new OpenHashSet[T](initialCapacity, loadFactor)
+
+  private var containNull = false
+  private var containNaN = false
+
+  def addNull(): Unit = {
+    containNull = true
+  }
+
+  def addNaN(): Unit = {
+    containNaN = true
+  }
+
+  def add(k: T): Unit = {
+    hashSet.add(k)
+  }
+
+  def contains(k: T): Boolean = {
+    hashSet.contains(k)
+  }
+
+  def containsNull(): Boolean = containNull
+
+  def containsNaN(): Boolean = containNaN
+}
+
+object SQLOpenHashSet {
+  def isNaN(dataType: DataType): Any => Boolean = {
+    dataType match {
+      case DoubleType =>
+        (value: Any) => 
java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double])
+      case FloatType =>
+        (value: Any) => 
java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float])
+      case _ => (_: Any) => false
+    }
+  }
+}
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 688ee61..f4221a8 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -2309,4 +2309,21 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
       }
     }
   }
+
+  test("SPARK-36702: ArrayUnion should handle duplicated Double.NaN and 
Float.Nan") {
+    checkEvaluation(ArrayUnion(
+      Literal.apply(Array(Double.NaN, Double.NaN)), Literal.apply(Array(1d))),
+      Seq(Double.NaN, 1d))
+    checkEvaluation(ArrayUnion(
+      Literal.create(Seq(Double.NaN, null), ArrayType(DoubleType)),
+      Literal.create(Seq(Double.NaN, null, 1d), ArrayType(DoubleType))),
+      Seq(Double.NaN, null, 1d))
+    checkEvaluation(ArrayUnion(
+      Literal.apply(Array(Float.NaN, Float.NaN)), Literal.apply(Array(1f))),
+      Seq(Float.NaN, 1f))
+    checkEvaluation(ArrayUnion(
+      Literal.create(Seq(Float.NaN, null), ArrayType(FloatType)),
+      Literal.create(Seq(Float.NaN, null, 1f), ArrayType(FloatType))),
+      Seq(Float.NaN, null, 1f))
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to