Repository: spark
Updated Branches:
  refs/heads/master 7ce110828 -> cba69aeb4


[SPARK-21110][SQL] Structs, arrays, and other orderable datatypes should be 
usable in inequalities

## What changes were proposed in this pull request?

Allows `BinaryComparison` operators to work on any data type that actually 
supports ordering as verified by `TypeUtils.checkForOrderingExpr` instead of 
relying on the incomplete list `TypeCollection.Ordered` (which is removed by 
this PR).

## How was this patch tested?

Updated unit tests to cover structs and arrays.

Author: Andrew Ray <ray.and...@gmail.com>

Closes #18818 from aray/SPARK-21110.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cba69aeb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cba69aeb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cba69aeb

Branch: refs/heads/master
Commit: cba69aeb453d2489830f3e6e0473a64dee81989e
Parents: 7ce1108
Author: Andrew Ray <ray.and...@gmail.com>
Authored: Thu Aug 31 15:08:03 2017 -0700
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Thu Aug 31 15:08:03 2017 -0700

----------------------------------------------------------------------
 .../expressions/codegen/CodeGenerator.scala     |  1 +
 .../sql/catalyst/expressions/predicates.scala   | 58 +++++---------------
 .../spark/sql/catalyst/util/TypeUtils.scala     |  1 +
 .../spark/sql/types/AbstractDataType.scala      | 12 ----
 .../catalyst/analysis/AnalysisErrorSuite.scala  |  2 +-
 .../analysis/ExpressionTypeCheckingSuite.scala  | 15 ++---
 .../catalyst/expressions/PredicateSuite.scala   | 37 +++++++++++--
 7 files changed, 58 insertions(+), 68 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cba69aeb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 3853863..4373971 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -594,6 +594,7 @@ class CodegenContext {
     case array: ArrayType => genComp(array, c1, c2) + " == 0"
     case struct: StructType => genComp(struct, c1, c2) + " == 0"
     case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
+    case NullType => "false"
     case _ =>
       throw new IllegalArgumentException(
         "cannot generate equality code for un-comparable type: " + 
dataType.simpleString)

http://git-wip-us.apache.org/repos/asf/spark/blob/cba69aeb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 613d620..d3071c5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -448,6 +448,16 @@ case class Or(left: Expression, right: Expression) extends 
BinaryOperator with P
 
 abstract class BinaryComparison extends BinaryOperator with Predicate {
 
+  // Note that we need to give a superset of allowable input types since 
orderable types are not
+  // finitely enumerable. The allowable types are checked below by 
checkInputDataTypes.
+  override def inputType: AbstractDataType = AnyDataType
+
+  override def checkInputDataTypes(): TypeCheckResult = 
super.checkInputDataTypes() match {
+    case TypeCheckResult.TypeCheckSuccess =>
+      TypeUtils.checkForOrderingExpr(left.dataType, 
this.getClass.getSimpleName)
+    case failure => failure
+  }
+
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     if (ctx.isPrimitiveType(left.dataType)
         && left.dataType != BooleanType // java boolean doesn't support > or < 
operator
@@ -460,7 +470,7 @@ abstract class BinaryComparison extends BinaryOperator with 
Predicate {
     }
   }
 
-  protected lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
+  protected lazy val ordering: Ordering[Any] = 
TypeUtils.getInterpretedOrdering(left.dataType)
 }
 
 
@@ -478,28 +488,13 @@ object Equality {
   }
 }
 
