Github user kiszk commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21061#discussion_r196620839
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -2189,3 +2189,293 @@ case class ArrayRemove(left: Expression, right: 
Expression)
     
       override def prettyName: String = "array_remove"
     }
    +
    +object ArraySetLike {
    +  def useGenericArrayData(elementSize: Int, length: Int): Boolean = {
    +    // Use the same calculation in UnsafeArrayData.fromPrimitiveArray()
    +    val headerInBytes = 
UnsafeArrayData.calculateHeaderPortionInBytes(length)
    +    val valueRegionInBytes = elementSize.toLong * length
    +    val totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8
    +    totalSizeInLongs > Integer.MAX_VALUE / 8
    +  }
    +
    +  def throwUnionLengthOverflowException(length: Int): Unit = {
    +    throw new RuntimeException(s"Unsuccessful try to union arrays with 
${length}" +
    +      s"elements due to exceeding the array size limit " +
    +      s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
    +  }
    +
    +  def evalUnionContainsNull(
    +      array1: ArrayData,
    +      array2: ArrayData,
    +      elementType: DataType,
    +      ordering: Ordering[Any]): ArrayData = {
    +    if (ordering == null) {
    +      val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
    +      val hs = new mutable.HashSet[Any]
    +      Seq(array1, array2).foreach(array => {
    +        var i = 0
    +        while (i < array.numElements()) {
    +          val elem = array.get(i, elementType)
    +          if (hs.add(elem)) {
    +            if (arrayBuffer.length > 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
    +              throwUnionLengthOverflowException(arrayBuffer.length)
    +            }
    +            arrayBuffer += elem
    +          }
    +          i += 1
    +        }
    +      })
    +      new GenericArrayData(arrayBuffer)
    +    } else {
    +      val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
    +      var alreadyIncludeNull = false
    +      Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => {
    +        var found = false
    +        if (elem == null) {
    +          if (alreadyIncludeNull) {
    +            found = true
    +          } else {
    +            alreadyIncludeNull = true
    +          }
    +        } else {
    +          // check elem is already stored in arrayBuffer or not?
    +          var j = 0
    +          while (!found && j < arrayBuffer.size) {
    +            val va = arrayBuffer(j)
    +            if (va != null && ordering.equiv(va, elem)) {
    +              found = true
    +            }
    +            j = j + 1
    +          }
    +        }
    +        if (!found) {
    +          if (arrayBuffer.length > 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
    +            throwUnionLengthOverflowException(arrayBuffer.length)
    +          }
    +          arrayBuffer += elem
    +        }
    +      }))
    +      new GenericArrayData(arrayBuffer)
    +    }
    +  }
    +}
    +
    +
    +abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast {
    +  override def dataType: DataType = left.dataType
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    val typeCheckResult = super.checkInputDataTypes()
    +    if (typeCheckResult.isSuccess) {
    +      
TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType,
    +        s"function $prettyName")
    +    } else {
    +      typeCheckResult
    +    }
    +  }
    +
    +  protected def cn = left.dataType.asInstanceOf[ArrayType].containsNull ||
    +    right.dataType.asInstanceOf[ArrayType].containsNull
    +
    +  @transient protected lazy val ordering: Ordering[Any] =
    +    TypeUtils.getInterpretedOrdering(elementType)
    +
    +  @transient protected lazy val elementTypeSupportEquals = elementType 
