Github user ueshin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21050#discussion_r195856243
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -2355,3 +2356,319 @@ case class ArrayRemove(left: Expression, right: 
Expression)
     
       override def prettyName: String = "array_remove"
     }
    +
    +/**
    + * Removes duplicate values from the array.
    + */
    +@ExpressionDescription(
    +  usage = "_FUNC_(array) - Removes duplicate values from the array.",
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_(array(1, 2, 3, null, 3));
    +       [1,2,3,null]
    +  """, since = "2.4.0")
    +case class ArrayDistinct(child: Expression)
    +  extends UnaryExpression with ExpectsInputTypes {
    +
    +  override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
    +
    +  override def dataType: DataType = child.dataType
    +
    +  lazy val elementType: DataType = 
dataType.asInstanceOf[ArrayType].elementType
    +
    +  @transient private lazy val ordering: Ordering[Any] =
    +    TypeUtils.getInterpretedOrdering(elementType)
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    super.checkInputDataTypes() match {
    +      case f: TypeCheckResult.TypeCheckFailure => f
    +      case TypeCheckResult.TypeCheckSuccess =>
    +        TypeUtils.checkForOrderingExpr(elementType, s"function 
$prettyName")
    +    }
    +  }
    +
    +  @transient private lazy val elementTypeSupportEquals = elementType match 
{
    +    case BinaryType => false
    +    case _: AtomicType => true
    +    case _ => false
    +  }
    +
    +  override def nullSafeEval(array: Any): Any = {
    +    val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
    +    if (elementTypeSupportEquals) {
    +      new GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
    +    } else {
    +      var foundNullElement = false
    +      var pos = 0
    +      for(i <- 0 until data.length) {
    +        if (data(i) == null) {
    +          if (!foundNullElement) {
    +            foundNullElement = true
    +            pos = pos + 1
    +          }
    +        } else {
    +          var j = 0
    +          var done = false
    +          while (j <= i && !done) {
    +            if (data(j) != null && ordering.equiv(data(j), data(i))) {
    +              done = true
    +            }
    +            j = j + 1
    +          }
    +          if (i == j-1) {
    +            pos = pos + 1
    +          }
    +        }
    +      }
    +      new GenericArrayData(data.slice(0, pos))
    +    }
    +  }
    +
    +  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    +    nullSafeCodeGen(ctx, ev, (array) => {
    +      val i = ctx.freshName("i")
    +      val j = ctx.freshName("j")
    +      val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray")
    +      val getValue1 = CodeGenerator.getValue(array, elementType, i)
    +      val getValue2 = CodeGenerator.getValue(array, elementType, j)
    +      val foundNullElement = ctx.freshName("foundNullElement")
    +      val openHashSet = classOf[OpenHashSet[_]].getName
    +      val hs = ctx.freshName("hs")
    +      val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
    +      if(elementTypeSupportEquals) {
    +        s"""
    +           |int $sizeOfDistinctArray = 0;
    +           |boolean $foundNullElement = false;
    +           |$openHashSet $hs = new $openHashSet($classTag);
    +           |for (int $i = 0; $i < $array.numElements(); $i++) {
    +           |  if ($array.isNullAt($i)) {
    +           |     if (!($foundNullElement)) {
    +           |       $foundNullElement = true;
    +           |     }
    +           |  }
    +           |  else {
    +           |    if (!($hs.contains($getValue1))) {
    +           |      $hs.add($getValue1);
    +           |    }
    +           |  }
    +           |}
    +           |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 
0);
    +           |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
    +         """.stripMargin
    +      }
    +      else {
    +        s"""
    +           |int $sizeOfDistinctArray = 0;
    +           |boolean $foundNullElement = false;
    +           |for (int $i = 0; $i < $array.numElements(); $i ++) {
    +           |  if ($array.isNullAt($i)) {
    +           |     if (!($foundNullElement)) {
    +           |       $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
    +           |       $foundNullElement = true;
    +           |     }
    +           |  }
    +           |  else {
    +           |    int $j;
    +           |    for ($j = 0; $j < $i; $j++) {
    +           |      if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, 
getValue1, getValue2)})
    +           |        break;
    +           |    }
    +           |    if ($i == $j) {
    +           |     $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
    +           |    }
    +           |  }
    +           |}
    +           |
    +           |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
    +         """.stripMargin
    +      }
    +    })
    +  }
    +
    +  private def setNull(
    +      isPrimitive: Boolean,
    +      foundNullElement: String,
    +      distinctArray: String,
    +      pos: String): String = {
    +    val setNullValue =
    +      if (!isPrimitive) {
    +        s"""
    +           |$distinctArray[$pos] = null;
    +        """.
    +          stripMargin
    +      } else {
    +        s"""
    +           |$distinctArray.setNullAt($pos);
    +        """.
    +          stripMargin
    +      }
    +
    +    s"""
    +       |if (!($foundNullElement)) {
    +       |  $setNullValue;
    +       |  $pos = $pos + 1;
    +       |  $foundNullElement = true;
    +       |}
    +    """.stripMargin
    +  }
    +
    +  private def setNotNullValue(isPrimitive: Boolean,
    +      distinctArray: String,
    +      pos: String,
    +      getValue1: String,
    +      primitiveValueTypeName: String): String = {
    +    if (!isPrimitive) {
    +      s"""
    +         |$distinctArray[$pos] = $getValue1;
    +      """.stripMargin
    +    } else {
    +      s"""
    +         |$distinctArray.set$primitiveValueTypeName($pos, $getValue1);
    +      """.stripMargin
    +    }
    +  }
    +
    +  private def setValueForFastEval(
    +      isPrimitive: Boolean,
    +      hs: String,
    +      distinctArray: String,
    +      pos: String,
    +      getValue1: String,
    +      primitiveValueTypeName: String): String = {
    +    val setValue = setNotNullValue(isPrimitive,
    +      distinctArray, pos, getValue1, primitiveValueTypeName)
    +    s"""
    +       |if (!($hs.contains($getValue1))) {
    +       |  $hs.add($getValue1);
    +       |  $setValue;
    +       |  $pos = $pos + 1;
    +       |}
    +    """.stripMargin
    +  }
    +
    +  private def setValueForbruteForceEval(
    +      isPrimitive: Boolean,
    +      i: String,
    +      j: String,
    +      inputArray: String,
    +      distinctArray: String,
    +      pos: String,
    +      getValue1: String,
    +      isEqual: String,
    +      primitiveValueTypeName: String): String = {
    +    val setValue = setNotNullValue(isPrimitive,
    +      distinctArray, pos, getValue1, primitiveValueTypeName)
    +    s"""
    +       |int $j;
    +       |for ($j = 0; $j < $i; $j ++) {
    +       |  if (!$inputArray.isNullAt($j) && $isEqual)
    +       |    break;
    +       |  }
    +       |  if ($i == $j) {
    +       |    $setValue;
    +       |    $pos = $pos + 1;
    +       |  }
    +    """.stripMargin
    +  }
    +
    +  def genCodeForResult(
    +      ctx: CodegenContext,
    +      ev: ExprCode,
    +      inputArray: String,
    +      size: String): String = {
    +    val distinctArray = ctx.freshName("distinctArray")
    +    val i = ctx.freshName("i")
    +    val j = ctx.freshName("j")
    +    val pos = ctx.freshName("pos")
    +    val getValue1 = CodeGenerator.getValue(inputArray, elementType, i)
    +    val getValue2 = CodeGenerator.getValue(inputArray, elementType, j)
    +    val isEqual = ctx.genEqual(elementType, getValue1, getValue2)
    +    val foundNullElement = ctx.freshName("foundNullElement")
    +    val hs = ctx.freshName("hs")
    +    val openHashSet = classOf[OpenHashSet[_]].getName
    +    if (!CodeGenerator.isPrimitiveType(elementType)) {
    +      val arrayClass = classOf[GenericArrayData].getName
    +      val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
    +      val setNullForNonPrimitive =
    +        setNull(false, foundNullElement, distinctArray, pos)
    +      if (elementTypeSupportEquals) {
    +        val setValueForFast = setValueForFastEval(false, hs, 
distinctArray, pos, getValue1, "")
    +        s"""
    +           |int $pos = 0;
    +           |Object[] $distinctArray = new Object[$size];
    +           |boolean $foundNullElement = false;
    +           |$openHashSet $hs = new $openHashSet($classTag);
    +           |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
    +           |  if ($inputArray.isNullAt($i)) {
    +           |    $setNullForNonPrimitive;
    +           |  }
    +           |  else {
    +           |    $setValueForFast;
    +           |  }
    +           |}
    +           |${ev.value} = new $arrayClass($distinctArray);
    +        """.stripMargin
    +      }
    +      else {
    +        val setValueForbruteForce = setValueForbruteForceEval(false, i, j,
    +                      inputArray, distinctArray, pos, getValue1: String, 
isEqual, "")
    --- End diff --
    
    nit: indent.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to