Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21061#discussion_r194527119 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -2189,3 +2189,293 @@ case class ArrayRemove(left: Expression, right: Expression) override def prettyName: String = "array_remove" } + +object ArraySetLike { + def useGenericArrayData(elementSize: Int, length: Int): Boolean = { + // Use the same calculation in UnsafeArrayData.fromPrimitiveArray() + val headerInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(length) + val valueRegionInBytes = elementSize.toLong * length + val totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8 + totalSizeInLongs > Integer.MAX_VALUE / 8 + } + + def throwUnionLengthOverflowException(length: Int): Unit = { + throw new RuntimeException(s"Unsuccessful try to union arrays with ${length}" + + s"elements due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } + + def evalUnionContainsNull( + array1: ArrayData, + array2: ArrayData, + elementType: DataType, + ordering: Ordering[Any]): ArrayData = { + if (ordering == null) { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new mutable.HashSet[Any] + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + val elem = array.get(i, elementType) + if (hs.add(elem)) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem + } + i += 1 + } + }) + new GenericArrayData(arrayBuffer) + } else { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadyIncludeNull = false + Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { + var found = false + if (elem == null) { + if (alreadyIncludeNull) { + found = true + } else { + alreadyIncludeNull = true + } + } else { + // check elem is already stored in arrayBuffer or not? + var j = 0 + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + if (va != null && ordering.equiv(va, elem)) { + found = true + } + j = j + 1 + } + } + if (!found) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem + } + })) + new GenericArrayData(arrayBuffer) + } + } +} + + +abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { + override def dataType: DataType = left.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType, + s"function $prettyName") + } else { + typeCheckResult + } + } + + protected def cn = left.dataType.asInstanceOf[ArrayType].containsNull || + right.dataType.asInstanceOf[ArrayType].containsNull + + @transient protected lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + @transient protected lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } +} + +/** + * Returns an array of the elements in the union of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 2, 3, 5) + """, + since = "2.4.0") +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + if (!cn) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + // calculate result array size + val hsSize = new OpenHashSet[Int] + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + if (hsSize.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(hsSize.size) + } + hsSize.add(array.getInt(i)) + i += 1 + } + }) + // store elements into array + val resultArray = new Array[Int](hsSize.size) + val hs = new OpenHashSet[Int] + var pos = 0 + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements () ) { + val elem = array.getInt (i) + if (!hs.contains (elem) ) { + resultArray (pos) = elem + hs.add (elem) + pos += 1 + } + i += 1 + } + }) + if (ArraySetLike.useGenericArrayData(IntegerType.defaultSize, resultArray.length)) { + new GenericArrayData(resultArray) + } else { + UnsafeArrayData.fromPrimitiveArray(resultArray) + } + case LongType => + // avoid boxing of primitive long array elements + // calculate result array size + val hsSize = new OpenHashSet[Long] + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + if (hsSize.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(hsSize.size) + } + hsSize.add(array.getLong(i)) + i += 1 + } + }) + // store elements into array + val resultArray = new Array[Long](hsSize.size) + val hs = new OpenHashSet[Long] + var pos = 0 + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + val elem = array.getLong(i) + if (!hs.contains(elem)) { + resultArray(pos) = elem + hs.add(elem) + pos += 1 + } + i += 1 + } + }) + if (ArraySetLike.useGenericArrayData(LongType.defaultSize, resultArray.length)) { + new GenericArrayData(resultArray) + } else { + UnsafeArrayData.fromPrimitiveArray(resultArray) + } + case _ => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new OpenHashSet[Any] + Seq(array1, array2).foreach(array => { + var i = 0 + while (i < array.numElements()) { + val elem = array.get(i, elementType) + if (!hs.contains(elem)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += elem + hs.add(elem) + } + i += 1 + } + }) + new GenericArrayData(arrayBuffer) + } + } else { + ArraySetLike.evalUnionContainsNull(array1, array2, elementType, + if (elementTypeSupportEquals) null else ordering) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + val genericArrayData = classOf[GenericArrayData].getName + val (postFix, classTag, getter, setter, javaTypeName, arrayBuilder) = if (!cn) { + val ptName = CodeGenerator.primitiveTypeName(elementType) + elementType match { + case ByteType | ShortType | IntegerType | LongType => + val unsafeArray = ctx.freshName("unsafeArray") + (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", + s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", + s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case _ => + val et = ctx.addReferenceObj("elementType", elementType) + ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", + s"get($i, $et)", s"update($pos, $value)", "Object", + s"${ev.value} = new $genericArrayData(new Object[$size]);") + } + } else { + ("", "", "", "", "", "") + } + + val hs = ctx.freshName("hs") + nullSafeCodeGen(ctx, ev, (array1, array2) => { + if (classTag != "") { + val openHashSet = classOf[OpenHashSet[_]].getName --- End diff -- I don't think this is working for some element type with `containsNull = true` properly. E.g., `ArrayUnion(a04, a03)` or `ArrayUnion(a14, a13)`, `aXX` of which are from `"Array Union"` in the `CollectionExpressionsSuite`, should fail with a wrong result.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org