This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new b00d0b460d6 [SPARK-41234][SQL][PYTHON] Add `array_insert` function
b00d0b460d6 is described below

commit b00d0b460d6ab05951a3620632b80fcb06907970
Author: Daniel Davies <ddav...@palantir.com>
AuthorDate: Mon Feb 6 16:00:11 2023 +0800

    [SPARK-41234][SQL][PYTHON] Add `array_insert` function
    
    ### What changes were proposed in this pull request?
    
    This PR implements the array_insert function, similar to the snowflake 
implementation 
[here](https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/api/snowflake.snowpark.functions.array_insert.html).
 The main difference of course being that spark arrays are 1-indexed, while 
snowflake arrays are 0-indexed, hence the definition of array_insert will 
differ by one index for positive numbers.
    
    #### Arguments
    arr: Array of anytype of elements in which the element has to be appended.
    pos: name of Numeric type column indicating position of insertion (starting 
index 1).
    item: Value for insertion into arr. The type of element has to match with 
the type of elements array is holding.
    
    `select array_insert(array(1, 2, 3), 3, 4);`
    | array_insert(array(1, 2, 3), 3, 4) |
    |------------------------------------|
    | [1, 2, 3, 4]                       |
    
    #### Support in other frameworks
    
[MySQL](https://dev.mysql.com/doc/refman/5.7/en/json-modification-functions.html#function_json-array-insert)
    
    
[Snowflake](https://docs.snowflake.com/en/sql-reference/functions/array_insert.html)
    
    [SQL Notebook](https://sqlnotebook.com/array-insert-func.html)
    
    ### Why are the changes needed?
    Implementation of an existing snowflake function that is not yet in spark. 
Use cases include adding arbitrary elements to a list at a given index.
    
    ### Does this PR introduce _any_ user-facing change?
    No changes to existing APIs; addition of a new array_insert function that 
can be accessed via Scala Spark and PySpark
    
    ### How was this patch tested?
    Tests at CollectionExpressionSuite (scala implementation class), 
DataFrameFunctionsSuite, SQLExpressionSuite)
    
    Closes #38867 from Daniel-Davies/ddavies/SPARK-41234.
    
    Lead-authored-by: Daniel Davies <ddav...@palantir.com>
    Co-authored-by: Daniel-Davies 
<33356828+daniel-dav...@users.noreply.github.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
    (cherry picked from commit fbfcd81e65630260273148fb19ee3e471056119d)
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../source/reference/pyspark.sql/functions.rst     |   1 +
 python/pyspark/sql/functions.py                    |  37 +++
 .../sql/catalyst/analysis/FunctionRegistry.scala   |   1 +
 .../expressions/collectionOperations.scala         | 251 +++++++++++++++++++++
 .../expressions/CollectionExpressionsSuite.scala   |  72 ++++++
 .../scala/org/apache/spark/sql/functions.scala     |  10 +
 .../sql-functions/sql-expression-schema.md         |   1 +
 .../src/test/resources/sql-tests/inputs/array.sql  |  12 +
 .../resources/sql-tests/results/ansi/array.sql.out |  98 ++++++++
 .../test/resources/sql-tests/results/array.sql.out |  98 ++++++++
 .../apache/spark/sql/DataFrameFunctionsSuite.scala |  44 ++++
 11 files changed, 625 insertions(+)

diff --git a/python/docs/source/reference/pyspark.sql/functions.rst 
b/python/docs/source/reference/pyspark.sql/functions.rst
index ddc8eab90f7..70fc04ef9cf 100644
--- a/python/docs/source/reference/pyspark.sql/functions.rst
+++ b/python/docs/source/reference/pyspark.sql/functions.rst
@@ -157,6 +157,7 @@ Collection Functions
     element_at
     array_append
     array_sort
+    array_insert
     array_remove
     array_distinct
     array_intersect
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 3426f2bdaf6..157432e07b0 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -7679,6 +7679,43 @@ def array_distinct(col: "ColumnOrName") -> Column:
     return _invoke_function_over_columns("array_distinct", col)
 
 
+@try_remote_functions
+def array_insert(arr: "ColumnOrName", pos: "ColumnOrName", value: 
"ColumnOrName") -> Column:
+    """
+    Collection function: adds an item into a given array at a specified array 
index.
+    Array indices start at 1 (or from the end if the index is negative).
+    Index specified beyond the size of the current array (plus additional 
element)
+    is extended with 'null' elements.
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    arr : :class:`~pyspark.sql.Column` or str
+        name of column containing an array
+    pos : :class:`~pyspark.sql.Column` or str
+        name of Numeric type column indicating position of insertion
+        (starting at index 1, negative position is a start from the back of 
the array)
+    value : :class:`~pyspark.sql.Column` or str
+        name of column containing values for insertion into array
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        an array of values, including the new specified value
+
+    Examples
+    --------
+    >>> df = spark.createDataFrame(
+    ...     [(['a', 'b', 'c'], 2, 'd'), (['c', 'b', 'a'], -2, 'd')],
+    ...     ['data', 'pos', 'val']
+    ... )
+    >>> df.select(array_insert(df.data, df.pos.cast('integer'), 
df.val).alias('data')).collect()
+    [Row(data=['a', 'd', 'b', 'c']), Row(data=['c', 'd', 'b', 'a'])]
+    """
+    return _invoke_function_over_columns("array_insert", arr, pos, value)
+
+
 @try_remote_functions
 def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
     """
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 8d28826491f..6181cc7cae9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -661,6 +661,7 @@ object FunctionRegistry {
     expression[CreateArray]("array"),
     expression[ArrayContains]("array_contains"),
     expression[ArraysOverlap]("arrays_overlap"),
+    expression[ArrayInsert]("array_insert"),
     expression[ArrayIntersect]("array_intersect"),
     expression[ArrayJoin]("array_join"),
     expression[ArrayPosition]("array_position"),
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index ca3982f54c8..92a3127d438 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -4601,6 +4601,257 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArrayBinaryL
     newLeft: Expression, newRight: Expression): ArrayExcept = copy(left = 
newLeft, right = newRight)
 }
 
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "_FUNC_(x, pos, val) - Places val into index pos of array x (array 
indices start at 1, or start from the end if start is negative).\",",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 2, 3, 4), 5, 5);
+       [1,2,3,4,5]
+      > SELECT _FUNC_(array(5, 3, 2, 1), -3, 4);
+       [5,4,3,2,1]
+  """,
+  group = "array_funcs",
+  since = "3.4.0")
+case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, 
itemExpr: Expression)
+  extends TernaryExpression with ImplicitCastInputTypes with 
ComplexTypeMergingExpression
+    with QueryErrorsBase {
+
+  override def inputTypes: Seq[AbstractDataType] = {
+    (srcArrayExpr.dataType, posExpr.dataType, itemExpr.dataType) match {
+      case (ArrayType(e1, hasNull), e2: IntegralType, e3) if (e2 != LongType) 
=>
+        TypeCoercion.findTightestCommonType(e1, e3) match {
+          case Some(dt) => Seq(ArrayType(dt, hasNull), IntegerType, dt)
+          case _ => Seq.empty
+        }
+      case (e1, e2, e3) => Seq.empty
+    }
+    Seq.empty
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    (first.dataType, second.dataType, third.dataType) match {
+      case (_: ArrayType, e2, e3) if e2 != IntegerType =>
+        DataTypeMismatch(
+          errorSubClass = "UNEXPECTED_INPUT_TYPE",
+          messageParameters = Map(
+            "paramIndex" -> "2",
+            "requiredType" -> toSQLType(IntegerType),
+            "inputSql" -> toSQLExpr(second),
+            "inputType" -> toSQLType(second.dataType))
+        )
+      case (ArrayType(e1, _), e2, e3) if e1.sameType(e3) =>
+        TypeCheckResult.TypeCheckSuccess
+      case _ =>
+        DataTypeMismatch(
+          errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
+          messageParameters = Map(
+            "functionName" -> toSQLId(prettyName),
+            "dataType" -> toSQLType(ArrayType),
+            "leftType" -> toSQLType(first.dataType),
+            "rightType" -> toSQLType(third.dataType)
+          )
+        )
+    }
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val value1 = first.eval(input)
+    if (value1 != null) {
+      val value2 = second.eval(input)
+      if (value2 != null) {
+        val value3 = third.eval(input)
+        return nullSafeEval(value1, value2, value3)
+      }
+    }
+    null
+  }
+
+  override def nullSafeEval(arr: Any, pos: Any, item: Any): Any = {
+    val baseArr = arr.asInstanceOf[ArrayData]
+    var posInt = pos.asInstanceOf[Int]
+    val arrayElementType = dataType.asInstanceOf[ArrayType].elementType
+
+    val newPosExtendsArrayLeft = (posInt < 0) && (-posInt > 
baseArr.numElements())
+
+    if (newPosExtendsArrayLeft) {
+      // special case- if the new position is negative but larger than the 
current array size
+      // place the new item at start of array, place the current array 
contents at the end
+      // and fill the newly created array elements inbetween with a null
+
+      val newArrayLength = -posInt + 1
+
+      if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+        throw 
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
+      }
+
+      val newArray = new Array[Any](newArrayLength)
+
+      baseArr.foreach(arrayElementType, (i, v) => {
+        // current position, offset by new item + new null array elements
+        val elementPosition = i + 1 + math.abs(posInt + baseArr.numElements())
+        newArray(elementPosition) = v
+      })
+
+      newArray(0) = item
+
+      return new GenericArrayData(newArray)
+    } else {
+      if (posInt < 0) {
+        posInt = posInt + baseArr.numElements()
+      } else if (posInt > 0) {
+        posInt = posInt - 1
+      }
+
+      val newArrayLength = math.max(baseArr.numElements() + 1, posInt + 1)
+
+      if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+        throw 
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
+      }
+
+      val newArray = new Array[Any](newArrayLength)
+
+      baseArr.foreach(arrayElementType, (i, v) => {
+        if (i >= posInt) {
+          newArray(i + 1) = v
+        } else {
+          newArray(i) = v
+        }
+      })
+
+      newArray(posInt) = item
+
+      return new GenericArrayData(newArray)
+    }
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val f = (arrExpr: ExprCode, posExpr: ExprCode, itemExpr: ExprCode) => {
+      val arr = arrExpr.value
+      val pos = posExpr.value
+      val item = itemExpr.value
+
+      val itemInsertionIndex = ctx.freshName("itemInsertionIndex")
+      val adjustedAllocIdx = ctx.freshName("adjustedAllocIdx")
+      val resLength = ctx.freshName("resLength")
+      val insertedItemIsNull = ctx.freshName("insertedItemIsNull")
+      val i = ctx.freshName("i")
+      val j = ctx.freshName("j")
+      val values = ctx.freshName("values")
+
+      val allocation = CodeGenerator.createArrayData(
+        values, elementType, resLength, s"$prettyName failed.")
+      val assignment = CodeGenerator.createArrayAssignment(values, 
elementType, arr,
+        adjustedAllocIdx, i, 
first.dataType.asInstanceOf[ArrayType].containsNull)
+
+      s"""
+         |int $itemInsertionIndex = 0;
+         |int $resLength = 0;
+         |int $adjustedAllocIdx = 0;
+         |boolean $insertedItemIsNull = ${itemExpr.isNull};
+         |
+         |if ($pos < 0 && (java.lang.Math.abs($pos) > $arr.numElements())) {
+         |
+         |  $resLength = java.lang.Math.abs($pos) + 1;
+         |  if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+         |    throw 
QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
+         |  }
+         |
+         |  $allocation
+         |  for (int $i = 0; $i < $arr.numElements(); $i ++) {
+         |    $adjustedAllocIdx = $i + 1 + java.lang.Math.abs($pos + 
$arr.numElements());
+         |    $assignment
+         |  }
+         |  ${CodeGenerator.setArrayElement(
+              values, elementType, itemInsertionIndex, item, 
Some(insertedItemIsNull))}
+         |
+         |  for (int $j = $pos + $arr.numElements(); $j < 0; $j ++) {
+         |    $values.setNullAt($j + 1 + java.lang.Math.abs($pos + 
$arr.numElements()));
+         |  }
+         |
+         |  ${ev.value} = $values;
+         |} else {
+         |
+         |  $itemInsertionIndex = 0;
+         |  if ($pos < 0) {
+         |    $itemInsertionIndex = $pos + $arr.numElements();
+         |  } else if ($pos > 0) {
+         |    $itemInsertionIndex = $pos - 1;
+         |  }
+         |
+         |  $resLength = java.lang.Math.max($arr.numElements() + 1, 
$itemInsertionIndex + 1);
+         |  if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+         |    throw 
QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
+         |  }
+         |
+         |  $allocation
+         |  for (int $i = 0; $i < $arr.numElements(); $i ++) {
+         |    $adjustedAllocIdx = $i;
+         |    if ($i >= $itemInsertionIndex) {
+         |      $adjustedAllocIdx = $adjustedAllocIdx + 1;
+         |    }
+         |    $assignment
+         |  }
+         |  ${CodeGenerator.setArrayElement(
+              values, elementType, itemInsertionIndex, item, 
Some(insertedItemIsNull))}
+         |
+         |  for (int $j = $arr.numElements(); $j < $resLength - 1; $j ++) {
+         |    $values.setNullAt($j);
+         |  }
+         |
+         |  ${ev.value} = $values;
+         |}
+      """.stripMargin
+    }
+
+    val leftGen = first.genCode(ctx)
+    val midGen = second.genCode(ctx)
+    val rightGen = third.genCode(ctx)
+    val resultCode = f(leftGen, midGen, rightGen)
+
+    if (nullable) {
+      val nullSafeEval =
+        leftGen.code + ctx.nullSafeExec(first.nullable, leftGen.isNull) {
+          midGen.code + ctx.nullSafeExec(second.nullable, midGen.isNull) {
+            s"""
+              ${rightGen.code}
+              ${ev.isNull} = false;
+              $resultCode
+            """
+          }
+        }
+
+      ev.copy(code = code"""
+        boolean ${ev.isNull} = true;
+        ${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
+        $nullSafeEval""")
+    } else {
+      ev.copy(code = code"""
+        ${leftGen.code}
+        ${midGen.code}
+        ${rightGen.code}
+        ${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
+        $resultCode""", isNull = FalseLiteral)
+    }
+  }
+
+  override def first: Expression = srcArrayExpr
+  override def second: Expression = posExpr
+  override def third: Expression = itemExpr
+
+  override def prettyName: String = "array_insert"
+  override def dataType: DataType = first.dataType
+  override def nullable: Boolean = first.nullable | second.nullable
+
+  @transient private lazy val elementType: DataType =
+    srcArrayExpr.dataType.asInstanceOf[ArrayType].elementType
+
+
+  override protected def withNewChildrenInternal(
+      newSrcArrayExpr: Expression, newPosExpr: Expression, newItemExpr: 
Expression): ArrayInsert =
+    copy(srcArrayExpr = newSrcArrayExpr, posExpr = newPosExpr, itemExpr = 
newItemExpr)
+}
+
 @ExpressionDescription(
   usage = "_FUNC_(array) - Removes null values from the array.",
   examples = """
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index d83739df38d..9b97430594d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -2250,6 +2250,78 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
       Seq(2d))
   }
 
+  test("Array Insert") {
+    val a1 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType))
+    val a2 = Literal.create(Seq(1, 2, null, 4, 5, null), 
ArrayType(IntegerType))
+    val a3 = Literal.create(Seq[Boolean](true, false, true), 
ArrayType(BooleanType))
+    val a4 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType))
+    val a5 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType))
+    val a6 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), 
ArrayType(FloatType))
+    val a7 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), 
ArrayType(DoubleType))
+    val a8 = Literal.create(Seq(1L, 2L, 4L), ArrayType(LongType))
+    val a9 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false))
+    val a10 = Literal.create(Seq("b", null, "a", "g", null), 
ArrayType(StringType, true))
+    val a11 = Literal.create(null, ArrayType(StringType))
+
+    // basic additions per type
+    checkEvaluation(ArrayInsert(a1, Literal(3), Literal(3)), Seq(1, 2, 3, 4))
+    checkEvaluation(
+      ArrayInsert(a3, Literal.create(3, IntegerType), Literal(true)),
+      Seq[Boolean](true, false, true, true)
+    )
+    checkEvaluation(
+      ArrayInsert(
+        a4,
+        Literal(3),
+        Literal.create(5.asInstanceOf[Byte], ByteType)),
+      Seq[Byte](1, 2, 5, 3, 2))
+
+    checkEvaluation(
+      ArrayInsert(
+        a5,
+        Literal(3),
+        Literal.create(3.asInstanceOf[Short], ShortType)),
+      Seq[Short](1, 2, 3, 3, 2))
+
+    checkEvaluation(
+      ArrayInsert(a7, Literal(4), Literal(4.4)),
+      Seq[Double](1.1, 2.2, 3.3, 4.4, 2.2)
+    )
+
+    checkEvaluation(
+      ArrayInsert(a6, Literal(4), Literal(4.4F)),
+      Seq(1.1F, 2.2F, 3.3F, 4.4F, 2.2F)
+    )
+    checkEvaluation(ArrayInsert(a8, Literal(3), Literal(3L)), Seq(1L, 2L, 3L, 
4L))
+    checkEvaluation(ArrayInsert(a9, Literal(3), Literal("d")), Seq("b", "a", 
"d", "c"))
+
+    // index edge cases
+    checkEvaluation(ArrayInsert(a1, Literal(2), Literal(3)), Seq(1, 3, 2, 4))
+    checkEvaluation(ArrayInsert(a1, Literal(0), Literal(3)), Seq(3, 1, 2, 4))
+    checkEvaluation(ArrayInsert(a1, Literal(1), Literal(3)), Seq(3, 1, 2, 4))
+    checkEvaluation(ArrayInsert(a1, Literal(4), Literal(3)), Seq(1, 2, 4, 3))
+    checkEvaluation(ArrayInsert(a1, Literal(-2), Literal(3)), Seq(1, 3, 2, 4))
+    checkEvaluation(ArrayInsert(a1, Literal(-3), Literal(3)), Seq(3, 1, 2, 4))
+    checkEvaluation(ArrayInsert(a1, Literal(-4), Literal(3)), Seq(3, null, 1, 
2, 4))
+    checkEvaluation(
+      ArrayInsert(a1, Literal(10), Literal(3)),
+      Seq(1, 2, 4, null, null, null, null, null, null, 3)
+    )
+    checkEvaluation(
+      ArrayInsert(a1, Literal(-10), Literal(3)),
+      Seq(3, null, null, null, null, null, null, null, 1, 2, 4)
+    )
+
+    // null handling
+    checkEvaluation(ArrayInsert(
+      a1, Literal(3), Literal.create(null, IntegerType)), Seq(1, 2, null, 4)
+    )
+    checkEvaluation(ArrayInsert(a2, Literal(3), Literal(3)), Seq(1, 2, 3, 
null, 4, 5, null))
+    checkEvaluation(ArrayInsert(a10, Literal(3), Literal("d")), Seq("b", null, 
"d", "a", "g", null))
+    checkEvaluation(ArrayInsert(a11, Literal(3), Literal("d")), null)
+    checkEvaluation(ArrayInsert(a10, Literal.create(null, IntegerType), 
Literal("d")), null)
+  }
+
   test("Array Intersect") {
     val a00 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType, false))
     val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 3d5547ead83..cb5c1ad5c49 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -4062,6 +4062,16 @@ object functions {
     ArrayIntersect(col1.expr, col2.expr)
   }
 
+  /**
+   * Adds an item into a given array at a specified position
+   *
+   * @group collection_funcs
+   * @since 3.4.0
+   */
+  def array_insert(arr: Column, pos: Column, value: Column): Column = withExpr 
{
+    ArrayInsert(arr.expr, pos.expr, value.expr)
+  }
+
   /**
    * Returns an array of the elements in the union of the given two arrays, 
without duplicates.
    *
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md 
b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 9b8d50d2ede..ef5c4addc84 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -20,6 +20,7 @@
 | org.apache.spark.sql.catalyst.expressions.ArrayExists | exists | SELECT 
exists(array(1, 2, 3), x -> x % 2 == 0) | struct<exists(array(1, 2, 3), 
lambdafunction(((namedlambdavariable() % 2) = 0), 
namedlambdavariable())):boolean> |
 | org.apache.spark.sql.catalyst.expressions.ArrayFilter | filter | SELECT 
filter(array(1, 2, 3), x -> x % 2 == 1) | struct<filter(array(1, 2, 3), 
lambdafunction(((namedlambdavariable() % 2) = 1), 
namedlambdavariable())):array<int>> |
 | org.apache.spark.sql.catalyst.expressions.ArrayForAll | forall | SELECT 
forall(array(1, 2, 3), x -> x % 2 == 0) | struct<forall(array(1, 2, 3), 
lambdafunction(((namedlambdavariable() % 2) = 0), 
namedlambdavariable())):boolean> |
+| org.apache.spark.sql.catalyst.expressions.ArrayInsert | array_insert | 
SELECT array_insert(array(1, 2, 3, 4), 5, 5) | struct<array_insert(array(1, 2, 
3, 4), 5, 5):array<int>> |
 | org.apache.spark.sql.catalyst.expressions.ArrayIntersect | array_intersect | 
SELECT array_intersect(array(1, 2, 3), array(1, 3, 5)) | 
struct<array_intersect(array(1, 2, 3), array(1, 3, 5)):array<int>> |
 | org.apache.spark.sql.catalyst.expressions.ArrayJoin | array_join | SELECT 
array_join(array('hello', 'world'), ' ') | struct<array_join(array(hello, 
world),  ):string> |
 | org.apache.spark.sql.catalyst.expressions.ArrayMax | array_max | SELECT 
array_max(array(1, 20, null, 3)) | struct<array_max(array(1, 20, NULL, 3)):int> 
|
diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql 
b/sql/core/src/test/resources/sql-tests/inputs/array.sql
index f2f0dac98e2..3d107cb6dfc 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/array.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql
@@ -130,6 +130,18 @@ select get(array(1, 2, 3), 3);
 select get(array(1, 2, 3), null);
 select get(array(1, 2, 3), -1);
 
+-- function array_insert()
+select array_insert(array(1, 2, 3), 3, 4);
+select array_insert(array(2, 3, 4), 0, 1);
+select array_insert(array(2, 3, 4), 1, 1);
+select array_insert(array(1, 3, 4), -2, 2);
+select array_insert(array(1, 2, 3), 3, "4");
+select array_insert(cast(NULL as ARRAY<INT>), 1, 1);
+select array_insert(array(1, 2, 3, NULL), cast(NULL as INT), 4);
+select array_insert(array(1, 2, 3, NULL), 4, cast(NULL as INT));
+select array_insert(array(2, 3, NULL, 4), 5, 5);
+select array_insert(array(2, 3, NULL, 4), -5, 1);
+
 -- function array_compact
 select array_compact(id) from values (1) as t(id);
 select array_compact(array("1", null, "2", null));
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out 
b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
index 714775d7e8b..0d8ef39ed60 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
@@ -550,6 +550,104 @@ struct<get(array(1, 2, 3), -1):int>
 NULL
 
 
+-- !query
+select array_insert(array(1, 2, 3), 3, 4)
+-- !query schema
+struct<array_insert(array(1, 2, 3), 3, 4):array<int>>
+-- !query output
+[1,2,4,3]
+
+
+-- !query
+select array_insert(array(2, 3, 4), 0, 1)
+-- !query schema
+struct<array_insert(array(2, 3, 4), 0, 1):array<int>>
+-- !query output
+[1,2,3,4]
+
+
+-- !query
+select array_insert(array(2, 3, 4), 1, 1)
+-- !query schema
+struct<array_insert(array(2, 3, 4), 1, 1):array<int>>
+-- !query output
+[1,2,3,4]
+
+
+-- !query
+select array_insert(array(1, 3, 4), -2, 2)
+-- !query schema
+struct<array_insert(array(1, 3, 4), -2, 2):array<int>>
+-- !query output
+[1,2,3,4]
+
+
+-- !query
+select array_insert(array(1, 2, 3), 3, "4")
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+  "errorClass" : "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES",
+  "sqlState" : "42K09",
+  "messageParameters" : {
+    "dataType" : "\"ARRAY\"",
+    "functionName" : "`array_insert`",
+    "leftType" : "\"ARRAY<INT>\"",
+    "rightType" : "\"STRING\"",
+    "sqlExpr" : "\"array_insert(array(1, 2, 3), 3, 4)\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 43,
+    "fragment" : "array_insert(array(1, 2, 3), 3, \"4\")"
+  } ]
+}
+
+
+-- !query
+select array_insert(cast(NULL as ARRAY<INT>), 1, 1)
+-- !query schema
+struct<array_insert(NULL, 1, 1):array<int>>
+-- !query output
+NULL
+
+
+-- !query
+select array_insert(array(1, 2, 3, NULL), cast(NULL as INT), 4)
+-- !query schema
+struct<array_insert(array(1, 2, 3, NULL), CAST(NULL AS INT), 4):array<int>>
+-- !query output
+NULL
+
+
+-- !query
+select array_insert(array(1, 2, 3, NULL), 4, cast(NULL as INT))
+-- !query schema
+struct<array_insert(array(1, 2, 3, NULL), 4, CAST(NULL AS INT)):array<int>>
+-- !query output
+[1,2,3,null,null]
+
+
+-- !query
+select array_insert(array(2, 3, NULL, 4), 5, 5)
+-- !query schema
+struct<array_insert(array(2, 3, NULL, 4), 5, 5):array<int>>
+-- !query output
+[2,3,null,4,5]
+
+
+-- !query
+select array_insert(array(2, 3, NULL, 4), -5, 1)
+-- !query schema
+struct<array_insert(array(2, 3, NULL, 4), -5, 1):array<int>>
+-- !query output
+[1,null,2,3,null,4]
+
+
 -- !query
 select array_compact(id) from values (1) as t(id)
 -- !query schema
diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out 
b/sql/core/src/test/resources/sql-tests/results/array.sql.out
index 94d4c7987a8..609122a23d3 100644
--- a/sql/core/src/test/resources/sql-tests/results/array.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out
@@ -431,6 +431,104 @@ struct<get(array(1, 2, 3), -1):int>
 NULL
 
 
+-- !query
+select array_insert(array(1, 2, 3), 3, 4)
+-- !query schema
+struct<array_insert(array(1, 2, 3), 3, 4):array<int>>
+-- !query output
+[1,2,4,3]
+
+
+-- !query
+select array_insert(array(2, 3, 4), 0, 1)
+-- !query schema
+struct<array_insert(array(2, 3, 4), 0, 1):array<int>>
+-- !query output
+[1,2,3,4]
+
+
+-- !query
+select array_insert(array(2, 3, 4), 1, 1)
+-- !query schema
+struct<array_insert(array(2, 3, 4), 1, 1):array<int>>
+-- !query output
+[1,2,3,4]
+
+
+-- !query
+select array_insert(array(1, 3, 4), -2, 2)
+-- !query schema
+struct<array_insert(array(1, 3, 4), -2, 2):array<int>>
+-- !query output
+[1,2,3,4]
+
+
+-- !query
+select array_insert(array(1, 2, 3), 3, "4")
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+  "errorClass" : "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES",
+  "sqlState" : "42K09",
+  "messageParameters" : {
+    "dataType" : "\"ARRAY\"",
+    "functionName" : "`array_insert`",
+    "leftType" : "\"ARRAY<INT>\"",
+    "rightType" : "\"STRING\"",
+    "sqlExpr" : "\"array_insert(array(1, 2, 3), 3, 4)\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 43,
+    "fragment" : "array_insert(array(1, 2, 3), 3, \"4\")"
+  } ]
+}
+
+
+-- !query
+select array_insert(cast(NULL as ARRAY<INT>), 1, 1)
+-- !query schema
+struct<array_insert(NULL, 1, 1):array<int>>
+-- !query output
+NULL
+
+
+-- !query
+select array_insert(array(1, 2, 3, NULL), cast(NULL as INT), 4)
+-- !query schema
+struct<array_insert(array(1, 2, 3, NULL), CAST(NULL AS INT), 4):array<int>>
+-- !query output
+NULL
+
+
+-- !query
+select array_insert(array(1, 2, 3, NULL), 4, cast(NULL as INT))
+-- !query schema
+struct<array_insert(array(1, 2, 3, NULL), 4, CAST(NULL AS INT)):array<int>>
+-- !query output
+[1,2,3,null,null]
+
+
+-- !query
+select array_insert(array(2, 3, NULL, 4), 5, 5)
+-- !query schema
+struct<array_insert(array(2, 3, NULL, 4), 5, 5):array<int>>
+-- !query output
+[2,3,null,4,5]
+
+
+-- !query
+select array_insert(array(2, 3, NULL, 4), -5, 1)
+-- !query schema
+struct<array_insert(array(2, 3, NULL, 4), -5, 1):array<int>>
+-- !query output
+[1,null,2,3,null,4]
+
+
 -- !query
 select array_compact(id) from values (1) as t(id)
 -- !query schema
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index b7daa45a42a..14def67ba40 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -3106,6 +3106,50 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
     )
   }
 
+  test("array_insert functions") {
+    val fiveShort: Short = 5
+
+    val df1 = Seq((Array[Integer](3, 2, 5, 1, 2), 6, 3)).toDF("a", "b", "c")
+    val df2 = Seq((Array[Short](1, 2, 3, 4), 5, fiveShort)).toDF("a", "b", "c")
+    val df3 = Seq((Array[Double](3.0, 2.0, 5.0, 1.0, 2.0), 2, 3.0)).toDF("a", 
"b", "c")
+    val df4 = Seq((Array[Boolean](true, false), 3, false)).toDF("a", "b", "c")
+    val df5 = Seq((Array[String]("a", "b", "c"), 0, "d")).toDF("a", "b", "c")
+    val df6 = Seq((Array[String]("a", null, "b", "c"), 5, "d")).toDF("a", "b", 
"c")
+
+    checkAnswer(df1.selectExpr("array_insert(a, b, c)"), Seq(Row(Seq(3, 2, 5, 
1, 2, 3))))
+    checkAnswer(df2.selectExpr("array_insert(a, b, c)"), Seq(Row(Seq[Short](1, 
2, 3, 4, 5))))
+    checkAnswer(
+      df3.selectExpr("array_insert(a, b, c)"),
+      Seq(Row(Seq[Double](3.0, 3.0, 2.0, 5.0, 1.0, 2.0)))
+    )
+    checkAnswer(df4.selectExpr("array_insert(a, b, c)"), Seq(Row(Seq(true, 
false, false))))
+    checkAnswer(df5.selectExpr("array_insert(a, b, c)"), Seq(Row(Seq("d", "a", 
"b", "c"))))
+    checkAnswer(df5.select(
+      array_insert(col("a"), lit(1), col("c"))),
+      Seq(Row(Seq("d", "a", "b", "c")))
+    )
+    // null checks
+    checkAnswer(df6.selectExpr("array_insert(a, b, c)"), Seq(Row(Seq("a", 
null, "b", "c", "d"))))
+    checkAnswer(df5.select(
+      array_insert(col("a"), col("b"), lit(null).cast("string"))),
+      Seq(Row(Seq(null, "a", "b", "c")))
+    )
+    checkAnswer(df6.select(
+      array_insert(col("a"), col("b"), lit(null).cast("string"))),
+      Seq(Row(Seq("a", null, "b", "c", null)))
+    )
+    checkAnswer(
+      df5.select(array_insert(col("a"), lit(null).cast("integer"), col("c"))),
+      Seq(Row(null))
+    )
+    checkAnswer(
+      df5.select(array_insert(lit(null).cast("array<string>"), col("b"), 
col("c"))),
+      Seq(Row(null))
+    )
+    checkAnswer(df1.selectExpr("array_insert(a, 7, c)"), Seq(Row(Seq(3, 2, 5, 
1, 2, null, 3))))
+    checkAnswer(df1.selectExpr("array_insert(a, -6, c)"), Seq(Row(Seq(3, null, 
3, 2, 5, 1, 2))))
+  }
+
   test("transform function - array for primitive type not containing null") {
     val df = Seq(
       Seq(1, 9, 8, 7),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to