This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new aea9a57  [SPARK-27134][SQL] array_distinct function does not work 
correctly with columns containing array of array
aea9a57 is described below

commit aea9a574c44768d1d93ee7e8069729383859292c
Author: Dilip Biswal <dbis...@us.ibm.com>
AuthorDate: Sat Mar 16 14:30:42 2019 -0500

    [SPARK-27134][SQL] array_distinct function does not work correctly with 
columns containing array of array
    
    ## What changes were proposed in this pull request?
    Correct the logic to compute the distinct.
    
    Below is a small repro snippet.
    
    ```
    scala> val df = Seq(Seq(Seq(1, 2), Seq(1, 2), Seq(1, 2), Seq(3, 4), Seq(4, 
5))).toDF("array_col")
    df: org.apache.spark.sql.DataFrame = [array_col: array<array<int>>]
    
    scala> val distinctDF = df.select(array_distinct(col("array_col")))
    distinctDF: org.apache.spark.sql.DataFrame = [array_distinct(array_col): 
array<array<int>>]
    
    scala> df.show(false)
    +----------------------------------------+
    |array_col                               |
    +----------------------------------------+
    |[[1, 2], [1, 2], [1, 2], [3, 4], [4, 5]]|
    +----------------------------------------+
    ```
    Error
    ```
    scala> distinctDF.show(false)
    +-------------------------+
    |array_distinct(array_col)|
    +-------------------------+
    |[[1, 2], [1, 2], [1, 2]] |
    +-------------------------+
    ```
    Expected result
    ```
    scala> distinctDF.show(false)
    +-------------------------+
    |array_distinct(array_col)|
    +-------------------------+
    |[[1, 2], [3, 4], [4, 5]] |
    +-------------------------+
    ```
    ## How was this patch tested?
    Added an additional test.
    
    Closes #24073 from dilipbiswal/SPARK-27134.
    
    Authored-by: Dilip Biswal <dbis...@us.ibm.com>
    Signed-off-by: Sean Owen <sean.o...@databricks.com>
---
 .../expressions/collectionOperations.scala         | 34 +++++++++++-----------
 .../expressions/CollectionExpressionsSuite.scala   | 11 +++++++
 2 files changed, 28 insertions(+), 17 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 6cc5bc8..f7955bc 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -3112,29 +3112,29 @@ case class ArrayDistinct(child: Expression)
     (data: Array[AnyRef]) => new 
GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
   } else {
     (data: Array[AnyRef]) => {
-      var foundNullElement = false
-      var pos = 0
+      val arrayBuffer = new scala.collection.mutable.ArrayBuffer[AnyRef]
+      var alreadyStoredNull = false
       for (i <- 0 until data.length) {
-        if (data(i) == null) {
-          if (!foundNullElement) {
-            foundNullElement = true
-            pos = pos + 1
-          }
-        } else {
+        if (data(i) != null) {
+          var found = false
           var j = 0
-          var done = false
-          while (j <= i && !done) {
-            if (data(j) != null && ordering.equiv(data(j), data(i))) {
-              done = true
-            }
-            j = j + 1
+          while (!found && j < arrayBuffer.size) {
+            val va = arrayBuffer(j)
+            found = (va != null) && ordering.equiv(va, data(i))
+            j += 1
           }
-          if (i == j - 1) {
-            pos = pos + 1
+          if (!found) {
+            arrayBuffer += data(i)
+          }
+        } else {
+          // De-duplicate the null values.
+          if (!alreadyStoredNull) {
+            arrayBuffer += data(i)
+            alreadyStoredNull = true
           }
         }
       }
-      new GenericArrayData(data.slice(0, pos))
+      new GenericArrayData(arrayBuffer)
     }
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index c1c0459..603073b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -1364,6 +1364,8 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
       ArrayType(DoubleType))
     val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 
1.121f, 0.1234f),
       ArrayType(FloatType))
+    val a8 =
+      Literal.create(Seq(2, 1, 2, 3, 4, 4, 5).map(_.toString.getBytes), 
ArrayType(BinaryType))
 
     checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5))
     checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer])
@@ -1373,6 +1375,7 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
     checkEvaluation(new ArrayDistinct(a5), Seq(true, false))
     checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121))
     checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f))
+    checkEvaluation(new ArrayDistinct(a8), Seq(2, 1, 3, 4, 
5).map(_.toString.getBytes))
 
     // complex data types
     val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 
2),
@@ -1393,9 +1396,17 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
       ArrayType(ArrayType(IntegerType)))
     val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, 
Seq[Int](2, 1), null),
       ArrayType(ArrayType(IntegerType)))
+    val c3 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](1, 2), 
Seq[Int](1, 2),
+      Seq[Int](3, 4), Seq[Int](4, 5)), ArrayType(ArrayType(IntegerType)))
+    val c4 = Literal.create(Seq[Seq[Int]](null, Seq[Int](1, 2), Seq[Int](1, 2),
+      Seq[Int](3, 4), Seq[Int](4, 5), null), ArrayType(ArrayType(IntegerType)))
     checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), 
Seq[Int](3, 4)))
     checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), 
Seq[Int](2, 1)))
     checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1)))
+    checkEvaluation(ArrayDistinct(c3), Seq[Seq[Int]](Seq[Int](1, 2), 
Seq[Int](3, 4),
+      Seq[Int](4, 5)))
+    checkEvaluation(ArrayDistinct(c4), Seq[Seq[Int]](null, Seq[Int](1, 2), 
Seq[Int](3, 4),
+      Seq[Int](4, 5)))
   }
 
   test("Array Union") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to