Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/21061#discussion_r197235630 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -2355,3 +2355,347 @@ case class ArrayRemove(left: Expression, right: Expression) override def prettyName: String = "array_remove" } + +object ArraySetLike { + def useGenericArrayData(elementSize: Int, length: Int): Boolean = { + // Use the same calculation in UnsafeArrayData.fromPrimitiveArray() + val headerInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(length) + val valueRegionInBytes = elementSize.toLong * length + val totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8 + totalSizeInLongs > Integer.MAX_VALUE / 8 + } + + def throwUnionLengthOverflowException(length: Int): Unit = { + throw new RuntimeException(s"Unsuccessful try to union arrays with $length " + + s"elements due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } +} + + +abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { + override def dataType: DataType = left.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType, + s"function $prettyName") + } else { + typeCheckResult + } + } + + @transient protected lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + @transient protected lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } +} + +/** + * Returns an array of the elements in the union of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 2, 3, 5) + """, + since = "2.4.0") +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { + var hsInt: OpenHashSet[Int] = _ + var hsLong: OpenHashSet[Long] = _ + + def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getInt(idx) + if (!hsInt.contains(elem)) { + resultArray.setInt(pos, elem) + hsInt.add(elem) + true + } else { + false + } + } + + def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getLong(idx) + if (!hsLong.contains(elem)) { + resultArray.setLong(pos, elem) + hsLong.add(elem) + true + } else { + false + } + } + + def evalPrimitiveType( + array1: ArrayData, + array2: ArrayData, + size: Int, + resultArray: ArrayData, + isLongType: Boolean): ArrayData = { + // store elements into resultArray + var foundNullElement = false + var pos = 0 + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + if (!foundNullElement) { + resultArray.setNullAt(pos) + pos += 1 + foundNullElement = true + } + } else { + val assigned = if (!isLongType) { + assignInt(array, i, resultArray, pos) + } else { + assignLong(array, i, resultArray, pos) + } + if (assigned) { + pos += 1 + } + } + i += 1 + } + }) + resultArray + } + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + if (elementTypeSupportEquals) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + // calculate result array size + val hsSize = new OpenHashSet[Int] + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + if (hsSize.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(hsSize.size) + } + hsSize.add(array.getInt(i)) + i += 1 + } + }) + if (UnsafeArrayData.useGenericArrayData(IntegerType.defaultSize, hsSize.size)) { + ArrayUnion.evalUnionContainsNull(array1, array2, elementType, + if (elementTypeSupportEquals) null else ordering) --- End diff -- Great catch!!
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org