Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/21061#discussion_r196621059 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -2189,3 +2189,293 @@ case class ArrayRemove(left: Expression, right: Expression) override def prettyName: String = "array_remove" } + +object ArraySetLike { + def useGenericArrayData(elementSize: Int, length: Int): Boolean = { + // Use the same calculation in UnsafeArrayData.fromPrimitiveArray() + val headerInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(length) + val valueRegionInBytes = elementSize.toLong * length + val totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8 + totalSizeInLongs > Integer.MAX_VALUE / 8 + } + + def throwUnionLengthOverflowException(length: Int): Unit = { + throw new RuntimeException(s"Unsuccessful try to union arrays with ${length}" + + s"elements due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } + + def evalUnionContainsNull( + array1: ArrayData, + array2: ArrayData, + elementType: DataType, + ordering: Ordering[Any]): ArrayData = { + if (ordering == null) { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new mutable.HashSet[Any] + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + val elem = array.get(i, elementType) + if (hs.add(elem)) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem + } + i += 1 + } + }) + new GenericArrayData(arrayBuffer) + } else { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadyIncludeNull = false + Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { + var found = false + if (elem == null) { + if (alreadyIncludeNull) { + found = true + } else { + alreadyIncludeNull = true + } + } else { + // check elem is already stored in arrayBuffer or not? + var j = 0 + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + if (va != null && ordering.equiv(va, elem)) { + found = true + } + j = j + 1 + } + } + if (!found) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem + } + })) + new GenericArrayData(arrayBuffer) + } + } +} + + +abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { + override def dataType: DataType = left.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType, + s"function $prettyName") + } else { + typeCheckResult + } + } + + protected def cn = left.dataType.asInstanceOf[ArrayType].containsNull || + right.dataType.asInstanceOf[ArrayType].containsNull + + @transient protected lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + @transient protected lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } +} + +/** + * Returns an array of the elements in the union of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 2, 3, 5) + """, + since = "2.4.0") +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + if (!cn) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + // calculate result array size + val hsSize = new OpenHashSet[Int] + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + if (hsSize.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(hsSize.size) + } + hsSize.add(array.getInt(i)) + i += 1 + } + }) + // store elements into array + val resultArray = new Array[Int](hsSize.size) + val hs = new OpenHashSet[Int] + var pos = 0 + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements () ) { + val elem = array.getInt (i) + if (!hs.contains (elem) ) { + resultArray (pos) = elem + hs.add (elem) + pos += 1 + } + i += 1 + } + }) + if (ArraySetLike.useGenericArrayData(IntegerType.defaultSize, resultArray.length)) { + new GenericArrayData(resultArray) + } else { + UnsafeArrayData.fromPrimitiveArray(resultArray) + } + case LongType => + // avoid boxing of primitive long array elements + // calculate result array size + val hsSize = new OpenHashSet[Long] + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + if (hsSize.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(hsSize.size) + } + hsSize.add(array.getLong(i)) + i += 1 + } + }) + // store elements into array + val resultArray = new Array[Long](hsSize.size) + val hs = new OpenHashSet[Long] + var pos = 0 + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + val elem = array.getLong(i) + if (!hs.contains(elem)) { + resultArray(pos) = elem + hs.add(elem) + pos += 1 + } + i += 1 + } + }) + if (ArraySetLike.useGenericArrayData(LongType.defaultSize, resultArray.length)) { + new GenericArrayData(resultArray) + } else { + UnsafeArrayData.fromPrimitiveArray(resultArray) + } + case _ => --- End diff -- You are right. To address [this comment](https://github.com/apache/spark/pull/21061#discussion_r194520120) can fix this issue.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org