Github user ueshin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21050#discussion_r195854348
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -2376,112 +2376,297 @@ case class ArrayDistinct(child: Expression)
     
       lazy val elementType: DataType = 
dataType.asInstanceOf[ArrayType].elementType
     
    +  @transient private lazy val ordering: Ordering[Any] =
    +    TypeUtils.getInterpretedOrdering(elementType)
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    super.checkInputDataTypes() match {
    +      case f: TypeCheckResult.TypeCheckFailure => f
    +      case TypeCheckResult.TypeCheckSuccess =>
    +        TypeUtils.checkForOrderingExpr(elementType, s"function 
$prettyName")
    +    }
    +  }
    +
    +  @transient private lazy val elementTypeSupportEquals = elementType match 
{
    +    case BinaryType => false
    +    case _: AtomicType => true
    +    case _ => false
    +  }
    +
       override def nullSafeEval(array: Any): Any = {
    -    val elementType = child.dataType.asInstanceOf[ArrayType].elementType
    -    val data = 
array.asInstanceOf[ArrayData].toArray[AnyRef](elementType).distinct
    -    new GenericArrayData(data.asInstanceOf[Array[Any]])
    +    val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
    +    if (elementTypeSupportEquals) {
    +      new GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
    +    } else {
    +      var foundNullElement = false
    +      var pos = 0
    +      for(i <- 0 until data.length) {
    +        if (data(i) == null) {
    +          if (!foundNullElement) {
    +            foundNullElement = true
    +            pos = pos + 1
    +          }
    +        } else {
    +          var j = 0
    +          var done = false
    +          while (j <= i && !done) {
    +            if (data(j) != null && ordering.equiv(data(j), data(i))) {
    +              done = true
    +            }
    +            j = j + 1
    +          }
    +          if (i == j-1) {
    +            pos = pos + 1
    +          }
    +        }
    +      }
    +      new GenericArrayData(data.slice(0, pos))
    +    }
       }
     
       override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
         nullSafeCodeGen(ctx, ev, (array) => {
           val i = ctx.freshName("i")
           val j = ctx.freshName("j")
    -      val hs = ctx.freshName("hs")
    +      val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray")
    +      val getValue1 = CodeGenerator.getValue(array, elementType, i)
    +      val getValue2 = CodeGenerator.getValue(array, elementType, j)
           val foundNullElement = ctx.freshName("foundNullElement")
    -      val distinctArrayLen = ctx.freshName("distinctArrayLen")
    -      val getValue = CodeGenerator.getValue(array, elementType, i)
           val openHashSet = classOf[OpenHashSet[_]].getName
    +      val hs = ctx.freshName("hs")
           val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
    +      if(elementTypeSupportEquals) {
    +        s"""
    +           |int $sizeOfDistinctArray = 0;
    +           |boolean $foundNullElement = false;
    +           |$openHashSet $hs = new $openHashSet($classTag);
    +           |for (int $i = 0; $i < $array.numElements(); $i++) {
    +           |  if ($array.isNullAt($i)) {
    +           |     if (!($foundNullElement)) {
    --- End diff --
    
    We don't need to check this and can do simply `$foundNullElement = true;`?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to