Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21050#discussion_r195854348 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -2376,112 +2376,297 @@ case class ArrayDistinct(child: Expression) lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") + } + } + + @transient private lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + override def nullSafeEval(array: Any): Any = { - val elementType = child.dataType.asInstanceOf[ArrayType].elementType - val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType).distinct - new GenericArrayData(data.asInstanceOf[Array[Any]]) + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + if (elementTypeSupportEquals) { + new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) + } else { + var foundNullElement = false + var pos = 0 + for(i <- 0 until data.length) { + if (data(i) == null) { + if (!foundNullElement) { + foundNullElement = true + pos = pos + 1 + } + } else { + var j = 0 + var done = false + while (j <= i && !done) { + if (data(j) != null && ordering.equiv(data(j), data(i))) { + done = true + } + j = j + 1 + } + if (i == j-1) { + pos = pos + 1 + } + } + } + new GenericArrayData(data.slice(0, pos)) + } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (array) => { val i = ctx.freshName("i") val j = ctx.freshName("j") - val hs = ctx.freshName("hs") + val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray") + val getValue1 = CodeGenerator.getValue(array, elementType, i) + val getValue2 = CodeGenerator.getValue(array, elementType, j) val foundNullElement = ctx.freshName("foundNullElement") - val distinctArrayLen = ctx.freshName("distinctArrayLen") - val getValue = CodeGenerator.getValue(array, elementType, i) val openHashSet = classOf[OpenHashSet[_]].getName + val hs = ctx.freshName("hs") val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" + if(elementTypeSupportEquals) { + s""" + |int $sizeOfDistinctArray = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $array.numElements(); $i++) { + | if ($array.isNullAt($i)) { + | if (!($foundNullElement)) { --- End diff -- We don't need to check this and can do simply `$foundNullElement = true;`?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org