Repository: spark
Updated Branches:
  refs/heads/master 15747cfd3 -> 9de11d3f9


[SPARK-23912][SQL] add array_distinct

## What changes were proposed in this pull request?

Add array_distinct to remove duplicate value from the array.

## How was this patch tested?

Add unit tests

Author: Huaxin Gao <huax...@us.ibm.com>

Closes #21050 from huaxingao/spark-23912.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9de11d3f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9de11d3f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9de11d3f

Branch: refs/heads/master
Commit: 9de11d3f901bc206a33b9da3e7499bcd43e0142a
Parents: 15747cf
Author: Huaxin Gao <huax...@us.ibm.com>
Authored: Thu Jun 21 12:24:53 2018 +0900
Committer: Takuya UESHIN <ues...@databricks.com>
Committed: Thu Jun 21 12:24:53 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 |  14 +
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../expressions/collectionOperations.scala      | 279 +++++++++++++++++++
 .../CollectionExpressionsSuite.scala            |  45 +++
 .../scala/org/apache/spark/sql/functions.scala  |   7 +
 .../spark/sql/DataFrameFunctionsSuite.scala     |  22 ++
 6 files changed, 368 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9de11d3f/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index e634669..11b179f 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1999,6 +1999,20 @@ def array_remove(col, element):
     return Column(sc._jvm.functions.array_remove(_to_java_column(col), 
element))
 
 
+@since(2.4)
+def array_distinct(col):
+    """
+    Collection function: removes duplicate values from the array.
+    :param col: name of column or expression
+
+    >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], 
['data'])
+    >>> df.select(array_distinct(df.data)).collect()
+    [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.array_distinct(_to_java_column(col)))
+
+
 @since(1.4)
 def explode(col):
     """Returns a new row for each element in the given array or map.

http://git-wip-us.apache.org/repos/asf/spark/blob/9de11d3f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 3700c63..4b09b9a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -433,6 +433,7 @@ object FunctionRegistry {
     expression[Flatten]("flatten"),
     expression[ArrayRepeat]("array_repeat"),
     expression[ArrayRemove]("array_remove"),
+    expression[ArrayDistinct]("array_distinct"),
     CreateStruct.registryEntry,
 
     // mask functions

http://git-wip-us.apache.org/repos/asf/spark/blob/9de11d3f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index d76f301..7c064a1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.array.ByteArrayMethods
 import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
