Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/21061#discussion_r197228219 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -2355,3 +2355,347 @@ case class ArrayRemove(left: Expression, right: Expression) override def prettyName: String = "array_remove" } + +object ArraySetLike { + def useGenericArrayData(elementSize: Int, length: Int): Boolean = { + // Use the same calculation in UnsafeArrayData.fromPrimitiveArray() + val headerInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(length) + val valueRegionInBytes = elementSize.toLong * length + val totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8 + totalSizeInLongs > Integer.MAX_VALUE / 8 + } + + def throwUnionLengthOverflowException(length: Int): Unit = { + throw new RuntimeException(s"Unsuccessful try to union arrays with $length " + + s"elements due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } +} + + +abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { + override def dataType: DataType = left.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType, + s"function $prettyName") + } else { + typeCheckResult + } + } + + @transient protected lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + @transient protected lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } +} + +/** + * Returns an array of the elements in the union of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 2, 3, 5) + """, + since = "2.4.0") +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { + var hsInt: OpenHashSet[Int] = _ + var hsLong: OpenHashSet[Long] = _ + + def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getInt(idx) + if (!hsInt.contains(elem)) { + resultArray.setInt(pos, elem) + hsInt.add(elem) + true + } else { + false + } + } + + def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getLong(idx) + if (!hsLong.contains(elem)) { + resultArray.setLong(pos, elem) + hsLong.add(elem) + true + } else { + false + } + } + + def evalPrimitiveType( + array1: ArrayData, + array2: ArrayData, + size: Int, + resultArray: ArrayData, + isLongType: Boolean): ArrayData = { + // store elements into resultArray + var foundNullElement = false + var pos = 0 + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + if (!foundNullElement) { + resultArray.setNullAt(pos) + pos += 1 + foundNullElement = true + } + } else { + val assigned = if (!isLongType) { + assignInt(array, i, resultArray, pos) + } else { + assignLong(array, i, resultArray, pos) + } + if (assigned) { + pos += 1 + } + } + i += 1 + } + }) + resultArray + } + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + if (elementTypeSupportEquals) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + // calculate result array size + val hsSize = new OpenHashSet[Int] + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + if (hsSize.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(hsSize.size) + } + hsSize.add(array.getInt(i)) + i += 1 + } + }) + if (UnsafeArrayData.useGenericArrayData(IntegerType.defaultSize, hsSize.size)) { + ArrayUnion.evalUnionContainsNull(array1, array2, elementType, + if (elementTypeSupportEquals) null else ordering) + } else { + hsInt = new OpenHashSet[Int] + val resultArray = UnsafeArrayData.fromPrimitiveArray(new Array[Int](hsSize.size)) + evalPrimitiveType(array1, array2, hsSize.size, resultArray, false) + } + case LongType => + // avoid boxing of primitive long array elements + // calculate result array size + val hsSize = new OpenHashSet[Long] + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + if (hsSize.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(hsSize.size) + } + hsSize.add(array.getLong(i)) + i += 1 + } + }) + if (UnsafeArrayData.useGenericArrayData(IntegerType.defaultSize, hsSize.size)) { + ArrayUnion.evalUnionContainsNull(array1, array2, elementType, + if (elementTypeSupportEquals) null else ordering) + } else { + hsLong = new OpenHashSet[Long] + val resultArray = UnsafeArrayData.fromPrimitiveArray(new Array[Long](hsSize.size)) + evalPrimitiveType(array1, array2, hsSize.size, resultArray, true) + } + case _ => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new OpenHashSet[Any] + var foundNullElement = false + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + if (!foundNullElement) { + arrayBuffer += null + foundNullElement = true + } + } else { + val elem = array.get(i, elementType) + if (!hs.contains(elem)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += elem + hs.add(elem) + } + } + i += 1 + } + }) + new GenericArrayData(arrayBuffer) + } + } else { + ArrayUnion.evalUnionContainsNull(array1, array2, elementType, + if (elementTypeSupportEquals) null else ordering) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + val (postFix, classTag, getter, setter, javaTypeName, arrayBuilder) = + if (elementTypeSupportEquals) { + elementType match { + case ByteType | ShortType | IntegerType | LongType => + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", + s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", + s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case _ => + val genericArrayData = classOf[GenericArrayData].getName + val et = ctx.addReferenceObj("elementType", elementType) + ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", + s"get($i, $et)", s"update($pos, $value)", "Object", + s"${ev.value} = new $genericArrayData(new Object[$size]);") + } + } else { + ("", "", "", "", "", "") + } + + nullSafeCodeGen(ctx, ev, (array1, array2) => { + if (classTag != "") { + // Here, we ensure elementTypeSupportEquals && !array1.containsNull && !array2.containsNull --- End diff -- good catch, thanks
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org