This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6380859 [SPARK-36702][SQL][FOLLOWUP] ArrayUnion handle duplicated Double.NaN and Float.NaN 6380859 is described below commit 638085953f931f98241856c9f652e5f15202fcc0 Author: Angerszhuuuu <angers....@gmail.com> AuthorDate: Wed Sep 15 22:04:09 2021 +0800 [SPARK-36702][SQL][FOLLOWUP] ArrayUnion handle duplicated Double.NaN and Float.NaN ### What changes were proposed in this pull request? According to https://github.com/apache/spark/pull/33955#discussion_r708570515 use normalized NaN ### Why are the changes needed? Use normalized NaN for duplicated NaN value ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Exiting UT Closes #34003 from AngersZhuuuu/SPARK-36702-FOLLOWUP. Authored-by: Angerszhuuuu <angers....@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/expressions/collectionOperations.scala | 13 ++++++++----- .../scala/org/apache/spark/sql/util/SQLOpenHashSet.scala | 8 ++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e5620a1..47b2719 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3578,6 +3578,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val hs = new SQLOpenHashSet[Any]() val isNaN = SQLOpenHashSet.isNaN(elementType) + val valueNaN = SQLOpenHashSet.valueNaN(elementType) Seq(array1, array2).foreach { array => var i = 0 while (i < array.numElements()) { @@ -3590,7 +3591,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi val elem = array.get(i, elementType) if (isNaN(elem)) { if (!hs.containsNaN) { - arrayBuffer += elem + arrayBuffer += valueNaN hs.addNaN } } else { @@ -3688,16 +3689,18 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi def withNaNCheck(body: String): String = { (elementType match { - case DoubleType => Some(s"java.lang.Double.isNaN((double)$value)") - case FloatType => Some(s"java.lang.Float.isNaN((float)$value)") + case DoubleType => + Some((s"java.lang.Double.isNaN((double)$value)", "java.lang.Double.NaN")) + case FloatType => + Some((s"java.lang.Float.isNaN((float)$value)", "java.lang.Float.NaN")) case _ => None - }).map { isNaN => + }).map { case (isNaN, valueNaN) => s""" |if ($isNaN) { | if (!$hashSet.containsNaN()) { | $size++; | $hashSet.addNaN(); - | $builder.$$plus$$eq($value); + | $builder.$$plus$$eq($valueNaN); | } |} else { | $body diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala index 5ffe733..083cfdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SQLOpenHashSet.scala @@ -69,4 +69,12 @@ object SQLOpenHashSet { case _ => (_: Any) => false } } + + def valueNaN(dataType: DataType): Any = { + dataType match { + case DoubleType => java.lang.Double.NaN + case FloatType => java.lang.Float.NaN + case _ => null + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org