+import org.apache.spark.util.collection.OpenHashSet
 
 /**
  * Base trait for [[BinaryExpression]]s with two arrays of the same element 
type and implicit
@@ -2355,3 +2356,281 @@ case class ArrayRemove(left: Expression, right: 
Expression)
 
   override def prettyName: String = "array_remove"
 }
+
+/**
+ * Removes duplicate values from the array.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(array) - Removes duplicate values from the array.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 2, 3, null, 3));
+       [1,2,3,null]
+  """, since = "2.4.0")
+case class ArrayDistinct(child: Expression)
+  extends UnaryExpression with ExpectsInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
+
+  override def dataType: DataType = child.dataType
+
+  @transient lazy val elementType: DataType = 
dataType.asInstanceOf[ArrayType].elementType
+
+  @transient private lazy val ordering: Ordering[Any] =
+    TypeUtils.getInterpretedOrdering(elementType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    super.checkInputDataTypes() match {
+      case f: TypeCheckResult.TypeCheckFailure => f
+      case TypeCheckResult.TypeCheckSuccess =>
+        TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName")
+    }
+  }
+
+  @transient private lazy val elementTypeSupportEquals = elementType match {
+    case BinaryType => false
+    case _: AtomicType => true
+    case _ => false
+  }
+
+  override def nullSafeEval(array: Any): Any = {
+    val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
+    if (elementTypeSupportEquals) {
+      new GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
+    } else {
+      var foundNullElement = false
+      var pos = 0
+      for (i <- 0 until data.length) {
+        if (data(i) == null) {
+          if (!foundNullElement) {
+            foundNullElement = true
+            pos = pos + 1
+          }
+        } else {
+          var j = 0
+          var done = false
+          while (j <= i && !done) {
+            if (data(j) != null && ordering.equiv(data(j), data(i))) {
+              done = true
+            }
+            j = j + 1
+          }
+          if (i == j - 1) {
+            pos = pos + 1
+          }
+        }
+      }
+      new GenericArrayData(data.slice(0, pos))
+    }
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    nullSafeCodeGen(ctx, ev, (array) => {
+      val i = ctx.freshName("i")
+      val j = ctx.freshName("j")
+      val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray")
+      val getValue1 = CodeGenerator.getValue(array, elementType, i)
+      val getValue2 = CodeGenerator.getValue(array, elementType, j)
+      val foundNullElement = ctx.freshName("foundNullElement")
+      val openHashSet = classOf[OpenHashSet[_]].getName
+      val hs = ctx.freshName("hs")
+      val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
+      if (elementTypeSupportEquals) {
+        s"""
+           |int $sizeOfDistinctArray = 0;
+           |boolean $foundNullElement = false;
+           |$openHashSet $hs = new $openHashSet($classTag);
+           |for (int $i = 0; $i < $array.numElements(); $i ++) {
+           |  if ($array.isNullAt($i)) {
+           |    $foundNullElement = true;
+           |  } else {
+           |    $hs.add($getValue1);
+           |  }
+           |}
+           |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0);
+           |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
+         """.stripMargin
+      } else {
+        s"""
+           |int $sizeOfDistinctArray = 0;
+           |boolean $foundNullElement = false;
+           |for (int $i = 0; $i < $array.numElements(); $i ++) {
+           |  if ($array.isNullAt($i)) {
+           |     if (!($foundNullElement)) {
+           |       $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
+           |       $foundNullElement = true;
+           |     }
+           |  } else {
+           |    int $j;
+           |    for ($j = 0; $j < $i; $j ++) {
+           |      if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, 
getValue1, getValue2)}) {
+           |        break;
+           |      }
+           |    }
+           |    if ($i == $j) {
+           |     $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
+           |    }
+           |  }
+           |}
+           |
+           |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
+         """.stripMargin
+      }
+    })
+  }
+
+  private def setNull(
+      isPrimitive: Boolean,
+      foundNullElement: String,
+      distinctArray: String,
+      pos: String): String = {
+    val setNullValue =
+      if (!isPrimitive) {
+        s"$distinctArray[$pos] = null";
+      } else {
+        s"$distinctArray.setNullAt($pos)";
+      }
+
+    s"""
+       |if (!($foundNullElement)) {
+       |  $setNullValue;
+       |  $pos = $pos + 1;
+       |  $foundNullElement = true;
+       |}
+    """.stripMargin
+  }
+
+  private def setNotNullValue(isPrimitive: Boolean,
+      distinctArray: String,
+      pos: String,
+      getValue1: String,
+      primitiveValueTypeName: String): String = {
+    if (!isPrimitive) {
+      s"$distinctArray[$pos] = $getValue1";
+    } else {
+      s"$distinctArray.set$primitiveValueTypeName($pos, $getValue1)";
+    }
+  }
+
+  private def setValueForFastEval(
+      isPrimitive: Boolean,
+      hs: String,
+      distinctArray: String,
+      pos: String,
+      getValue1: String,
+      primitiveValueTypeName: String): String = {
+    val setValue = setNotNullValue(isPrimitive,
+      distinctArray, pos, getValue1, primitiveValueTypeName)
+    s"""
+       |if (!($hs.contains($getValue1))) {
+       |  $hs.add($getValue1);
+       |  $setValue;
+       |  $pos = $pos + 1;
+       |}
+    """.stripMargin
+  }
+
+  private def setValueForBruteForceEval(
+      isPrimitive: Boolean,
+      i: String,
+      j: String,
+      inputArray: String,
+      distinctArray: String,
+      pos: String,
+      getValue1: String,
+      isEqual: String,
+      primitiveValueTypeName: String): String = {
+    val setValue = setNotNullValue(isPrimitive,
+      distinctArray, pos, getValue1, primitiveValueTypeName)
+    s"""
+       |int $j;
+       |for ($j = 0; $j < $i; $j ++) {
+       |  if (!$inputArray.isNullAt($j) && $isEqual) {
+       |    break;
+       |  }
+       |}
+       |if ($i == $j) {
+       |  $setValue;
+       |  $pos = $pos + 1;
+       |}
+    """.stripMargin
+  }
+
+  def genCodeForResult(
+      ctx: CodegenContext,
+      ev: ExprCode,
+      inputArray: String,
+      size: String): String = {
+    val distinctArray = ctx.freshName("distinctArray")
+    val i = ctx.freshName("i")
+    val j = ctx.freshName("j")
+    val pos = ctx.freshName("pos")
+    val getValue1 = CodeGenerator.getValue(inputArray, elementType, i)
+    val getValue2 = CodeGenerator.getValue(inputArray, elementType, j)
+    val isEqual = ctx.genEqual(elementType, getValue1, getValue2)
+    val foundNullElement = ctx.freshName("foundNullElement")
+    val hs = ctx.freshName("hs")
+    val openHashSet = classOf[OpenHashSet[_]].getName
+    if (!CodeGenerator.isPrimitiveType(elementType)) {
+      val arrayClass = classOf[GenericArrayData].getName
+      val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
+      val setNullForNonPrimitive =
+        setNull(false, foundNullElement, distinctArray, pos)
+      if (elementTypeSupportEquals) {
+        val setValueForFast = setValueForFastEval(false, hs, distinctArray, 
pos, getValue1, "")
+        s"""
+           |int $pos = 0;
+           |Object[] $distinctArray = new Object[$size];
+           |boolean $foundNullElement = false;
+           |$openHashSet $hs = new $openHashSet($classTag);
+           |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
+           |  if ($inputArray.isNullAt($i)) {
+           |    $setNullForNonPrimitive;
+           |  } else {
+           |    $setValueForFast;
+           |  }
+           |}
+           |${ev.value} = new $arrayClass($distinctArray);
+        """.stripMargin
+      } else {
+        val setValueForBruteForce = setValueForBruteForceEval(
+          false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "")
+        s"""
+           |int $pos = 0;
+           |Object[] $distinctArray = new Object[$size];
+           |boolean $foundNullElement = false;
+           |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
+           |  if ($inputArray.isNullAt($i)) {
+           |    $setNullForNonPrimitive;
+           |  } else {
+           |    $setValueForBruteForce;
+           |  }
+           |}
+           |${ev.value} = new $arrayClass($distinctArray);
+       """.stripMargin
+      }
+    } else {
+      val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+      val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, 
pos)
+      val classTag = 
s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()"
+      val setValueForFast =
+        setValueForFastEval(true, hs, distinctArray, pos, getValue1, 
primitiveValueTypeName)
+      s"""
+         |${ctx.createUnsafeArray(distinctArray, size, elementType, s" 
$prettyName failed.")}
+         |int $pos = 0;
+         |boolean $foundNullElement = false;
+         |$openHashSet $hs = new $openHashSet($classTag);
+         |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
+         |  if ($inputArray.isNullAt($i)) {
+         |    $setNullForPrimitive;
+         |  } else {
+         |    $setValueForFast;
+         |  }
+         |}
+         |${ev.value} = $distinctArray;
+      """.stripMargin
+    }
+  }
+
+  override def prettyName: String = "array_distinct"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9de11d3f/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 85e692b..f377f9c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -766,4 +766,49 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
     checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 
6), Seq[Int](2, 1)))
     checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, 
Seq[Int](2, 1)))
   }
+
+  test("Array Distinct") {
+    val a0 = Literal.create(Seq(2, 1, 2, 3, 4, 4, 5), ArrayType(IntegerType))
+    val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
+    val a2 = Literal.create(Seq("b", "a", "a", "c", "b"), 
ArrayType(StringType))
+    val a3 = Literal.create(Seq("b", null, "a", null, "a", null), 
ArrayType(StringType))
+    val a4 = Literal.create(Seq(null, null, null), ArrayType(NullType))
+    val a5 = Literal.create(Seq(true, false, false, true), 
ArrayType(BooleanType))
+    val a6 = Literal.create(Seq(1.123, 0.1234, 1.121, 1.123, 1.1230, 1.121, 
0.1234),
+      ArrayType(DoubleType))
+    val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 
1.121f, 0.1234f),
+      ArrayType(FloatType))
+
+    checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5))
+    checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer])
+    checkEvaluation(new ArrayDistinct(a2), Seq("b", "a", "c"))
+    checkEvaluation(new ArrayDistinct(a3), Seq("b", null, "a"))
+    checkEvaluation(new ArrayDistinct(a4), Seq(null))
+    checkEvaluation(new ArrayDistinct(a5), Seq(true, false))
+    checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121))
+    checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f))
+
+    // complex data types
+    val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 
2),
+      Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType))
+    val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null),
+      ArrayType(BinaryType))
+    val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), null, 
Array[Byte](1, 2),
+      null, Array[Byte](5, 6), null), ArrayType(BinaryType))
+
+    checkEvaluation(ArrayDistinct(b0), Seq[Array[Byte]](Array[Byte](5, 6), 
Array[Byte](1, 2)))
+    checkEvaluation(ArrayDistinct(b1), Seq[Array[Byte]](Array[Byte](2, 1), 
null))
+    checkEvaluation(ArrayDistinct(b2), Seq[Array[Byte]](Array[Byte](5, 6), 
null,
+      Array[Byte](1, 2)))
+
+    val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), 
Seq[Int](1, 2),
+      Seq[Int](3, 4), Seq[Int](1, 2)), ArrayType(ArrayType(IntegerType)))
+    val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
+      ArrayType(ArrayType(IntegerType)))
+    val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, 
Seq[Int](2, 1), null),
+      ArrayType(ArrayType(IntegerType)))
+    checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), 
Seq[Int](3, 4)))
+    checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), 
Seq[Int](2, 1)))
+    checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1)))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9de11d3f/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 8551058..965dbb6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3190,6 +3190,13 @@ object functions {
   }
 
   /**
+   * Removes duplicate values from the array.
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) }
+
+  /**
    * Creates a new row for each element in the given array or map column.
    *
    * @group collection_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/9de11d3f/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 4e5c1c5..3dc696b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -1216,6 +1216,28 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     assert(e.message.contains("argument 1 requires array type, however, '`_1`' 
is of string type"))
   }
 
+  test("array_distinct functions") {
+    val df = Seq(
+      (Array[Int](2, 1, 3, 4, 3, 5), Array("b", "c", "a", "c", "b", "", "")),
+      (Array.empty[Int], Array.empty[String]),
+      (null, null)
+    ).toDF("a", "b")
+    checkAnswer(
+      df.select(array_distinct($"a"), array_distinct($"b")),
+      Seq(
+        Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")),
+        Row(Seq.empty[Int], Seq.empty[String]),
+        Row(null, null))
+    )
+    checkAnswer(
+      df.selectExpr("array_distinct(a)", "array_distinct(b)"),
+      Seq(
+        Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")),
+        Row(Seq.empty[Int], Seq.empty[String]),
+        Row(null, null))
+    )
+  }
+
   private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
     import DataFrameFunctionsSuite.CodegenFallbackExpr
     for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), 
(false, true))) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to