Repository: spark Updated Branches: refs/heads/master f16140975 -> 8af61fba0
[SPARK-25122][SQL] Deduplication of supports equals code ## What changes were proposed in this pull request? The method ```*supportEquals``` determining whether elements of a data type could be used as items in a hash set or as keys in a hash map is duplicated across multiple collection and higher-order functions. This PR suggests to deduplicate the method. ## How was this patch tested? Run tests in: - DataFrameFunctionsSuite - CollectionExpressionsSuite - HigherOrderExpressionsSuite Closes #22110 from mn-mikke/SPARK-25122. Authored-by: Marek Novotny <mn.mi...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8af61fba Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8af61fba Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8af61fba Branch: refs/heads/master Commit: 8af61fba03e1d32ddee4e83717fc8137682ffae6 Parents: f161409 Author: Marek Novotny <mn.mi...@gmail.com> Authored: Fri Aug 17 11:52:16 2018 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Fri Aug 17 11:52:16 2018 +0800 ---------------------------------------------------------------------- .../expressions/collectionOperations.scala | 38 +++++++------------- .../expressions/higherOrderFunctions.scala | 8 +---- .../spark/sql/catalyst/util/TypeUtils.scala | 13 ++++++- 3 files changed, 25 insertions(+), 34 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/8af61fba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5e3449d..cf9796e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1505,13 +1505,7 @@ case class ArraysOverlap(left: Expression, right: Expression) @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) - @transient private lazy val elementTypeSupportEquals = elementType match { - case BinaryType => false - case _: AtomicType => true - case _ => false - } - - @transient private lazy val doEvaluation = if (elementTypeSupportEquals) { + @transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) { fastEval _ } else { bruteForceEval _ @@ -1593,7 +1587,7 @@ case class ArraysOverlap(left: Expression, right: Expression) nullSafeCodeGen(ctx, ev, (a1, a2) => { val smaller = ctx.freshName("smallerArray") val bigger = ctx.freshName("biggerArray") - val comparisonCode = if (elementTypeSupportEquals) { + val comparisonCode = if (TypeUtils.typeWithProperEquals(elementType)) { fastCodegen(ctx, ev, smaller, bigger) } else { bruteForceCodegen(ctx, ev, smaller, bigger) @@ -3404,12 +3398,6 @@ case class ArrayDistinct(child: Expression) } } - @transient private lazy val elementTypeSupportEquals = elementType match { - case BinaryType => false - case _: AtomicType => true - case _ => false - } - @transient protected lazy val canUseSpecializedHashSet = elementType match { case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true case _ => false @@ -3434,9 +3422,13 @@ case class ArrayDistinct(child: Expression) override def nullSafeEval(array: Any): Any = { val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) - if (elementTypeSupportEquals) { - new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) - } else { + doEvaluation(data) + } + + @transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) { + (data: Array[AnyRef]) => new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) + } else { + (data: Array[AnyRef]) => { var foundNullElement = false var pos = 0 for (i <- 0 until data.length) { @@ -3576,12 +3568,6 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { @transient protected lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) - @transient protected lazy val elementTypeSupportEquals = elementType match { - case BinaryType => false - case _: AtomicType => true - case _ => false - } - @transient protected lazy val canUseSpecializedHashSet = elementType match { case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true case _ => false @@ -3679,7 +3665,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike with ComplexTypeMergingExpression { @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = { - if (elementTypeSupportEquals) { + if (TypeUtils.typeWithProperEquals(elementType)) { (array1, array2) => val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val hs = new OpenHashSet[Any] @@ -3896,7 +3882,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL } @transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData = { - if (elementTypeSupportEquals) { + if (TypeUtils.typeWithProperEquals(elementType)) { (array1, array2) => if (array1.numElements() != 0 && array2.numElements() != 0) { val hs = new OpenHashSet[Any] @@ -4136,7 +4122,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike } @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = { - if (elementTypeSupportEquals) { + if (TypeUtils.typeWithProperEquals(elementType)) { (array1, array2) => val hs = new OpenHashSet[Any] var notFoundNullElement = true http://git-wip-us.apache.org/repos/asf/spark/blob/8af61fba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index f667a64..3e0621d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -683,12 +683,6 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) value2Var: NamedLambdaVariable), _) = function - private def keyTypeSupportsEquals = keyType match { - case BinaryType => false - case _: AtomicType => true - case _ => false - } - /** * The function accepts two key arrays and returns a collection of keys with indexes * to value arrays. Indexes are represented as an array of two items. This is a small @@ -696,7 +690,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) */ @transient private lazy val getKeysWithValueIndexes: (ArrayData, ArrayData) => mutable.Iterable[(Any, Array[Option[Int]])] = { - if (keyTypeSupportsEquals) { + if (TypeUtils.typeWithProperEquals(keyType)) { getKeysWithIndexesFast } else { getKeysWithIndexesBruteForce http://git-wip-us.apache.org/repos/asf/spark/blob/8af61fba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 5214cdc..76218b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.types._ /** - * Helper functions to check for valid data types. + * Functions to help with checking for valid data types and value comparison of various types. */ object TypeUtils { def checkForNumericExpr(dt: DataType, caller: String): TypeCheckResult = { @@ -73,4 +73,15 @@ object TypeUtils { } x.length - y.length } + + /** + * Returns true if the equals method of the elements of the data type is implemented properly. + * This also means that they can be safely used in collections relying on the equals method, + * as sets or maps. + */ + def typeWithProperEquals(dataType: DataType): Boolean = dataType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org