This is an automated email from the ASF dual-hosted git repository. yamamuro pushed a commit to branch branch-2.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push: new 1222ce0 [SPARK-31500][SQL] collect_set() of BinaryType returns duplicate elements 1222ce0 is described below commit 1222ce064f97ed9ad34e2fca4d270762592a1854 Author: Pablo Langa <soy...@gmail.com> AuthorDate: Fri May 1 22:09:04 2020 +0900 [SPARK-31500][SQL] collect_set() of BinaryType returns duplicate elements ### What changes were proposed in this pull request? The collect_set() aggregate function should produce a set of distinct elements. When the column argument's type is BinayType this is not the case. Example: ```scala import org.apache.spark.sql.functions._ import org.apache.spark.sql.expressions.Window case class R(id: String, value: String, bytes: Array[Byte]) def makeR(id: String, value: String) = R(id, value, value.getBytes) val df = Seq(makeR("a", "dog"), makeR("a", "cat"), makeR("a", "cat"), makeR("b", "fish")).toDF() // In the example below "bytesSet" erroneously has duplicates but "stringSet" does not (as expected). df.agg(collect_set('value) as "stringSet", collect_set('bytes) as "byteSet").show(truncate=false) // The same problem is displayed when using window functions. val win = Window.partitionBy('id).rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) val result = df.select( collect_set('value).over(win) as "stringSet", collect_set('bytes).over(win) as "bytesSet" ) .select('stringSet, 'bytesSet, size('stringSet) as "stringSetSize", size('bytesSet) as "bytesSetSize") .show() ``` We use a HashSet buffer to accumulate the results, the problem is that arrays equality in Scala don't behave as expected, arrays ara just plain java arrays and the equality don't compare the content of the arrays Array(1, 2, 3) == Array(1, 2, 3) => False The result is that duplicates are not removed in the hashset The solution proposed is that in the last stage, when we have all the data in the Hashset buffer, we delete duplicates changing the type of the elements and then transform it to the original type. This transformation is only applied when we have a BinaryType ### Why are the changes needed? Fix the bug explained ### Does this PR introduce any user-facing change? Yes. Now `collect_set()` correctly deduplicates array of byte. ### How was this patch tested? Unit testing Closes #28351 from planga82/feature/SPARK-31500_COLLECT_SET_bug. Authored-by: Pablo Langa <soy...@gmail.com> Signed-off-by: Takeshi Yamamuro <yamam...@apache.org> (cherry picked from commit 4fecc20f6ecdfe642890cf0a368a85558c40a47c) Signed-off-by: Takeshi Yamamuro <yamam...@apache.org> --- .../catalyst/expressions/aggregate/collect.scala | 45 +++++++++++++++++++--- .../apache/spark/sql/DataFrameAggregateSuite.scala | 16 ++++++++ 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index be972f0..8dc3171 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -23,6 +23,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ @@ -46,13 +47,15 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper // actual order of input rows. override lazy val deterministic: Boolean = false + protected def convertToBufferElement(value: Any): Any + override def update(buffer: T, input: InternalRow): T = { val value = child.eval(input) // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator if (value != null) { - buffer += InternalRow.copyValue(value) + buffer += convertToBufferElement(value) } buffer } @@ -61,12 +64,10 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper buffer ++= other } - override def eval(buffer: T): Any = { - new GenericArrayData(buffer.toArray) - } + protected val bufferElementType: DataType private lazy val projection = UnsafeProjection.create( - Array[DataType](ArrayType(elementType = child.dataType, containsNull = false))) + Array[DataType](ArrayType(elementType = bufferElementType, containsNull = false))) private lazy val row = new UnsafeRow(1) override def serialize(obj: T): Array[Byte] = { @@ -77,7 +78,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper override def deserialize(bytes: Array[Byte]): T = { val buffer = createAggregationBuffer() row.pointTo(bytes, bytes.length) - row.getArray(0).foreach(child.dataType, (_, x: Any) => buffer += x) + row.getArray(0).foreach(bufferElementType, (_, x: Any) => buffer += x) buffer } } @@ -94,6 +95,10 @@ case class CollectList( def this(child: Expression) = this(child, 0, 0) + override lazy val bufferElementType = child.dataType + + override def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -103,6 +108,10 @@ case class CollectList( override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty override def prettyName: String = "collect_list" + + override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { + new GenericArrayData(buffer.toArray) + } } /** @@ -117,6 +126,30 @@ case class CollectSet( def this(child: Expression) = this(child, 0, 0) + override lazy val bufferElementType = child.dataType match { + case BinaryType => ArrayType(ByteType) + case other => other + } + + override def convertToBufferElement(value: Any): Any = child.dataType match { + /* + * collect_set() of BinaryType should not return duplicate elements, + * Java byte arrays use referential equality and identity hash codes + * so we need to use a different catalyst value for arrays + */ + case BinaryType => UnsafeArrayData.fromPrimitiveArray(value.asInstanceOf[Array[Byte]]) + case _ => InternalRow.copyValue(value) + } + + override def eval(buffer: mutable.HashSet[Any]): Any = { + val array = child.dataType match { + case BinaryType => + buffer.iterator.map(_.asInstanceOf[ArrayData].toByteArray).toArray + case _ => buffer.toArray + } + new GenericArrayData(array) + } + override def checkInputDataTypes(): TypeCheckResult = { if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) { TypeCheckResult.TypeCheckSuccess diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d0106c4..bb7c68a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -509,6 +509,22 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-31500: collect_set() of BinaryType returns duplicate elements") { + val bytesTest1 = "test1".getBytes + val bytesTest2 = "test2".getBytes + val df = Seq(bytesTest1, bytesTest1, bytesTest2).toDF("a") + checkAnswer(df.select(size(collect_set($"a"))), Row(2) :: Nil) + + val a = "aa".getBytes + val b = "bb".getBytes + val c = "cc".getBytes + val d = "dd".getBytes + val df1 = Seq((a, b), (a, b), (c, d)) + .toDF("x", "y") + .select(struct($"x", $"y").as("a")) + checkAnswer(df1.select(size(collect_set($"a"))), Row(2) :: Nil) + } + test("collect_set functions cannot have maps") { val df = Seq((1, 3, 0), (2, 3, 0), (3, 4, 1)) .toDF("a", "x", "y") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org