This is an automated email from the ASF dual-hosted git repository.

yamamuro pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 1222ce0  [SPARK-31500][SQL] collect_set() of BinaryType returns 
duplicate elements
1222ce0 is described below

commit 1222ce064f97ed9ad34e2fca4d270762592a1854
Author: Pablo Langa <soy...@gmail.com>
AuthorDate: Fri May 1 22:09:04 2020 +0900

    [SPARK-31500][SQL] collect_set() of BinaryType returns duplicate elements
    
    ### What changes were proposed in this pull request?
    
    The collect_set() aggregate function should produce a set of distinct 
elements. When the column argument's type is BinayType this is not the case.
    
    Example:
    ```scala
    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.expressions.Window
    
    case class R(id: String, value: String, bytes: Array[Byte])
    def makeR(id: String, value: String) = R(id, value, value.getBytes)
    val df = Seq(makeR("a", "dog"), makeR("a", "cat"), makeR("a", "cat"), 
makeR("b", "fish")).toDF()
    // In the example below "bytesSet" erroneously has duplicates but 
"stringSet" does not (as expected).
    df.agg(collect_set('value) as "stringSet", collect_set('bytes) as 
"byteSet").show(truncate=false)
    // The same problem is displayed when using window functions.
    val win = Window.partitionBy('id).rowsBetween(Window.unboundedPreceding, 
Window.unboundedFollowing)
    val result = df.select(
      collect_set('value).over(win) as "stringSet",
      collect_set('bytes).over(win) as "bytesSet"
    )
    .select('stringSet, 'bytesSet, size('stringSet) as "stringSetSize", 
size('bytesSet) as "bytesSetSize")
    .show()
    ```
    
    We use a HashSet buffer to accumulate the results, the problem is that 
arrays equality in Scala don't behave as expected, arrays ara just plain java 
arrays and the equality don't compare the content of the arrays
    Array(1, 2, 3) == Array(1, 2, 3)  => False
    The result is that duplicates are not removed in the hashset
    
    The solution proposed is that in the last stage, when we have all the data 
in the Hashset buffer, we delete duplicates changing the type of the elements 
and then transform it to the original type.
    This transformation is only applied when we have a BinaryType
    
    ### Why are the changes needed?
    Fix the bug explained
    
    ### Does this PR introduce any user-facing change?
    Yes. Now `collect_set()` correctly deduplicates array of byte.
    
    ### How was this patch tested?
    Unit testing
    
    Closes #28351 from planga82/feature/SPARK-31500_COLLECT_SET_bug.
    
    Authored-by: Pablo Langa <soy...@gmail.com>
    Signed-off-by: Takeshi Yamamuro <yamam...@apache.org>
    (cherry picked from commit 4fecc20f6ecdfe642890cf0a368a85558c40a47c)
    Signed-off-by: Takeshi Yamamuro <yamam...@apache.org>
---
 .../catalyst/expressions/aggregate/collect.scala   | 45 +++++++++++++++++++---
 .../apache/spark/sql/DataFrameAggregateSuite.scala | 16 ++++++++
 2 files changed, 55 insertions(+), 6 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index be972f0..8dc3171 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -23,6 +23,7 @@ import scala.collection.mutable
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.ArrayData
 import org.apache.spark.sql.catalyst.util.GenericArrayData
 import org.apache.spark.sql.types._
 
@@ -46,13 +47,15 @@ abstract class Collect[T <: Growable[Any] with 
Iterable[Any]] extends TypedImper
   // actual order of input rows.
   override lazy val deterministic: Boolean = false
 
+  protected def convertToBufferElement(value: Any): Any
+
   override def update(buffer: T, input: InternalRow): T = {
     val value = child.eval(input)
 
     // Do not allow null values. We follow the semantics of Hive's 
collect_list/collect_set here.
     // See: 
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
     if (value != null) {
-      buffer += InternalRow.copyValue(value)
+      buffer += convertToBufferElement(value)
     }
     buffer
   }
