Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21069#discussion_r191956713 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -1882,3 +1882,117 @@ case class ArrayRepeat(left: Expression, right: Expression) } } + +/** + * Remove all elements that equal to element from the given array + */ +@ExpressionDescription( + usage = "_FUNC_(array, element) - Remove all elements that equal to element from array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, null, 3), 3); + [1,2,null] + """, since = "2.4.0") +case class ArrayRemove(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = left.dataType + + override def inputTypes: Seq[AbstractDataType] = + Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + + lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (!left.dataType.isInstanceOf[ArrayType] + || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + "Arguments must be an array followed by a value of same type as the array members") + } else { + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + } + } + + override def nullSafeEval(arr: Any, value: Any): Any = { + val newArray = new Array[Any](arr.asInstanceOf[ArrayData].numElements()) + var pos = 0 + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null || !ordering.equiv(v, value)) { + newArray(pos) = v + pos += 1 + } + ) + new GenericArrayData(newArray.slice(0, pos)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (arr, value) => { + val numsToRemove = ctx.freshName("numsToRemove") + val newArraySize = ctx.freshName("newArraySize") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(arr, elementType, i) + val isEqual = ctx.genEqual(elementType, value, getValue) + s""" + |int $numsToRemove = 0; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | if (!$arr.isNullAt($i) && $isEqual) { + | $numsToRemove = $numsToRemove + 1; + | } + |} + |int $newArraySize = $arr.numElements() - $numsToRemove; + |${genCodeForResult(ctx, ev, arr, value, newArraySize)} + """.stripMargin + }) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + value: String, + newArraySize: String): String = { + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val getValue = CodeGenerator.getValue(inputArray, elementType, i) + val isEqual = ctx.genEqual(elementType, value, getValue) + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |int $pos = 0; + |Object[] $values = new Object[$newArraySize]; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if (!($isEqual)) { --- End diff -- Don't we need to check null?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org