This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 50f20fbad36 [SPARK-44969][SQL] Reuse `ArrayInsert` in `ArrayAppend`
50f20fbad36 is described below

commit 50f20fbad36dbb05a5132a3043364eff7b1a565c
Author: Max Gekk <max.g...@gmail.com>
AuthorDate: Fri Aug 25 22:51:08 2023 -0700

    [SPARK-44969][SQL] Reuse `ArrayInsert` in `ArrayAppend`
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to replace the current implementation of the 
`ArrayAppend` expression by a runtime replaceable one to `ArrayInsert` with 
`posExpr = -1`.
    
    ### Why are the changes needed?
    To improve code maintenance.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    By running the modified test suite:
    ```
    $ build/sbt "test:testOnly *CollectionExpressionsSuite"
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #42660 from MaxGekk/array_append-to-insert-1.
    
    Authored-by: Max Gekk <max.g...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../explain-results/function_array_append.explain  |   2 +-
 .../expressions/collectionOperations.scala         | 229 ++++++---------------
 .../expressions/CollectionExpressionsSuite.scala   | 144 ++++++-------
 3 files changed, 119 insertions(+), 256 deletions(-)

diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_append.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_append.explain
index ca2804ebb60..e857e2e974f 100644
--- 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_append.explain
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_append.explain
@@ -1,2 +1,2 @@
-Project [array_append(e#0, 1) AS array_append(e, 1)#0]
+Project [array_insert(e#0, -1, 1, false) AS array_append(e, 1)#0]
 +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index fe9c4015c15..957aa1ab2d5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1397,29 +1397,9 @@ case class ArrayContains(left: Expression, right: 
Expression)
     copy(left = newLeft, right = newRight)
 }
 
-@ExpressionDescription(
-  usage = """
-      _FUNC_(array, element) - Add the element at the beginning of the array 
passed as first
-      argument. Type of element should be the same as the type of the elements 
of the array.
-      Null element is also prepended to the array. But if the array passed is 
NULL
-      output is NULL
-    """,
-  examples = """
-    Examples:
-      > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd');
-       ["d","b","d","c","a"]
-      > SELECT _FUNC_(array(1, 2, 3, null), null);
-       [null,1,2,3,null]
-      > SELECT _FUNC_(CAST(null as Array<Int>), 2);
-       NULL
-  """,
-  group = "array_funcs",
-  since = "3.5.0")
-case class ArrayPrepend(left: Expression, right: Expression) extends 
RuntimeReplaceable
+trait ArrayPendBase extends RuntimeReplaceable
   with ImplicitCastInputTypes with BinaryLike[Expression] with QueryErrorsBase 
{
 
-  override lazy val replacement: Expression = new ArrayInsert(left, 
Literal(1), right)
-
   override def inputTypes: Seq[AbstractDataType] = {
     (left.dataType, right.dataType) match {
       case (ArrayType(e1, hasNull), e2) =>
@@ -1455,6 +1435,29 @@ case class ArrayPrepend(left: Expression, right: 
Expression) extends RuntimeRepl
         )
     }
   }
+}
+
+@ExpressionDescription(
+  usage = """
+      _FUNC_(array, element) - Add the element at the beginning of the array 
passed as first
+      argument. Type of element should be the same as the type of the elements 
of the array.
+      Null element is also prepended to the array. But if the array passed is 
NULL
+      output is NULL
+    """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd');
+       ["d","b","d","c","a"]
+      > SELECT _FUNC_(array(1, 2, 3, null), null);
+       [null,1,2,3,null]
+      > SELECT _FUNC_(CAST(null as Array<Int>), 2);
+       NULL
+  """,
+  group = "array_funcs",
+  since = "3.5.0")
+case class ArrayPrepend(left: Expression, right: Expression) extends 
ArrayPendBase {
+
+  override lazy val replacement: Expression = new ArrayInsert(left, 
Literal(1), right)
 
   override def prettyName: String = "array_prepend"
 
@@ -1463,6 +1466,41 @@ case class ArrayPrepend(left: Expression, right: 
Expression) extends RuntimeRepl
     copy(left = newLeft, right = newRight)
 }
 
+
+/**
+ * Given an array, and another element append the element at the end of the 
array.
+ * This function does not return null when the elements are null. It appends 
null at
+ * the end of the array. But returns null if the array passed is null.
+ */
+@ExpressionDescription(
+  usage = """
+      _FUNC_(array, element) - Add the element at the end of the array passed 
as first
+      argument. Type of element should be similar to type of the elements of 
the array.
+      Null element is also appended into the array. But if the array passed, 
is NULL
+      output is NULL
+      """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd');
+       ["b","d","c","a","d"]
+      > SELECT _FUNC_(array(1, 2, 3, null), null);
+       [1,2,3,null,null]
+      > SELECT _FUNC_(CAST(null as Array<Int>), 2);
+       NULL
+  """,
+  since = "3.4.0",
+  group = "array_funcs")
+case class ArrayAppend(left: Expression, right: Expression) extends 
ArrayPendBase {
+
+  override lazy val replacement: Expression = new ArrayInsert(left, 
Literal(-1), right)
+
+  override def prettyName: String = "array_append"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ArrayAppend =
+    copy(left = newLeft, right = newRight)
+}
+
 /**
  * Checks if the two arrays contain at least one common element.
  */
@@ -5039,152 +5077,3 @@ case class ArrayCompact(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): 
ArrayCompact =
     copy(child = newChild)
 }
-
-/**
- * Given an array, and another element append the element at the end of the 
array.
- * This function does not return null when the elements are null. It appends 
null at
- * the end of the array. But returns null if the array passed is null.
- */
-@ExpressionDescription(
-  usage = """
-      _FUNC_(array, element) - Add the element at the end of the array passed 
as first
-      argument. Type of element should be similar to type of the elements of 
the array.
-      Null element is also appended into the array. But if the array passed, 
is NULL
-      output is NULL
-      """,
-  examples = """
-    Examples:
-      > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd');
-       ["b","d","c","a","d"]
-      > SELECT _FUNC_(array(1, 2, 3, null), null);
-       [1,2,3,null,null]
-      > SELECT _FUNC_(CAST(null as Array<Int>), 2);
-       NULL
-  """,
-  since = "3.4.0",
-  group = "array_funcs")
-case class ArrayAppend(left: Expression, right: Expression)
-  extends BinaryExpression
-    with ImplicitCastInputTypes
-    with ComplexTypeMergingExpression
-    with QueryErrorsBase {
-  override def prettyName: String = "array_append"
-
-  @transient protected lazy val elementType: DataType =
-    inputTypes.head.asInstanceOf[ArrayType].elementType
-
-  override def inputTypes: Seq[AbstractDataType] = {
-    (left.dataType, right.dataType) match {
-      case (ArrayType(e1, hasNull), e2) =>
-        TypeCoercion.findTightestCommonType(e1, e2) match {
-          case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
-          case _ => Seq.empty
-        }
-      case _ => Seq.empty
-    }
-  }
-
-  override def checkInputDataTypes(): TypeCheckResult = {
-    (left.dataType, right.dataType) match {
-      case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) => 
TypeCheckResult.TypeCheckSuccess
-      case (ArrayType(e1, _), e2) => DataTypeMismatch(
-        errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
-        messageParameters = Map(
-          "functionName" -> toSQLId(prettyName),
-          "leftType" -> toSQLType(left.dataType),
-          "rightType" -> toSQLType(right.dataType),
-          "dataType" -> toSQLType(ArrayType)
-        ))
-      case _ =>
-        DataTypeMismatch(
-          errorSubClass = "UNEXPECTED_INPUT_TYPE",
-          messageParameters = Map(
-            "paramIndex" -> "0",
-            "requiredType" -> toSQLType(ArrayType),
-            "inputSql" -> toSQLExpr(left),
-            "inputType" -> toSQLType(left.dataType)
-          )
-        )
-    }
-  }
-
-  override def eval(input: InternalRow): Any = {
-    val value1 = left.eval(input)
-    if (value1 == null) {
-      null
-    } else {
-      val value2 = right.eval(input)
-      nullSafeEval(value1, value2)
-    }
-  }
-
-  override protected def nullSafeEval(arr: Any, elementData: Any): Any = {
-    val arrayData = arr.asInstanceOf[ArrayData]
-    val numberOfElements = arrayData.numElements() + 1
-    if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
-      throw 
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements)
-    }
-    val finalData = new Array[Any](numberOfElements)
-    arrayData.foreach(elementType, finalData.update)
-    finalData.update(arrayData.numElements(), elementData)
-    new GenericArrayData(finalData)
-  }
-
-  override def nullable: Boolean = left.nullable
-  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val leftGen = left.genCode(ctx)
-    val rightGen = right.genCode(ctx)
-    val f = (eval1: String, eval2: String) => {
-      val newArraySize = ctx.freshName("newArraySize")
-      val i = ctx.freshName("i")
-      val values = ctx.freshName("values")
-      val allocation = CodeGenerator.createArrayData(
-        values, elementType, newArraySize, s" $prettyName failed.")
-      val assignment = CodeGenerator.createArrayAssignment(
-        values, elementType, eval1, i, i, 
left.dataType.asInstanceOf[ArrayType].containsNull)
-      s"""
-         |int $newArraySize = $eval1.numElements() + 1;
-         |$allocation
-         |int $i = 0;
-         |while ($i < $eval1.numElements()) {
-         |  $assignment
-         |  $i ++;
-         |}
-         |${CodeGenerator.setArrayElement(values, elementType, i, eval2, 
Some(rightGen.isNull))}
-         |${ev.value} = $values;
-         |""".stripMargin
-    }
-    val resultCode = f(leftGen.value, rightGen.value)
-    if (nullable) {
-      val nullSafeEval =
-        leftGen.code + rightGen.code + ctx.nullSafeExec(left.nullable, 
leftGen.isNull) {
-          s"""
-            ${ev.isNull} = false; // resultCode could change nullability.
-            $resultCode
-          """
-        }
-      ev.copy(code =
-        code"""
-        boolean ${ev.isNull} = true;
-        ${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
-        $nullSafeEval
-      """)
-    } else {
-      ev.copy(code =
-        code"""
-        ${leftGen.code}
-        ${rightGen.code}
-        ${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
-        $resultCode""", isNull = FalseLiteral)
-    }
-  }
-
-  /**
-   * Returns the [[DataType]] of the result of evaluating this expression. It 
is invalid to query
-   * the dataType of an unresolved expression (i.e., when `resolved` == false).
-   */
-  override def dataType: DataType = if (right.nullable) 
left.dataType.asNullable else left.dataType
-  protected def withNewChildrenInternal(newLeft: Expression, newRight: 
Expression): ArrayAppend =
-    copy(left = newLeft, right = newRight)
-
-}
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 1787f6ac72d..ff393857c31 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -2328,6 +2328,27 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
     )
 
     // null handling
+    checkEvaluation(
+      ArrayInsert(
+        Literal.create(null, ArrayType(StringType)),
+        Literal(-1),
+        Literal.create("c", StringType),
+        legacyNegativeIndex = false),
+      null)
+    checkEvaluation(
+      ArrayInsert(
+        Literal.create(null, ArrayType(StringType)),
+        Literal(-1),
+        Literal.create(null, StringType),
+        legacyNegativeIndex = false),
+      null)
+    checkEvaluation(
+      ArrayInsert(
+        Literal.create(Seq(""), ArrayType(StringType)),
+        Literal(-1),
+        Literal.create(null, StringType),
+        legacyNegativeIndex = false),
+      Seq("", null))
     checkEvaluation(new ArrayInsert(
       a1, Literal(3), Literal.create(null, IntegerType)), Seq(1, 2, null, 4)
     )
@@ -2336,6 +2357,38 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
       Seq("b", null, "d", "a", "g", null))
     checkEvaluation(new ArrayInsert(a11, Literal(3), Literal("d")), null)
     checkEvaluation(new ArrayInsert(a10, Literal.create(null, IntegerType), 
Literal("d")), null)
+
+    assert(
+      ArrayInsert(
+        Literal.create(Seq(null, 1d, 2d), ArrayType(DoubleType)),
+        Literal(-1),
+        Literal.create(3, IntegerType),
+        legacyNegativeIndex = false)
+        .checkInputDataTypes() ==
+        DataTypeMismatch(
+          errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
+          messageParameters = Map(
+            "functionName" -> "`array_insert`",
+            "dataType" -> "\"ARRAY\"",
+            "leftType" -> "\"ARRAY<DOUBLE>\"",
+            "rightType" -> "\"INT\""))
+    )
+
+    assert(
+      ArrayInsert(
+        Literal.create("Hi", StringType),
+        Literal(-1),
+        Literal.create("Spark", StringType),
+        legacyNegativeIndex = false)
+        .checkInputDataTypes() == DataTypeMismatch(
+        errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
+        messageParameters = Map(
+          "functionName" -> "`array_insert`",
+          "dataType" -> "\"ARRAY\"",
+          "leftType" -> "\"STRING\"",
+          "rightType" -> "\"STRING\"")
+      )
+    )
   }
 
   test("Array Intersect") {
@@ -2679,86 +2732,14 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
     }
   }
 
-  test("ArrayAppend Expression Test") {
-    checkEvaluation(
-      ArrayAppend(
-        Literal.create(null, ArrayType(StringType)),
-        Literal.create("c", StringType)),
-      null)
-
-    checkEvaluation(
-      ArrayAppend(
-        Literal.create(null, ArrayType(StringType)),
-        Literal.create(null, StringType)),
-      null)
-
-    checkEvaluation(
-      ArrayAppend(
-        Literal.create(Seq(""), ArrayType(StringType)),
-        Literal.create(null, StringType)),
-      Seq("", null))
-
-    checkEvaluation(
-      ArrayAppend(
-        Literal.create(Seq("a", "b", "c"), ArrayType(StringType)),
-        Literal.create(null, StringType)),
-      Seq("a", "b", "c", null))
-
-    checkEvaluation(
-      ArrayAppend(
-        Literal.create(Seq(Double.NaN, 1d, 2d), ArrayType(DoubleType)),
-        Literal.create(3d, DoubleType)),
-      Seq(Double.NaN, 1d, 2d, 3d))
-    // Null entry check
-    checkEvaluation(
-      ArrayAppend(
-        Literal.create(Seq(null, 1d, 2d), ArrayType(DoubleType)),
-        Literal.create(3d, DoubleType)),
-      Seq(null, 1d, 2d, 3d))
-
-    checkEvaluation(
-      ArrayAppend(
-        Literal.create(Seq("a", "b", "c"), ArrayType(StringType)),
-        Literal.create("c", StringType)),
-      Seq("a", "b", "c", "c"))
-
-    assert(
-      ArrayAppend(
-        Literal.create(Seq(null, 1d, 2d), ArrayType(DoubleType)),
-        Literal.create(3, IntegerType))
-        .checkInputDataTypes() ==
-        DataTypeMismatch(
-          errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
-          messageParameters = Map(
-            "functionName" -> "`array_append`",
-            "dataType" -> "\"ARRAY\"",
-            "leftType" -> "\"ARRAY<DOUBLE>\"",
-            "rightType" -> "\"INT\""))
-    )
-
-
-    assert(
-      ArrayAppend(
-        Literal.create("Hi", StringType),
-        Literal.create("Spark", StringType))
-        .checkInputDataTypes() == DataTypeMismatch(
-        errorSubClass = "UNEXPECTED_INPUT_TYPE",
-        messageParameters = Map(
-          "paramIndex" -> "0",
-          "requiredType" -> "\"ARRAY\"",
-          "inputSql" -> "\"Hi\"",
-          "inputType" -> "\"STRING\""
-        )
-      )
-    )
-
-  }
-
   test("SPARK-42401: Array insert of null value (explicit)") {
     val a = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false))
-    checkEvaluation(new ArrayInsert(
-      a, Literal(2), Literal.create(null, StringType)), Seq("b", null, "a", 
"c")
-    )
+    checkEvaluation(
+      new ArrayInsert(a, Literal(2), Literal.create(null, StringType)),
+      Seq("b", null, "a", "c"))
+    checkEvaluation(
+      new ArrayInsert(a, Literal(-1), Literal.create(null, StringType)),
+      Seq("b", "a", "c", null))
   }
 
   test("SPARK-42401: Array insert of null value (implicit)") {
@@ -2767,11 +2748,4 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
       a, Literal(5), Literal.create("q", StringType)), Seq("b", "a", "c", 
null, "q")
     )
   }
-
-  test("SPARK-42401: Array append of null value") {
-    val a = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false))
-    checkEvaluation(ArrayAppend(
-      a, Literal.create(null, StringType)), Seq("b", "a", "c", null)
-    )
-  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to