@@ -61,12 +64,10 @@ abstract class Collect[T <: Growable[Any] with 
Iterable[Any]] extends TypedImper
     buffer ++= other
   }
 
-  override def eval(buffer: T): Any = {
-    new GenericArrayData(buffer.toArray)
-  }
+  protected val bufferElementType: DataType
 
   private lazy val projection = UnsafeProjection.create(
-    Array[DataType](ArrayType(elementType = child.dataType, containsNull = 
false)))
+    Array[DataType](ArrayType(elementType = bufferElementType, containsNull = 
false)))
   private lazy val row = new UnsafeRow(1)
 
   override def serialize(obj: T): Array[Byte] = {
@@ -77,7 +78,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] 
extends TypedImper
   override def deserialize(bytes: Array[Byte]): T = {
     val buffer = createAggregationBuffer()
     row.pointTo(bytes, bytes.length)
-    row.getArray(0).foreach(child.dataType, (_, x: Any) => buffer += x)
+    row.getArray(0).foreach(bufferElementType, (_, x: Any) => buffer += x)
     buffer
   }
 }
@@ -94,6 +95,10 @@ case class CollectList(
 
   def this(child: Expression) = this(child, 0, 0)
 
+  override lazy val bufferElementType = child.dataType
+
+  override def convertToBufferElement(value: Any): Any = 
InternalRow.copyValue(value)
+
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)
 
@@ -103,6 +108,10 @@ case class CollectList(
   override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = 
mutable.ArrayBuffer.empty
 
   override def prettyName: String = "collect_list"
+
+  override def eval(buffer: mutable.ArrayBuffer[Any]): Any = {
+    new GenericArrayData(buffer.toArray)
+  }
 }
 
 /**
@@ -117,6 +126,30 @@ case class CollectSet(
 
   def this(child: Expression) = this(child, 0, 0)
 
+  override lazy val bufferElementType = child.dataType match {
+    case BinaryType => ArrayType(ByteType)
+    case other => other
+  }
+
+  override def convertToBufferElement(value: Any): Any = child.dataType match {
+    /*
+     * collect_set() of BinaryType should not return duplicate elements,
+     * Java byte arrays use referential equality and identity hash codes
+     * so we need to use a different catalyst value for arrays
+     */
+    case BinaryType => 
UnsafeArrayData.fromPrimitiveArray(value.asInstanceOf[Array[Byte]])
+    case _ => InternalRow.copyValue(value)
+  }
+
+  override def eval(buffer: mutable.HashSet[Any]): Any = {
+    val array = child.dataType match {
+      case BinaryType =>
+        buffer.iterator.map(_.asInstanceOf[ArrayData].toByteArray).toArray
+      case _ => buffer.toArray
+    }
+    new GenericArrayData(array)
+  }
+
   override def checkInputDataTypes(): TypeCheckResult = {
     if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) {
       TypeCheckResult.TypeCheckSuccess
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index d0106c4..bb7c68a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -509,6 +509,22 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
     )
   }
 
+  test("SPARK-31500: collect_set() of BinaryType returns duplicate elements") {
+    val bytesTest1 = "test1".getBytes
+    val bytesTest2 = "test2".getBytes
+    val df = Seq(bytesTest1, bytesTest1, bytesTest2).toDF("a")
+    checkAnswer(df.select(size(collect_set($"a"))), Row(2) :: Nil)
+
+    val a = "aa".getBytes
+    val b = "bb".getBytes
+    val c = "cc".getBytes
+    val d = "dd".getBytes
+    val df1 = Seq((a, b), (a, b), (c, d))
+      .toDF("x", "y")
+      .select(struct($"x", $"y").as("a"))
+    checkAnswer(df1.select(size(collect_set($"a"))), Row(2) :: Nil)
+  }
+
   test("collect_set functions cannot have maps") {
     val df = Seq((1, 3, 0), (2, 3, 0), (3, 4, 1))
       .toDF("a", "x", "y")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to