Github user kiszk commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21103#discussion_r206002948
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -3968,3 +3964,267 @@ object ArrayUnion {
         new GenericArrayData(arrayBuffer)
       }
     }
    +
    +/**
    + * Returns an array of the elements in the intersect of x and y, without 
duplicates
    + */
    +@ExpressionDescription(
    +  usage = """
    +  _FUNC_(array1, array2) - Returns an array of the elements in array1 but 
not in array2,
    +    without duplicates.
    +  """,
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
    +       array(2)
    +  """,
    +  since = "2.4.0")
    +case class ArrayExcept(left: Expression, right: Expression) extends 
ArraySetLike
    +    with ComplexTypeMergingExpression {
    +  override def dataType: DataType = {
    +    dataTypeCheck
    +    left.dataType
    +  }
    +
    +  @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = {
    +    if (elementTypeSupportEquals) {
    +      (array1, array2) =>
    +        val hs = new OpenHashSet[Any]
    +        var notFoundNullElement = true
    +        var i = 0
    +        while (i < array2.numElements()) {
    +          if (array2.isNullAt(i)) {
    +            notFoundNullElement = false
    +          } else {
    +            val elem = array2.get(i, elementType)
    +            hs.add(elem)
    +          }
    +          i += 1
    +        }
    +        val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
    +        i = 0
    +        while (i < array1.numElements()) {
    +          if (array1.isNullAt(i)) {
    +            if (notFoundNullElement) {
    +              arrayBuffer += null
    +              notFoundNullElement = false
    +            }
    +          } else {
    +            val elem = array1.get(i, elementType)
    +            if (!hs.contains(elem)) {
    +              arrayBuffer += elem
    +              hs.add(elem)
    +            }
    +          }
    +          i += 1
    +        }
    +        new GenericArrayData(arrayBuffer)
    +    } else {
    +      (array1, array2) =>
    +        val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
    +        var scannedNullElements = false
    +        var i = 0
    +        while (i < array1.numElements()) {
    +          var found = false
    +          val elem1 = array1.get(i, elementType)
    +          if (elem1 == null) {
    +            if (!scannedNullElements) {
    +              var j = 0
    +              while (!found && j < array2.numElements()) {
    +                found = array2.isNullAt(j)
    +                j += 1
    +              }
    +              // array2 is scanned only once for null element
    +              scannedNullElements = true
    +            } else {
    +              found = true
    +            }
    +          } else {
    +            var j = 0
    +            while (!found && j < array2.numElements()) {
    +              val elem2 = array2.get(j, elementType)
    +              if (elem2 != null) {
    +                found = ordering.equiv(elem1, elem2)
    +              }
    +              j += 1
    +            }
    +            if (!found) {
    +              // check whether elem1 is already stored in arrayBuffer
    +              var k = 0
    +              while (!found && k < arrayBuffer.size) {
    +                val va = arrayBuffer(k)
    +                found = (va != null) && ordering.equiv(va, elem1)
    +                k += 1
    +              }
    +            }
    +          }
    +          if (!found) {
    +            arrayBuffer += elem1
    +          }
    +          i += 1
    +        }
    +        new GenericArrayData(arrayBuffer)
    +      }
    +  }
    +
    +  override def nullSafeEval(input1: Any, input2: Any): Any = {
    +    val array1 = input1.asInstanceOf[ArrayData]
    +    val array2 = input2.asInstanceOf[ArrayData]
    +
    +    evalExcept(array1, array2)
    +  }
    +
    +  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    +    val arrayData = classOf[ArrayData].getName
    +    val i = ctx.freshName("i")
    +    val pos = ctx.freshName("pos")
    +    val value = ctx.freshName("value")
    +    val hsValue = ctx.freshName("hsValue")
    +    val size = ctx.freshName("size")
    +    val (postFix, openHashElementType, hsJavaTypeName, genHsValue,
    +         getter, setter, javaTypeName, arrayBuilder) =
    +      if (elementTypeSupportEquals) {
    +        elementType match {
    +          case BooleanType | ByteType | ShortType | IntegerType =>
    +            val ptName = CodeGenerator.primitiveTypeName(elementType)
    +            val unsafeArray = ctx.freshName("unsafeArray")
    +            ("$mcI$sp", "Int", "int",
    +              if (elementType != BooleanType) {
    +                s"(int) $value"
    +              } else {
    +                s"$value ? 1 : 0;"
    +              },
    +              s"get$ptName($i)", s"set$ptName($pos, $value)", 
CodeGenerator.javaType(elementType),
    +              s"""
    +                 |${ctx.createUnsafeArray(unsafeArray, size, elementType, 
s" $prettyName failed.")}
    +                 |${ev.value} = $unsafeArray;
    +               """.stripMargin)
    +          case LongType | FloatType | DoubleType =>
    +            val ptName = CodeGenerator.primitiveTypeName(elementType)
    +            val unsafeArray = ctx.freshName("unsafeArray")
    +            val signature = elementType match {
    +              case LongType => "$mcJ$sp"
    +              case FloatType => "$mcF$sp"
    +              case DoubleType => "$mcD$sp"
    +            }
    +            (signature, CodeGenerator.boxedType(elementType),
    +              CodeGenerator.javaType(elementType), value,
    +              s"get$ptName($i)", s"set$ptName($pos, $value)", 
CodeGenerator.javaType(elementType),
    +              s"""
    +                 |${ctx.createUnsafeArray(unsafeArray, size, elementType, 
s" $prettyName failed.")}
    +                 |${ev.value} = $unsafeArray;
    +               """.stripMargin)
    +          case _ =>
    +            val genericArrayData = classOf[GenericArrayData].getName
    +            val et = ctx.addReferenceObj("elementType", elementType)
    +            ("", "Object", "Object", value,
    +              s"get($i, $et)", s"update($pos, $value)", "Object",
    +              s"${ev.value} = new $genericArrayData(new Object[$size]);")
    +        }
    +      } else {
    +        ("", "", "", "", "", "", "", "")
    +      }
    +
    +    nullSafeCodeGen(ctx, ev, (array1, array2) => {
    +      if (openHashElementType != "") {
    +        // Here, we ensure elementTypeSupportEquals is true
    +        val notFoundNullElement = ctx.freshName("notFoundNullElement")
    +        val openHashSet = classOf[OpenHashSet[_]].getName
    +        val classTag = 
s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
    +        val hs = ctx.freshName("hs")
    +        val arrayData = classOf[ArrayData].getName
    +        val arrays = ctx.freshName("arrays")
    +        val array = ctx.freshName("array")
    +        val arrayDataIdx = ctx.freshName("arrayDataIdx")
    +
    +        val array2NullCheck = if 
(right.dataType.asInstanceOf[ArrayType].containsNull) {
    +          s"""
    +            |if ($array2.isNullAt($i)) {
    +            |  $notFoundNullElement = false;
    +            |} else
    +           """.stripMargin
    +        } else {
    +          ""
    +        }
    +        val array1NullCheck = if 
(left.dataType.asInstanceOf[ArrayType].containsNull) {
    +          s"""
    +             |if ($array1.isNullAt($i)) {
    +             |  if ($notFoundNullElement) {
    +             |    $size++;
    +             |    $notFoundNullElement = false;
    +             |  }
    +             |} else
    +           """.stripMargin
    +        } else {
    +          ""
    +        }
    +        val array1NullAssignment = if 
(left.dataType.asInstanceOf[ArrayType].containsNull) {
    +          s"""
    +             |if ($array1.isNullAt($i)) {
    +             |  if ($notFoundNullElement) {
    +             |    ${ev.value}.setNullAt($pos++);
    +             |    $notFoundNullElement = false;
    +             |  }
    +             |} else
    +           """.stripMargin
    +        } else {
    +          ""
    +        }
    +
    +        s"""
    +           |$openHashSet $hs = new $openHashSet$postFix($classTag);
    +           |boolean $notFoundNullElement = true;
    +           |int $size = 0;
    +           |for (int $i = 0; $i < $array2.numElements(); $i++) {
    +           |  $array2NullCheck
    +           |  {
    +           |    $javaTypeName $value = $array2.$getter;
    +           |    $hsJavaTypeName $hsValue = $genHsValue;
    +           |    $hs.add$postFix($hsValue);
    +           |  }
    +           |}
    +           |for (int $i = 0; $i < $array1.numElements(); $i++) {
    +           |  $array1NullCheck
    +           |  {
    +           |    $javaTypeName $value = $array1.$getter;
    +           |    $hsJavaTypeName $hsValue = $genHsValue;
    +           |    if (!$hs.contains($hsValue)) {
    +           |      $hs.add$postFix($hsValue);
    +           |      $size++;
    +           |    }
    +           |  }
    +           |}
    +           |$arrayBuilder
    +           |$hs = new $openHashSet$postFix($classTag);
    +           |$notFoundNullElement = true;
    +           |int $pos = 0;
    +           |for (int $i = 0; $i < $array2.numElements(); $i++) {
    --- End diff --
    
    Yes, only calculating array size for allocating a `UnsafeArrayData` or 
`GenericArrayData`.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to