This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.1 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push: new e1fc62d [SPARK-36792][SQL] InSet should handle NaN e1fc62d is described below commit e1fc62de8e05f2606972434419898c5810339995 Author: Angerszhuuuu <angers....@gmail.com> AuthorDate: Fri Sep 24 16:19:36 2021 +0800 [SPARK-36792][SQL] InSet should handle NaN ### What changes were proposed in this pull request? InSet should handle NaN ``` InSet(Literal(Double.NaN), Set(Double.NaN, 1d)) should return true, but return false. ``` ### Why are the changes needed? InSet should handle NaN ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added UT Closes #34033 from AngersZhuuuu/SPARK-36792. Authored-by: Angerszhuuuu <angers....@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 64f4bf47af2412811ff2843cd363ce883a604ce7) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/expressions/predicates.scala | 38 +++++++++++++++++++--- .../sql/catalyst/expressions/PredicateSuite.scala | 14 ++++++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 53ac356..f2d91b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -487,12 +487,24 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}" @transient private[this] lazy val hasNull: Boolean = hset.contains(null) + @transient private[this] lazy val isNaN: Any => Boolean = child.dataType match { + case DoubleType => (value: Any) => java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double]) + case FloatType => (value: Any) => java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float]) + case _ => (_: Any) => false + } + @transient private[this] lazy val hasNaN = child.dataType match { + case DoubleType | FloatType => set.exists(isNaN) + case _ => false + } + override def nullable: Boolean = child.nullable || hasNull protected override def nullSafeEval(value: Any): Any = { if (set.contains(value)) { true + } else if (isNaN(value)) { + hasNaN } else if (hasNull) { null } else { @@ -524,15 +536,33 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with private def genCodeWithSet(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { val setTerm = ctx.addReferenceObj("set", set) + val setIsNull = if (hasNull) { s"${ev.isNull} = !${ev.value};" } else { "" } - s""" - |${ev.value} = $setTerm.contains($c); - |$setIsNull - """.stripMargin + + val ret = child.dataType match { + case DoubleType => Some((v: Any) => s"java.lang.Double.isNaN($v)") + case FloatType => Some((v: Any) => s"java.lang.Float.isNaN($v)") + case _ => None + } + + ret.map { isNaN => + s""" + |if ($setTerm.contains($c)) { + | ${ev.value} = true; + |} else if (${isNaN(c)}) { + | ${ev.value} = $hasNaN; + |} + |$setIsNull + |""".stripMargin + }.getOrElse( + s""" + |${ev.value} = $setTerm.contains($c); + |$setIsNull + """.stripMargin) }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 6f75623..c34b37d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -644,4 +644,18 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkExpr(GreaterThan, Double.NaN, Double.NaN, false) checkExpr(GreaterThan, 0.0, -0.0, false) } + + test("SPARK-36792: InSet should handle Double.NaN and Float.NaN") { + checkInAndInSet(In(Literal(Double.NaN), Seq(Literal(Double.NaN), Literal(2d))), true) + checkInAndInSet(In(Literal.create(null, DoubleType), + Seq(Literal(Double.NaN), Literal(2d), Literal.create(null, DoubleType))), null) + checkInAndInSet(In(Literal.create(null, DoubleType), + Seq(Literal(Double.NaN), Literal(2d))), null) + checkInAndInSet(In(Literal(3d), + Seq(Literal(Double.NaN), Literal(2d))), false) + checkInAndInSet(In(Literal(3d), + Seq(Literal(Double.NaN), Literal(2d), Literal.create(null, DoubleType))), null) + checkInAndInSet(In(Literal(Double.NaN), + Seq(Literal(Double.NaN), Literal(2d), Literal.create(null, DoubleType))), true) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org