Repository: spark
Updated Branches:
  refs/heads/master f16140975 -> 8af61fba0


[SPARK-25122][SQL] Deduplication of supports equals code

## What changes were proposed in this pull request?

The method ```*supportEquals``` determining whether elements of a data type 
could be used as items in a hash set or as keys in a hash map is duplicated 
across multiple collection and higher-order functions.

This PR suggests to deduplicate the method.

## How was this patch tested?

Run tests in:
- DataFrameFunctionsSuite
- CollectionExpressionsSuite
- HigherOrderExpressionsSuite

Closes #22110 from mn-mikke/SPARK-25122.

Authored-by: Marek Novotny <mn.mi...@gmail.com>
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8af61fba
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8af61fba
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8af61fba

Branch: refs/heads/master
Commit: 8af61fba03e1d32ddee4e83717fc8137682ffae6
Parents: f161409
Author: Marek Novotny <mn.mi...@gmail.com>
Authored: Fri Aug 17 11:52:16 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Fri Aug 17 11:52:16 2018 +0800

----------------------------------------------------------------------
 .../expressions/collectionOperations.scala      | 38 +++++++-------------
 .../expressions/higherOrderFunctions.scala      |  8 +----
 .../spark/sql/catalyst/util/TypeUtils.scala     | 13 ++++++-
 3 files changed, 25 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8af61fba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 5e3449d..cf9796e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1505,13 +1505,7 @@ case class ArraysOverlap(left: Expression, right: 
Expression)
   @transient private lazy val ordering: Ordering[Any] =
     TypeUtils.getInterpretedOrdering(elementType)
 
-  @transient private lazy val elementTypeSupportEquals = elementType match {
-    case BinaryType => false
-    case _: AtomicType => true
-    case _ => false
-  }
-
-  @transient private lazy val doEvaluation = if (elementTypeSupportEquals) {
+  @transient private lazy val doEvaluation = if 
(TypeUtils.typeWithProperEquals(elementType)) {
     fastEval _
   } else {
     bruteForceEval _
@@ -1593,7 +1587,7 @@ case class ArraysOverlap(left: Expression, right: 
Expression)
     nullSafeCodeGen(ctx, ev, (a1, a2) => {
       val smaller = ctx.freshName("smallerArray")
       val bigger = ctx.freshName("biggerArray")
-      val comparisonCode = if (elementTypeSupportEquals) {
+      val comparisonCode = if (TypeUtils.typeWithProperEquals(elementType)) {
         fastCodegen(ctx, ev, smaller, bigger)
       } else {
         bruteForceCodegen(ctx, ev, smaller, bigger)
@@ -3404,12 +3398,6 @@ case class ArrayDistinct(child: Expression)
     }
   }
 
-  @transient private lazy val elementTypeSupportEquals = elementType match {
-    case BinaryType => false
-    case _: AtomicType => true
-    case _ => false
-  }
-
   @transient protected lazy val canUseSpecializedHashSet = elementType match {
     case ByteType | ShortType | IntegerType | LongType | FloatType | 
DoubleType => true
     case _ => false
@@ -3434,9 +3422,13 @@ case class ArrayDistinct(child: Expression)
 
   override def nullSafeEval(array: Any): Any = {
     val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
-    if (elementTypeSupportEquals) {
-      new GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
-    } else {
+    doEvaluation(data)
+  }
+
+  @transient private lazy val doEvaluation = if 
(TypeUtils.typeWithProperEquals(elementType)) {
+    (data: Array[AnyRef]) => new 
GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
+  } else {
+    (data: Array[AnyRef]) => {
       var foundNullElement = false
       var pos = 0
       for (i <- 0 until data.length) {
@@ -3576,12 +3568,6 @@ abstract class ArraySetLike extends 
BinaryArrayExpressionWithImplicitCast {
   @transient protected lazy val ordering: Ordering[Any] =
     TypeUtils.getInterpretedOrdering(elementType)
 
-  @transient protected lazy val elementTypeSupportEquals = elementType match {
-    case BinaryType => false
-    case _: AtomicType => true
-    case _ => false
-  }
-
   @transient protected lazy val canUseSpecializedHashSet = elementType match {
     case ByteType | ShortType | IntegerType | LongType | FloatType | 
DoubleType => true
     case _ => false
@@ -3679,7 +3665,7 @@ case class ArrayUnion(left: Expression, right: 
Expression) extends ArraySetLike
   with ComplexTypeMergingExpression {
 
   @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = {
-    if (elementTypeSupportEquals) {
+    if (TypeUtils.typeWithProperEquals(elementType)) {
       (array1, array2) =>
         val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
         val hs = new OpenHashSet[Any]
@@ -3896,7 +3882,7 @@ case class ArrayIntersect(left: Expression, right: 
Expression) extends ArraySetL
   }
 
   @transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData = {
-    if (elementTypeSupportEquals) {
+    if (TypeUtils.typeWithProperEquals(elementType)) {
       (array1, array2) =>
         if (array1.numElements() != 0 && array2.numElements() != 0) {
           val hs = new OpenHashSet[Any]
@@ -4136,7 +4122,7 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArraySetLike
   }
 
   @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = {
-    if (elementTypeSupportEquals) {
+    if (TypeUtils.typeWithProperEquals(elementType)) {
       (array1, array2) =>
         val hs = new OpenHashSet[Any]
         var notFoundNullElement = true

http://git-wip-us.apache.org/repos/asf/spark/blob/8af61fba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index f667a64..3e0621d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -683,12 +683,6 @@ case class MapZipWith(left: Expression, right: Expression, 
function: Expression)
     value2Var: NamedLambdaVariable),
     _) = function
 
-  private def keyTypeSupportsEquals = keyType match {
-    case BinaryType => false
-    case _: AtomicType => true
-    case _ => false
-  }
-
   /**
    * The function accepts two key arrays and returns a collection of keys with 
indexes
    * to value arrays. Indexes are represented as an array of two items. This 
is a small
@@ -696,7 +690,7 @@ case class MapZipWith(left: Expression, right: Expression, 
function: Expression)
    */
   @transient private lazy val getKeysWithValueIndexes:
       (ArrayData, ArrayData) => mutable.Iterable[(Any, Array[Option[Int]])] = {
-    if (keyTypeSupportsEquals) {
+    if (TypeUtils.typeWithProperEquals(keyType)) {
       getKeysWithIndexesFast
     } else {
       getKeysWithIndexesBruteForce

http://git-wip-us.apache.org/repos/asf/spark/blob/8af61fba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 5214cdc..76218b4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.RowOrdering
 import org.apache.spark.sql.types._
 
 /**
- * Helper functions to check for valid data types.
+ * Functions to help with checking for valid data types and value comparison 
of various types.
  */
 object TypeUtils {
   def checkForNumericExpr(dt: DataType, caller: String): TypeCheckResult = {
@@ -73,4 +73,15 @@ object TypeUtils {
     }
     x.length - y.length
   }
+
+  /**
+   * Returns true if the equals method of the elements of the data type is 
implemented properly.
+   * This also means that they can be safely used in collections relying on 
the equals method,
+   * as sets or maps.
+   */
+  def typeWithProperEquals(dataType: DataType): Boolean = dataType match {
+    case BinaryType => false
+    case _: AtomicType => true
+    case _ => false
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to