This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new e1fc62d  [SPARK-36792][SQL] InSet should handle NaN
e1fc62d is described below

commit e1fc62de8e05f2606972434419898c5810339995
Author: Angerszhuuuu <angers....@gmail.com>
AuthorDate: Fri Sep 24 16:19:36 2021 +0800

    [SPARK-36792][SQL] InSet should handle NaN
    
    ### What changes were proposed in this pull request?
    InSet should handle NaN
    ```
    InSet(Literal(Double.NaN), Set(Double.NaN, 1d)) should return true, but 
return false.
    ```
    ### Why are the changes needed?
    InSet should handle NaN
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added UT
    
    Closes #34033 from AngersZhuuuu/SPARK-36792.
    
    Authored-by: Angerszhuuuu <angers....@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 64f4bf47af2412811ff2843cd363ce883a604ce7)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/expressions/predicates.scala      | 38 +++++++++++++++++++---
 .../sql/catalyst/expressions/PredicateSuite.scala  | 14 ++++++++
 2 files changed, 48 insertions(+), 4 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 53ac356..f2d91b5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -487,12 +487,24 @@ case class InSet(child: Expression, hset: Set[Any]) 
extends UnaryExpression with
   override def toString: String = s"$child INSET ${hset.mkString("(", ",", 
")")}"
 
   @transient private[this] lazy val hasNull: Boolean = hset.contains(null)
+  @transient private[this] lazy val isNaN: Any => Boolean = child.dataType 
match {
+    case DoubleType => (value: Any) => 
java.lang.Double.isNaN(value.asInstanceOf[java.lang.Double])
+    case FloatType => (value: Any) => 
java.lang.Float.isNaN(value.asInstanceOf[java.lang.Float])
+    case _ => (_: Any) => false
+  }
+  @transient private[this] lazy val hasNaN = child.dataType match {
+    case DoubleType | FloatType => set.exists(isNaN)
+    case _ => false
+  }
+
 
   override def nullable: Boolean = child.nullable || hasNull
 
   protected override def nullSafeEval(value: Any): Any = {
     if (set.contains(value)) {
       true
+    } else if (isNaN(value)) {
+      hasNaN
     } else if (hasNull) {
       null
     } else {
@@ -524,15 +536,33 @@ case class InSet(child: Expression, hset: Set[Any]) 
extends UnaryExpression with
   private def genCodeWithSet(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     nullSafeCodeGen(ctx, ev, c => {
       val setTerm = ctx.addReferenceObj("set", set)
+
       val setIsNull = if (hasNull) {
         s"${ev.isNull} = !${ev.value};"
       } else {
         ""
       }
-      s"""
-         |${ev.value} = $setTerm.contains($c);
-         |$setIsNull
-       """.stripMargin
+
+      val ret = child.dataType match {
+        case DoubleType => Some((v: Any) => s"java.lang.Double.isNaN($v)")
+        case FloatType => Some((v: Any) => s"java.lang.Float.isNaN($v)")
+        case _ => None
+      }
+
+      ret.map { isNaN =>
+        s"""
+          |if ($setTerm.contains($c)) {
+          |  ${ev.value} = true;
+          |} else if (${isNaN(c)}) {
+          |  ${ev.value} =  $hasNaN;
+          |}
+          |$setIsNull
+          |""".stripMargin
+      }.getOrElse(
+        s"""
+           |${ev.value} = $setTerm.contains($c);
+           |$setIsNull
+         """.stripMargin)
     })
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 6f75623..c34b37d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -644,4 +644,18 @@ class PredicateSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkExpr(GreaterThan, Double.NaN, Double.NaN, false)
     checkExpr(GreaterThan, 0.0, -0.0, false)
   }
+
+  test("SPARK-36792: InSet should handle Double.NaN and Float.NaN") {
+    checkInAndInSet(In(Literal(Double.NaN), Seq(Literal(Double.NaN), 
Literal(2d))), true)
+    checkInAndInSet(In(Literal.create(null, DoubleType),
+      Seq(Literal(Double.NaN), Literal(2d), Literal.create(null, 
DoubleType))), null)
+    checkInAndInSet(In(Literal.create(null, DoubleType),
+      Seq(Literal(Double.NaN), Literal(2d))), null)
+    checkInAndInSet(In(Literal(3d),
+      Seq(Literal(Double.NaN), Literal(2d))), false)
+    checkInAndInSet(In(Literal(3d),
+      Seq(Literal(Double.NaN), Literal(2d), Literal.create(null, 
DoubleType))), null)
+    checkInAndInSet(In(Literal(Double.NaN),
+      Seq(Literal(Double.NaN), Literal(2d), Literal.create(null, 
DoubleType))), true)
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to