This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new b96cbfb [SPARK-36753][SQL] ArrayExcept handle duplicated Double.NaN and Float.NaN b96cbfb is described below commit b96cbfb27d33f28dc5526dcf626642244125957d Author: Angerszhuuuu <angers....@gmail.com> AuthorDate: Wed Sep 22 23:51:41 2021 +0800 [SPARK-36753][SQL] ArrayExcept handle duplicated Double.NaN and Float.NaN ### What changes were proposed in this pull request? For query ``` select array_except(array(cast('nan' as double), 1d), array(cast('nan' as double))) ``` This returns [NaN, 1d], but it should return [1d]. This issue is caused by `OpenHashSet` can't handle `Double.NaN` and `Float.NaN` too. In this pr fix this based on https://github.com/apache/spark/pull/33955 ### Why are the changes needed? Fix bug ### Does this PR introduce _any_ user-facing change? ArrayExcept won't show handle equal `NaN` value ### How was this patch tested? Added UT Closes #33994 from AngersZhuuuu/SPARK-36753. Authored-by: Angerszhuuuu <angers....@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit a7cbe699863a6b68d27bdf3934dda7d396d80404) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../expressions/collectionOperations.scala | 61 +++++++++++++--------- .../expressions/CollectionExpressionsSuite.scala | 17 ++++++ 2 files changed, 53 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c153181..d89c77e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -37,7 +37,6 @@ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String} -import org.apache.spark.util.collection.OpenHashSet /** * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit @@ -3839,32 +3838,38 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = { if (TypeUtils.typeWithProperEquals(elementType)) { (array1, array2) => - val hs = new OpenHashSet[Any] - var notFoundNullElement = true + val hs = new SQLOpenHashSet[Any] + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, + (value: Any) => hs.add(value), + (valueNaN: Any) => {}) + val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, + (value: Any) => + if (!hs.contains(value)) { + arrayBuffer += value + hs.add(value) + }, + (valueNaN: Any) => arrayBuffer += valueNaN) var i = 0 while (i < array2.numElements()) { if (array2.isNullAt(i)) { - notFoundNullElement = false + hs.addNull() } else { val elem = array2.get(i, elementType) - hs.add(elem) + withArray2NaNCheckFunc(elem) } i += 1 } - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] i = 0 while (i < array1.numElements()) { if (array1.isNullAt(i)) { - if (notFoundNullElement) { + if (!hs.containsNull()) { arrayBuffer += null - notFoundNullElement = false + hs.addNull() } } else { val elem = array1.get(i, elementType) - if (!hs.contains(elem)) { - arrayBuffer += elem - hs.add(elem) - } + withArray1NaNCheckFunc(elem) } i += 1 } @@ -3933,10 +3938,9 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL val ptName = CodeGenerator.primitiveTypeName(jt) nullSafeCodeGen(ctx, ev, (array1, array2) => { - val notFoundNullElement = ctx.freshName("notFoundNullElement") val nullElementIndex = ctx.freshName("nullElementIndex") val builder = ctx.freshName("builder") - val openHashSet = classOf[OpenHashSet[_]].getName + val openHashSet = classOf[SQLOpenHashSet[_]].getName val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" val hashSet = ctx.freshName("hashSet") val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName @@ -3947,7 +3951,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL if (left.dataType.asInstanceOf[ArrayType].containsNull) { s""" |if ($array2.isNullAt($i)) { - | $notFoundNullElement = false; + | $hashSet.addNull(); |} else { | $body |} @@ -3965,18 +3969,18 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL } val writeArray2ToHashSet = withArray2NullCheck( - s""" - |$jt $value = ${genGetValue(array2, i)}; - |$hashSet.add$hsPostFix($hsValueCast$value); - """.stripMargin) + s"$jt $value = ${genGetValue(array2, i)};" + + SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, + s"$hashSet.add$hsPostFix($hsValueCast$value);", + (valueNaN: Any) => "")) def withArray1NullAssignment(body: String) = if (left.dataType.asInstanceOf[ArrayType].containsNull) { s""" |if ($array1.isNullAt($i)) { - | if ($notFoundNullElement) { + | if (!$hashSet.containsNull()) { + | $hashSet.addNull(); | $nullElementIndex = $size; - | $notFoundNullElement = false; | $size++; | $builder.$$plus$$eq($nullValueHolder); | } @@ -3988,9 +3992,8 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL body } - val processArray1 = withArray1NullAssignment( + val body = s""" - |$jt $value = ${genGetValue(array1, i)}; |if (!$hashSet.contains($hsValueCast$value)) { | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | break; @@ -3998,12 +4001,20 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL | $hashSet.add$hsPostFix($hsValueCast$value); | $builder.$$plus$$eq($value); |} - """.stripMargin) + """.stripMargin + + val processArray1 = withArray1NullAssignment( + s"$jt $value = ${genGetValue(array1, i)};" + + SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body, + (valueNaN: String) => + s""" + |$size++; + |$builder.$$plus$$eq($valueNaN); + """.stripMargin)) // Only need to track null element index when array1's element is nullable. val declareNullTrackVariables = if (left.dataType.asInstanceOf[ArrayType].containsNull) { s""" - |boolean $notFoundNullElement = true; |int $nullElementIndex = -1; """.stripMargin } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 6a66e36..efffe95 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1912,6 +1912,23 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq(Float.NaN, null, 1f)) } + test("SPARK-36753: ArrayExcept should handle duplicated Double.NaN and Float.Nan") { + checkEvaluation(ArrayExcept( + Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN))), + Seq(1d)) + checkEvaluation(ArrayExcept( + Literal.create(Seq(null, Double.NaN, null, 1d), ArrayType(DoubleType)), + Literal.create(Seq(Double.NaN, null), ArrayType(DoubleType))), + Seq(1d)) + checkEvaluation(ArrayExcept( + Literal.apply(Array(Float.NaN, 1f)), Literal.apply(Array(Float.NaN))), + Seq(1f)) + checkEvaluation(ArrayExcept( + Literal.create(Seq(null, Float.NaN, null, 1f), ArrayType(FloatType)), + Literal.create(Seq(Float.NaN, null), ArrayType(FloatType))), + Seq(1f)) + } + test("SPARK-36754: ArrayIntersect should handle duplicated Double.NaN and Float.Nan") { checkEvaluation(ArrayIntersect( Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN, 1d, 2d))), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org