This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.1 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push: new 1042481 [SPARK-36702][SQL] ArrayUnion handle duplicated Double.NaN and Float.Nan 1042481 is described below commit 104248120094e3f2356023744a8efc0df22b0da6 Author: Angerszhuuuu <angers....@gmail.com> AuthorDate: Tue Sep 14 18:25:47 2021 +0800 [SPARK-36702][SQL] ArrayUnion handle duplicated Double.NaN and Float.Nan ### What changes were proposed in this pull request? For query ``` select array_union(array(cast('nan' as double), cast('nan' as double)), array()) ``` This returns [NaN, NaN], but it should return [NaN]. This issue is caused by `OpenHashSet` can't handle `Double.NaN` and `Float.NaN` too. In this pr we add a wrap for OpenHashSet that can handle `null`, `Double.NaN`, `Float.NaN` together ### Why are the changes needed? Fix bug ### Does this PR introduce _any_ user-facing change? ArrayUnion won't show duplicated `NaN` value ### How was this patch tested? Added UT Closes #33955 from AngersZhuuuu/SPARK-36702-WrapOpenHashSet. Lead-authored-by: Angerszhuuuu <angers....@gmail.com> Co-authored-by: AngersZhuuuu <angers....@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit f71f37755d581017f549ecc8683fb7afc2852c67) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../expressions/collectionOperations.scala | 61 +++++++++++++----- .../org/apache/spark/sql/util/SQLOpenHashSet.scala | 72 ++++++++++++++++++++++ .../expressions/CollectionExpressionsSuite.scala | 17 +++++ 3 files changed, 133 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index bb2163c..b829ac0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SQLOpenHashSet import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH @@ -3367,24 +3368,31 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi if (TypeUtils.typeWithProperEquals(elementType)) { (array1, array2) => val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - val hs = new OpenHashSet[Any] - var foundNullElement = false + val hs = new SQLOpenHashSet[Any]() + val isNaN = SQLOpenHashSet.isNaN(elementType) Seq(array1, array2).foreach { array => var i = 0 while (i < array.numElements()) { if (array.isNullAt(i)) { - if (!foundNullElement) { + if (!hs.containsNull) { + hs.addNull arrayBuffer += null - foundNullElement = true } } else { val elem = array.get(i, elementType) - if (!hs.contains(elem)) { - if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size) + if (isNaN(elem)) { + if (!hs.containsNaN) { + arrayBuffer += elem + hs.addNaN + } + } else { + if (!hs.contains(elem)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += elem + hs.add(elem) } - arrayBuffer += elem - hs.add(elem) } } i += 1 @@ -3441,13 +3449,12 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi val ptName = CodeGenerator.primitiveTypeName(jt) nullSafeCodeGen(ctx, ev, (array1, array2) => { - val foundNullElement = ctx.freshName("foundNullElement") val nullElementIndex = ctx.freshName("nullElementIndex") val builder = ctx.freshName("builder") val array = ctx.freshName("array") val arrays = ctx.freshName("arrays") val arrayDataIdx = ctx.freshName("arrayDataIdx") - val openHashSet = classOf[OpenHashSet[_]].getName + val openHashSet = classOf[SQLOpenHashSet[_]].getName val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" val hashSet = ctx.freshName("hashSet") val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName @@ -3457,9 +3464,9 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi if (dataType.asInstanceOf[ArrayType].containsNull) { s""" |if ($array.isNullAt($i)) { - | if (!$foundNullElement) { + | if (!$hashSet.containsNull()) { | $nullElementIndex = $size; - | $foundNullElement = true; + | $hashSet.addNull(); | $size++; | $builder.$$plus$$eq($nullValueHolder); | } @@ -3471,9 +3478,28 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi body } - val processArray = withArrayNullAssignment( + def withNaNCheck(body: String): String = { + (elementType match { + case DoubleType => Some(s"java.lang.Double.isNaN((double)$value)") + case FloatType => Some(s"java.lang.Float.isNaN((float)$value)") + case _ => None + }).map { isNaN => + s""" + |if ($isNaN) { + | if (!$hashSet.containsNaN()) { + | $size++; + | $hashSet.addNaN(); + | $builder.$$plus$$eq($value); + | } + |} else { + | $body + |} + """.stripMargin + } + }.getOrElse(body) + + val body = s""" - |$jt $value = ${genGetValue(array, i)}; |if (!$hashSet.contains($hsValueCast$value)) { | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | break; @@ -3481,12 +3507,13 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi | $hashSet.add$hsPostFix($hsValueCast$value); | $builder.$$plus$$eq($value); |} - """.stripMargin) + """.stripMargin + val processArray = + withArrayNullAssignment(s"$jt $value = ${genGetValue(array, i)};" + withNaNCheck(body)) // Only need to track null element index when result array's element is nullable. val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { s""" - |boolean $foundNullElement = false; |int $nullElementIndex = -1; """.stripMargin } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala new file mode 100644 index 0000000..5ffe733 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import scala.reflect._ + +import org.apache.spark.annotation.Private +import org.apache.spark.sql.types.{DataType, DoubleType, FloatType} +import org.apache.spark.util.collection.OpenHashSet + +// A wrap of OpenHashSet that can handle null, Double.NaN and Float.NaN w.r.t. the SQL semantic. +@Private +class SQLOpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( + initialCapacity: Int, + loadFactor: Double) { + + def this(initialCapacity: Int) = this(initialCapacity, 0.7) + + def this() = this(64) + + private val hashSet = new OpenHashSet[T](initialCapacity, loadFactor) + + private var containNull = false + private var containNaN = false + + def addNull(): Unit = { + containNull = true + } + + def addNaN(): Unit = { + containNaN = true + } + + def add(k: T): Unit = { + hashSet.add(k) + } + + def contains(k: T): Boolean = { + hashSet.contains(k) + } + + def containsNull(): Boolean = containNull + + def containsNaN(): Boolean = containNaN +} + +object SQLOpenHashSet { + def isNaN(dataType: DataType): Any => Boolean = { + dataType match { + case DoubleType => + (value: Any) => java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double]) + case FloatType => + (value: Any) => java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float]) + case _ => (_: Any) => false + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d79f06f..25e40c4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1948,4 +1948,21 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } } } + + test("SPARK-36702: ArrayUnion should handle duplicated Double.NaN and Float.Nan") { + checkEvaluation(ArrayUnion( + Literal.apply(Array(Double.NaN, Double.NaN)), Literal.apply(Array(1d))), + Seq(Double.NaN, 1d)) + checkEvaluation(ArrayUnion( + Literal.create(Seq(Double.NaN, null), ArrayType(DoubleType)), + Literal.create(Seq(Double.NaN, null, 1d), ArrayType(DoubleType))), + Seq(Double.NaN, null, 1d)) + checkEvaluation(ArrayUnion( + Literal.apply(Array(Float.NaN, Float.NaN)), Literal.apply(Array(1f))), + Seq(Float.NaN, 1f)) + checkEvaluation(ArrayUnion( + Literal.create(Seq(Float.NaN, null), ArrayType(FloatType)), + Literal.create(Seq(Float.NaN, null, 1f), ArrayType(FloatType))), + Seq(Float.NaN, null, 1f)) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org