match {
    +    case BinaryType => false
    +    case _: AtomicType => true
    +    case _ => false
    +  }
    +}
    +
    +/**
    + * Returns an array of the elements in the union of x and y, without 
duplicates
    + */
    +@ExpressionDescription(
    +  usage = """
    +    _FUNC_(array1, array2) - Returns an array of the elements in the union 
of array1 and array2,
    +      without duplicates.
    +  """,
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
    +       array(1, 2, 3, 5)
    +  """,
    +  since = "2.4.0")
    +case class ArrayUnion(left: Expression, right: Expression) extends 
ArraySetLike {
    +
    +  override def nullSafeEval(input1: Any, input2: Any): Any = {
    +    val array1 = input1.asInstanceOf[ArrayData]
    +    val array2 = input2.asInstanceOf[ArrayData]
    +
    +    if (!cn) {
    +      elementType match {
    +        case IntegerType =>
    +          // avoid boxing of primitive int array elements
    +          // calculate result array size
    +          val hsSize = new OpenHashSet[Int]
    +          Seq(array1, array2).foreach(array => {
    +            var i = 0
    +            while (i < array.numElements()) {
    +              if (hsSize.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) 
{
    +                ArraySetLike.throwUnionLengthOverflowException(hsSize.size)
    +              }
    +              hsSize.add(array.getInt(i))
    +              i += 1
    +            }
    +          })
    +          // store elements into array
    +          val resultArray = new Array[Int](hsSize.size)
    +          val hs = new OpenHashSet[Int]
    +          var pos = 0
    +          Seq(array1, array2).foreach(array => {
    +            var i = 0
    +            while (i < array.numElements () ) {
    +              val elem = array.getInt (i)
    +              if (!hs.contains (elem) ) {
    +                resultArray (pos) = elem
    +                hs.add (elem)
    +                pos += 1
    +              }
    +              i += 1
    +            }
    +          })
    +          if (ArraySetLike.useGenericArrayData(IntegerType.defaultSize, 
resultArray.length)) {
    +            new GenericArrayData(resultArray)
    +          } else {
    +            UnsafeArrayData.fromPrimitiveArray(resultArray)
    +          }
    +        case LongType =>
    +          // avoid boxing of primitive long array elements
    +          // calculate result array size
    +          val hsSize = new OpenHashSet[Long]
    +          Seq(array1, array2).foreach(array => {
    +            var i = 0
    +            while (i < array.numElements()) {
    +              if (hsSize.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) 
{
    +                ArraySetLike.throwUnionLengthOverflowException(hsSize.size)
    +              }
    +              hsSize.add(array.getLong(i))
    +              i += 1
    +            }
    +          })
    +          // store elements into array
    +          val resultArray = new Array[Long](hsSize.size)
    +          val hs = new OpenHashSet[Long]
    +          var pos = 0
    +          Seq(array1, array2).foreach(array => {
    +            var i = 0
    +            while (i < array.numElements()) {
    +              val elem = array.getLong(i)
    +              if (!hs.contains(elem)) {
    +                resultArray(pos) = elem
    +                hs.add(elem)
    +                pos += 1
    +              }
    +              i += 1
    +            }
    +          })
    +          if (ArraySetLike.useGenericArrayData(LongType.defaultSize, 
resultArray.length)) {
    +            new GenericArrayData(resultArray)
    +          } else {
    +            UnsafeArrayData.fromPrimitiveArray(resultArray)
    +          }
    +        case _ =>
    +          val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
    +          val hs = new OpenHashSet[Any]
    +          Seq(array1, array2).foreach(array => {
    +            var i = 0
    +            while (i < array.numElements()) {
    +              val elem = array.get(i, elementType)
    +              if (!hs.contains(elem)) {
    +                if (arrayBuffer.size > 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
    +                  
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
    +                }
    +                arrayBuffer += elem
    +                hs.add(elem)
    +              }
    +              i += 1
    +            }
    +          })
    +          new GenericArrayData(arrayBuffer)
    +      }
    +    } else {
    +      ArraySetLike.evalUnionContainsNull(array1, array2, elementType,
    +        if (elementTypeSupportEquals) null else ordering)
    +    }
    +  }
    +
    +  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    +    val i = ctx.freshName("i")
    +    val pos = ctx.freshName("pos")
    +    val value = ctx.freshName("value")
    +    val size = ctx.freshName("size")
    +    val genericArrayData = classOf[GenericArrayData].getName
    +    val (postFix, classTag, getter, setter, javaTypeName, arrayBuilder) = 
if (!cn) {
    +      val ptName = CodeGenerator.primitiveTypeName(elementType)
    +      elementType match {
    +        case ByteType | ShortType | IntegerType | LongType =>
    +          val unsafeArray = ctx.freshName("unsafeArray")
    +          (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp",
    +            s"scala.reflect.ClassTag$$.MODULE$$.$ptName()",
    +            s"get$ptName($i)", s"set$ptName($pos, $value)", 
CodeGenerator.javaType(elementType),
    +            s"""
    +              |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" 
$prettyName failed.")}
    +              |${ev.value} = $unsafeArray;
    +           """.stripMargin)
    +        case _ =>
    +          val et = ctx.addReferenceObj("elementType", elementType)
    +          ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()",
    +            s"get($i, $et)", s"update($pos, $value)", "Object",
    +            s"${ev.value} = new $genericArrayData(new Object[$size]);")
    +      }
    +    } else {
    +      ("", "", "", "", "", "")
    +    }
    +
    +    val hs = ctx.freshName("hs")
    +    nullSafeCodeGen(ctx, ev, (array1, array2) => {
    +      if (classTag != "") {
    +        val openHashSet = classOf[OpenHashSet[_]].getName
    --- End diff --
    
    As you pointed out, this part works only for `containsNull = false`. 
    We ensure `classTag != ""` only when `containsNull = false`. Since it is 
not easy to understand, I will add a comment regarding this assumption.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to