Repository: spark
Updated Branches:
  refs/heads/master 7a83d7140 -> b9b68a6dc


[SPARK-26211][SQL] Fix InSet for binary, and struct and array with null.

## What changes were proposed in this pull request?

Currently `InSet` doesn't work properly for binary type, or struct and array 
type with null value in the set.
Because, as for binary type, the `HashSet` doesn't work properly for 
`Array[Byte]`, and as for struct and array type with null value in the set, the 
`ordering` will throw a `NPE`.

## How was this patch tested?

Added a few tests.

Closes #23176 from ueshin/issues/SPARK-26211/inset.

Authored-by: Takuya UESHIN <ues...@databricks.com>
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b9b68a6d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b9b68a6d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b9b68a6d

Branch: refs/heads/master
Commit: b9b68a6dc7d0f735163e980392ea957f2d589923
Parents: 7a83d71
Author: Takuya UESHIN <ues...@databricks.com>
Authored: Thu Nov 29 22:37:02 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Thu Nov 29 22:37:02 2018 +0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/predicates.scala   | 33 ++++++-------
 .../catalyst/expressions/PredicateSuite.scala   | 50 +++++++++++++++++++-
 2 files changed, 63 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b9b68a6d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 16e0bc3..01ecb99 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -367,31 +367,26 @@ case class InSet(child: Expression, hset: Set[Any]) 
extends UnaryExpression with
   }
 
   @transient lazy val set: Set[Any] = child.dataType match {
-    case _: AtomicType => hset
+    case t: AtomicType if !t.isInstanceOf[BinaryType] => hset
     case _: NullType => hset
     case _ =>
       // for structs use interpreted ordering to be able to compare UnsafeRows 
with non-UnsafeRows
-      TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset
+      TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ (hset 
- null)
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val setTerm = ctx.addReferenceObj("set", set)
-    val childGen = child.genCode(ctx)
-    val setIsNull = if (hasNull) {
-      s"${ev.isNull} = !${ev.value};"
-    } else {
-      ""
-    }
-    ev.copy(code =
-      code"""
-         |${childGen.code}
-         |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull};
-         |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false;
-         |if (!${ev.isNull}) {
-         |  ${ev.value} = $setTerm.contains(${childGen.value});
-         |  $setIsNull
-         |}
-       """.stripMargin)
+    nullSafeCodeGen(ctx, ev, c => {
+      val setTerm = ctx.addReferenceObj("set", set)
+      val setIsNull = if (hasNull) {
+        s"${ev.isNull} = !${ev.value};"
+      } else {
+        ""
+      }
+      s"""
+         |${ev.value} = $setTerm.contains($c);
+         |$setIsNull
+       """.stripMargin
+    })
   }
 
   override def sql: String = {

http://git-wip-us.apache.org/repos/asf/spark/blob/b9b68a6d/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index ac76b17..3b60d1d8 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -268,7 +268,7 @@ class PredicateSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(InSet(nl, nS), null)
 
     val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, 
ByteType, ShortType,
-      LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, 
TimestampType)
+      LongType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
     primitiveTypes.foreach { t =>
       val dataGen = RandomDataGenerator.forType(t, nullable = true).get
       val inputData = Seq.fill(10) {
@@ -293,6 +293,54 @@ class PredicateSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     }
   }
 
+  test("INSET: binary") {
+    val hS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte)
+    val nS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte) + 
null
+    val onetwo = Literal(Array(1.toByte, 2.toByte))
+    val three = Literal(Array(3.toByte))
+    val threefour = Literal(Array(3.toByte, 4.toByte))
+    val nl = Literal(null, onetwo.dataType)
+    checkEvaluation(InSet(onetwo, hS), true)
+    checkEvaluation(InSet(three, hS), true)
+    checkEvaluation(InSet(three, nS), true)
+    checkEvaluation(InSet(threefour, hS), false)
+    checkEvaluation(InSet(threefour, nS), null)
+    checkEvaluation(InSet(nl, hS), null)
+    checkEvaluation(InSet(nl, nS), null)
+  }
+
+  test("INSET: struct") {
+    val hS = HashSet[Any]() + Literal.create((1, "a")).value + 
Literal.create((2, "b")).value
+    val nS = HashSet[Any]() + Literal.create((1, "a")).value + 
Literal.create((2, "b")).value + null
+    val oneA = Literal.create((1, "a"))
+    val twoB = Literal.create((2, "b"))
+    val twoC = Literal.create((2, "c"))
+    val nl = Literal(null, oneA.dataType)
+    checkEvaluation(InSet(oneA, hS), true)
+    checkEvaluation(InSet(twoB, hS), true)
+    checkEvaluation(InSet(twoB, nS), true)
+    checkEvaluation(InSet(twoC, hS), false)
+    checkEvaluation(InSet(twoC, nS), null)
+    checkEvaluation(InSet(nl, hS), null)
+    checkEvaluation(InSet(nl, nS), null)
+  }
+
+  test("INSET: array") {
+    val hS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + 
Literal.create(Seq(3)).value
+    val nS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + 
Literal.create(Seq(3)).value + null
+    val onetwo = Literal.create(Seq(1, 2))
+    val three = Literal.create(Seq(3))
+    val threefour = Literal.create(Seq(3, 4))
+    val nl = Literal(null, onetwo.dataType)
+    checkEvaluation(InSet(onetwo, hS), true)
+    checkEvaluation(InSet(three, hS), true)
+    checkEvaluation(InSet(three, nS), true)
+    checkEvaluation(InSet(threefour, hS), false)
+    checkEvaluation(InSet(threefour, nS), null)
+    checkEvaluation(InSet(nl, hS), null)
+    checkEvaluation(InSet(nl, nS), null)
+  }
+
   private case class MyStruct(a: Long, b: String)
   private case class MyStruct2(a: MyStruct, b: Array[Int])
   private val udt = new ExamplePointUDT


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to