Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21061#discussion_r182663898 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -505,3 +506,150 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def prettyName: String = "array_max" } + +abstract class ArraySetUtils extends BinaryExpression with ExpectsInputTypes { + val kindUnion = 1 + def typeId: Int + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { + val r = super.checkInputDataTypes() + if ((r == TypeCheckResult.TypeCheckSuccess) && + (left.dataType.asInstanceOf[ArrayType].elementType != + right.dataType.asInstanceOf[ArrayType].elementType)) { + TypeCheckResult.TypeCheckFailure("Element type in both arrays must be the same") + } else { + r + } + } + + override def dataType: DataType = left.dataType + + private def elementType = dataType.asInstanceOf[ArrayType].elementType + private def cn1 = left.dataType.asInstanceOf[ArrayType].containsNull + private def cn2 = right.dataType.asInstanceOf[ArrayType].containsNull + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val ary1 = input1.asInstanceOf[ArrayData] + val ary2 = input2.asInstanceOf[ArrayData] + + if (!cn1 && !cn2) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + val hs = new OpenHashSet[Int] + var i = 0 + while (i < ary1.numElements()) { + hs.add(ary1.getInt(i)) + i += 1 + } + i = 0 + while (i < ary2.numElements()) { + hs.add(ary2.getInt(i)) + i += 1 + } + UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) + case LongType => + // avoid boxing of primitive long array elements + val hs = new OpenHashSet[Long] + var i = 0 + while (i < ary1.numElements()) { + hs.add(ary1.getLong(i)) + i += 1 + } + i = 0 + while (i < ary2.numElements()) { + hs.add(ary2.getLong(i)) + i += 1 + } + UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray) + case _ => + val hs = new OpenHashSet[Any] + var i = 0 + while (i < ary1.numElements()) { + hs.add(ary1.get(i, elementType)) + i += 1 + } + i = 0 + while (i < ary2.numElements()) { + hs.add(ary2.get(i, elementType)) + i += 1 + } + new GenericArrayData(hs.iterator.toArray) + } + } else { + ArraySetUtils.arrayUnion(ary1, ary2, elementType) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val hs = ctx.freshName("hs") + val i = ctx.freshName("i") + val ArraySetUtils = "org.apache.spark.sql.catalyst.expressions.ArraySetUtils" + val genericArrayData = classOf[GenericArrayData].getName + val unsafeArrayData = classOf[UnsafeArrayData].getName + val openHashSet = classOf[OpenHashSet[_]].getName + val et = s"org.apache.spark.sql.types.DataTypes.$elementType" + val (postFix, classTag, getter, arrayBuilder, castType) = if (!cn1 && !cn2) { + val ptName = CodeGenerator.primitiveTypeName(elementType) + elementType match { + case ByteType | ShortType | IntegerType => + (s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", + s"$unsafeArrayData.fromPrimitiveArray", CodeGenerator.javaType(elementType)) + case LongType => + (s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", s"get$ptName($i)", + s"$unsafeArrayData.fromPrimitiveArray", "long") + case _ => + ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, $et)", + s"new $genericArrayData", "Object") + } + } else { + ("", "", "", "", "") + } + + nullSafeCodeGen(ctx, ev, (ary1, ary2) => { + if (classTag != "") { + s""" + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |for (int $i = 0; $i < $ary1.numElements(); $i++) { + | $hs.add$postFix($ary1.$getter); + |} + |for (int $i = 0; $i < $ary2.numElements(); $i++) { + | $hs.add$postFix($ary2.$getter); + |} + |${ev.value} = $arrayBuilder(($castType[]) $hs.iterator().toArray($classTag)); --- End diff -- I guess we shouldn't use `iterator()` to avoid box/unbox. `Iterator` is not specialized.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org