+// TODO: although map type is not orderable, technically map type should be 
able to be used
+// in equality comparison
 @ExpressionDescription(
   usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` equals `expr2`, or 
false otherwise.")
 case class EqualTo(left: Expression, right: Expression)
     extends BinaryComparison with NullIntolerant {
 
-  override def inputType: AbstractDataType = AnyDataType
-
-  override def checkInputDataTypes(): TypeCheckResult = {
-    super.checkInputDataTypes() match {
-      case TypeCheckResult.TypeCheckSuccess =>
-        // TODO: although map type is not orderable, technically map type 
should be able to be used
-        // in equality comparison, remove this type check once we support it.
-        if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) {
-          TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualTo, 
but the actual " +
-            s"input type is ${left.dataType.catalogString}.")
-        } else {
-          TypeCheckResult.TypeCheckSuccess
-        }
-      case failure => failure
-    }
-  }
-
   override def symbol: String = "="
 
   protected override def nullSafeEval(left: Any, right: Any): Any = 
ordering.equiv(left, right)
@@ -509,6 +504,8 @@ case class EqualTo(left: Expression, right: Expression)
   }
 }
 
+// TODO: although map type is not orderable, technically map type should be 
able to be used
+// in equality comparison
 @ExpressionDescription(
   usage = """
     expr1 _FUNC_ expr2 - Returns same result as the EQUAL(=) operator for 
non-null operands,
@@ -516,23 +513,6 @@ case class EqualTo(left: Expression, right: Expression)
   """)
 case class EqualNullSafe(left: Expression, right: Expression) extends 
BinaryComparison {
 
-  override def inputType: AbstractDataType = AnyDataType
-
-  override def checkInputDataTypes(): TypeCheckResult = {
-    super.checkInputDataTypes() match {
-      case TypeCheckResult.TypeCheckSuccess =>
-        // TODO: although map type is not orderable, technically map type 
should be able to be used
-        // in equality comparison, remove this type check once we support it.
-        if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) {
-          TypeCheckResult.TypeCheckFailure("Cannot use map type in 
EqualNullSafe, but the actual " +
-            s"input type is ${left.dataType.catalogString}.")
-        } else {
-          TypeCheckResult.TypeCheckSuccess
-        }
-      case failure => failure
-    }
-  }
-
   override def symbol: String = "<=>"
 
   override def nullable: Boolean = false
@@ -564,8 +544,6 @@ case class EqualNullSafe(left: Expression, right: 
Expression) extends BinaryComp
 case class LessThan(left: Expression, right: Expression)
     extends BinaryComparison with NullIntolerant {
 
-  override def inputType: AbstractDataType = TypeCollection.Ordered
-
   override def symbol: String = "<"
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = 
ordering.lt(input1, input2)
@@ -576,8 +554,6 @@ case class LessThan(left: Expression, right: Expression)
 case class LessThanOrEqual(left: Expression, right: Expression)
     extends BinaryComparison with NullIntolerant {
 
-  override def inputType: AbstractDataType = TypeCollection.Ordered
-
   override def symbol: String = "<="
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = 
ordering.lteq(input1, input2)
@@ -588,8 +564,6 @@ case class LessThanOrEqual(left: Expression, right: 
Expression)
 case class GreaterThan(left: Expression, right: Expression)
     extends BinaryComparison with NullIntolerant {
 
-  override def inputType: AbstractDataType = TypeCollection.Ordered
-
   override def symbol: String = ">"
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = 
ordering.gt(input1, input2)
@@ -600,8 +574,6 @@ case class GreaterThan(left: Expression, right: Expression)
 case class GreaterThanOrEqual(left: Expression, right: Expression)
     extends BinaryComparison with NullIntolerant {
 
-  override def inputType: AbstractDataType = TypeCollection.Ordered
-
   override def symbol: String = ">="
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = 
ordering.gteq(input1, input2)

http://git-wip-us.apache.org/repos/asf/spark/blob/cba69aeb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 4522577..1dcda49 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -65,6 +65,7 @@ object TypeUtils {
       case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
       case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
       case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
+      case udt: UserDefinedType[_] => getInterpretedOrdering(udt.sqlType)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/cba69aeb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index 1d54ff5..3041f44 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -79,18 +79,6 @@ private[sql] class TypeCollection(private val types: 
Seq[AbstractDataType])
 private[sql] object TypeCollection {
 
   /**
-   * Types that can be ordered/compared. In the long run we should probably 
make this a trait
-   * that can be mixed into each data type, and perhaps create an 
`AbstractDataType`.
-   */
-  // TODO: Should we consolidate this with RowOrdering.isOrderable?
-  val Ordered = TypeCollection(
-    BooleanType,
-    ByteType, ShortType, IntegerType, LongType,
-    FloatType, DoubleType, DecimalType,
-    TimestampType, DateType,
-    StringType, BinaryType)
-
-  /**
    * Types that include numeric types and interval type. They are only used in 
unary_minus,
    * unary_positive, add and subtract operations.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/cba69aeb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 4e06136..884e113 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -505,7 +505,7 @@ class AnalysisErrorSuite extends AnalysisTest {
       right,
       joinType = Cross,
       condition = Some('b === 'd))
-    assertAnalysisError(plan2, "Cannot use map type in EqualTo" :: Nil)
+    assertAnalysisError(plan2, "EqualTo does not support ordering on type 
MapType" :: Nil)
   }
 
   test("PredicateSubQuery is used outside of a filter") {

http://git-wip-us.apache.org/repos/asf/spark/blob/cba69aeb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 3072577..36714bd 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.types.{LongType, StringType, TypeCollection}
+import org.apache.spark.sql.types._
 
 class ExpressionTypeCheckingSuite extends SparkFunSuite {
 
@@ -109,16 +109,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
     assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
 
-    assertError(EqualTo('mapField, 'mapField), "Cannot use map type in 
EqualTo")
-    assertError(EqualNullSafe('mapField, 'mapField), "Cannot use map type in 
EqualNullSafe")
+    assertError(EqualTo('mapField, 'mapField), "EqualTo does not support 
ordering on type MapType")
+    assertError(EqualNullSafe('mapField, 'mapField),
+      "EqualNullSafe does not support ordering on type MapType")
     assertError(LessThan('mapField, 'mapField),
-      s"requires ${TypeCollection.Ordered.simpleString} type")
+      "LessThan does not support ordering on type MapType")
     assertError(LessThanOrEqual('mapField, 'mapField),
-      s"requires ${TypeCollection.Ordered.simpleString} type")
+      "LessThanOrEqual does not support ordering on type MapType")
     assertError(GreaterThan('mapField, 'mapField),
-      s"requires ${TypeCollection.Ordered.simpleString} type")
+      "GreaterThan does not support ordering on type MapType")
     assertError(GreaterThanOrEqual('mapField, 'mapField),
-      s"requires ${TypeCollection.Ordered.simpleString} type")
+      "GreaterThanOrEqual does not support ordering on type MapType")
 
     assertError(If('intField, 'stringField, 'stringField),
       "type of predicate expression in If should be boolean")

http://git-wip-us.apache.org/repos/asf/spark/blob/cba69aeb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index ef510a9..055c31c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -17,12 +17,15 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import java.sql.{Date, Timestamp}
+
 import scala.collection.immutable.HashSet
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.RandomDataGenerator
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
 import org.apache.spark.sql.types._
 
 
@@ -215,14 +218,35 @@ class PredicateSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     }
   }
 
-  private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, 
false).map(Literal(_))
+  private case class MyStruct(a: Long, b: String)
+  private case class MyStruct2(a: MyStruct, b: Array[Int])
+  private val udt = new ExamplePointUDT
+
+  private val smallValues =
+    Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new 
Date(2000, 1, 1),
+      new Timestamp(1), "a", 1f, 1d, 0f, 0d, false, Array(1L, 2L))
+      .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
+      Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
+      Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))
   private val largeValues =
-    Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN, 
true).map(Literal(_))
+    Seq(2.toByte, 2.toShort, 2, 2L, Decimal(2), Array(2.toByte), new 
Date(2000, 1, 2),
+      new Timestamp(2), "b", 2f, 2d, Float.NaN, Double.NaN, true, Array(2L, 
1L))
+      .map(Literal(_)) ++ Seq(Literal.create(MyStruct(2L, "b")),
+      Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 2))),
+      Literal.create(ArrayData.toArrayData(Array(1.0, 3.0)), udt))
 
   private val equalValues1 =
-    Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, 
true).map(Literal(_))
+    Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new 
Date(2000, 1, 1),
+      new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 
2L))
+      .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
+      Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
+      Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))
   private val equalValues2 =
-    Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, 
true).map(Literal(_))
+    Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new 
Date(2000, 1, 1),
+      new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 
2L))
+      .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
+      Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
+      Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))
 
   test("BinaryComparison consistency check") {
     DataTypeTestUtils.ordered.foreach { dt =>
@@ -285,11 +309,13 @@ class PredicateSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     // Use -1 (default value for codegen) which can trigger some weird bugs, 
e.g. SPARK-14757
     val normalInt = Literal(-1)
     val nullInt = NonFoldableLiteral.create(null, IntegerType)
+    val nullNullType = Literal.create(null, NullType)
 
     def nullTest(op: (Expression, Expression) => Expression): Unit = {
       checkEvaluation(op(normalInt, nullInt), null)
       checkEvaluation(op(nullInt, normalInt), null)
       checkEvaluation(op(nullInt, nullInt), null)
+      checkEvaluation(op(nullNullType, nullNullType), null)
     }
 
     nullTest(LessThan)
@@ -301,6 +327,7 @@ class PredicateSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(EqualNullSafe(normalInt, nullInt), false)
     checkEvaluation(EqualNullSafe(nullInt, normalInt), false)
     checkEvaluation(EqualNullSafe(nullInt, nullInt), true)
+    checkEvaluation(EqualNullSafe(nullNullType, nullNullType), true)
   }
 
   test("EqualTo on complex type") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to