This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 50f20fbad36 [SPARK-44969][SQL] Reuse `ArrayInsert` in `ArrayAppend` 50f20fbad36 is described below commit 50f20fbad36dbb05a5132a3043364eff7b1a565c Author: Max Gekk <max.g...@gmail.com> AuthorDate: Fri Aug 25 22:51:08 2023 -0700 [SPARK-44969][SQL] Reuse `ArrayInsert` in `ArrayAppend` ### What changes were proposed in this pull request? In the PR, I propose to replace the current implementation of the `ArrayAppend` expression by a runtime replaceable one to `ArrayInsert` with `posExpr = -1`. ### Why are the changes needed? To improve code maintenance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By running the modified test suite: ``` $ build/sbt "test:testOnly *CollectionExpressionsSuite" ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42660 from MaxGekk/array_append-to-insert-1. Authored-by: Max Gekk <max.g...@gmail.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../explain-results/function_array_append.explain | 2 +- .../expressions/collectionOperations.scala | 229 ++++++--------------- .../expressions/CollectionExpressionsSuite.scala | 144 ++++++------- 3 files changed, 119 insertions(+), 256 deletions(-) diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_append.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_append.explain index ca2804ebb60..e857e2e974f 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_append.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_append.explain @@ -1,2 +1,2 @@ -Project [array_append(e#0, 1) AS array_append(e, 1)#0] +Project [array_insert(e#0, -1, 1, false) AS array_append(e, 1)#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index fe9c4015c15..957aa1ab2d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1397,29 +1397,9 @@ case class ArrayContains(left: Expression, right: Expression) copy(left = newLeft, right = newRight) } -@ExpressionDescription( - usage = """ - _FUNC_(array, element) - Add the element at the beginning of the array passed as first - argument. Type of element should be the same as the type of the elements of the array. - Null element is also prepended to the array. But if the array passed is NULL - output is NULL - """, - examples = """ - Examples: - > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd'); - ["d","b","d","c","a"] - > SELECT _FUNC_(array(1, 2, 3, null), null); - [null,1,2,3,null] - > SELECT _FUNC_(CAST(null as Array<Int>), 2); - NULL - """, - group = "array_funcs", - since = "3.5.0") -case class ArrayPrepend(left: Expression, right: Expression) extends RuntimeReplaceable +trait ArrayPendBase extends RuntimeReplaceable with ImplicitCastInputTypes with BinaryLike[Expression] with QueryErrorsBase { - override lazy val replacement: Expression = new ArrayInsert(left, Literal(1), right) - override def inputTypes: Seq[AbstractDataType] = { (left.dataType, right.dataType) match { case (ArrayType(e1, hasNull), e2) => @@ -1455,6 +1435,29 @@ case class ArrayPrepend(left: Expression, right: Expression) extends RuntimeRepl ) } } +} + +@ExpressionDescription( + usage = """ + _FUNC_(array, element) - Add the element at the beginning of the array passed as first + argument. Type of element should be the same as the type of the elements of the array. + Null element is also prepended to the array. But if the array passed is NULL + output is NULL + """, + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd'); + ["d","b","d","c","a"] + > SELECT _FUNC_(array(1, 2, 3, null), null); + [null,1,2,3,null] + > SELECT _FUNC_(CAST(null as Array<Int>), 2); + NULL + """, + group = "array_funcs", + since = "3.5.0") +case class ArrayPrepend(left: Expression, right: Expression) extends ArrayPendBase { + + override lazy val replacement: Expression = new ArrayInsert(left, Literal(1), right) override def prettyName: String = "array_prepend" @@ -1463,6 +1466,41 @@ case class ArrayPrepend(left: Expression, right: Expression) extends RuntimeRepl copy(left = newLeft, right = newRight) } + +/** + * Given an array, and another element append the element at the end of the array. + * This function does not return null when the elements are null. It appends null at + * the end of the array. But returns null if the array passed is null. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, element) - Add the element at the end of the array passed as first + argument. Type of element should be similar to type of the elements of the array. + Null element is also appended into the array. But if the array passed, is NULL + output is NULL + """, + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd'); + ["b","d","c","a","d"] + > SELECT _FUNC_(array(1, 2, 3, null), null); + [1,2,3,null,null] + > SELECT _FUNC_(CAST(null as Array<Int>), 2); + NULL + """, + since = "3.4.0", + group = "array_funcs") +case class ArrayAppend(left: Expression, right: Expression) extends ArrayPendBase { + + override lazy val replacement: Expression = new ArrayInsert(left, Literal(-1), right) + + override def prettyName: String = "array_append" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayAppend = + copy(left = newLeft, right = newRight) +} + /** * Checks if the two arrays contain at least one common element. */ @@ -5039,152 +5077,3 @@ case class ArrayCompact(child: Expression) override protected def withNewChildInternal(newChild: Expression): ArrayCompact = copy(child = newChild) } - -/** - * Given an array, and another element append the element at the end of the array. - * This function does not return null when the elements are null. It appends null at - * the end of the array. But returns null if the array passed is null. - */ -@ExpressionDescription( - usage = """ - _FUNC_(array, element) - Add the element at the end of the array passed as first - argument. Type of element should be similar to type of the elements of the array. - Null element is also appended into the array. But if the array passed, is NULL - output is NULL - """, - examples = """ - Examples: - > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd'); - ["b","d","c","a","d"] - > SELECT _FUNC_(array(1, 2, 3, null), null); - [1,2,3,null,null] - > SELECT _FUNC_(CAST(null as Array<Int>), 2); - NULL - """, - since = "3.4.0", - group = "array_funcs") -case class ArrayAppend(left: Expression, right: Expression) - extends BinaryExpression - with ImplicitCastInputTypes - with ComplexTypeMergingExpression - with QueryErrorsBase { - override def prettyName: String = "array_append" - - @transient protected lazy val elementType: DataType = - inputTypes.head.asInstanceOf[ArrayType].elementType - - override def inputTypes: Seq[AbstractDataType] = { - (left.dataType, right.dataType) match { - case (ArrayType(e1, hasNull), e2) => - TypeCoercion.findTightestCommonType(e1, e2) match { - case Some(dt) => Seq(ArrayType(dt, hasNull), dt) - case _ => Seq.empty - } - case _ => Seq.empty - } - } - - override def checkInputDataTypes(): TypeCheckResult = { - (left.dataType, right.dataType) match { - case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) => TypeCheckResult.TypeCheckSuccess - case (ArrayType(e1, _), e2) => DataTypeMismatch( - errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", - messageParameters = Map( - "functionName" -> toSQLId(prettyName), - "leftType" -> toSQLType(left.dataType), - "rightType" -> toSQLType(right.dataType), - "dataType" -> toSQLType(ArrayType) - )) - case _ => - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> "0", - "requiredType" -> toSQLType(ArrayType), - "inputSql" -> toSQLExpr(left), - "inputType" -> toSQLType(left.dataType) - ) - ) - } - } - - override def eval(input: InternalRow): Any = { - val value1 = left.eval(input) - if (value1 == null) { - null - } else { - val value2 = right.eval(input) - nullSafeEval(value1, value2) - } - } - - override protected def nullSafeEval(arr: Any, elementData: Any): Any = { - val arrayData = arr.asInstanceOf[ArrayData] - val numberOfElements = arrayData.numElements() + 1 - if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements) - } - val finalData = new Array[Any](numberOfElements) - arrayData.foreach(elementType, finalData.update) - finalData.update(arrayData.numElements(), elementData) - new GenericArrayData(finalData) - } - - override def nullable: Boolean = left.nullable - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val leftGen = left.genCode(ctx) - val rightGen = right.genCode(ctx) - val f = (eval1: String, eval2: String) => { - val newArraySize = ctx.freshName("newArraySize") - val i = ctx.freshName("i") - val values = ctx.freshName("values") - val allocation = CodeGenerator.createArrayData( - values, elementType, newArraySize, s" $prettyName failed.") - val assignment = CodeGenerator.createArrayAssignment( - values, elementType, eval1, i, i, left.dataType.asInstanceOf[ArrayType].containsNull) - s""" - |int $newArraySize = $eval1.numElements() + 1; - |$allocation - |int $i = 0; - |while ($i < $eval1.numElements()) { - | $assignment - | $i ++; - |} - |${CodeGenerator.setArrayElement(values, elementType, i, eval2, Some(rightGen.isNull))} - |${ev.value} = $values; - |""".stripMargin - } - val resultCode = f(leftGen.value, rightGen.value) - if (nullable) { - val nullSafeEval = - leftGen.code + rightGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) { - s""" - ${ev.isNull} = false; // resultCode could change nullability. - $resultCode - """ - } - ev.copy(code = - code""" - boolean ${ev.isNull} = true; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $nullSafeEval - """) - } else { - ev.copy(code = - code""" - ${leftGen.code} - ${rightGen.code} - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = FalseLiteral) - } - } - - /** - * Returns the [[DataType]] of the result of evaluating this expression. It is invalid to query - * the dataType of an unresolved expression (i.e., when `resolved` == false). - */ - override def dataType: DataType = if (right.nullable) left.dataType.asNullable else left.dataType - protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): ArrayAppend = - copy(left = newLeft, right = newRight) - -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 1787f6ac72d..ff393857c31 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -2328,6 +2328,27 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ) // null handling + checkEvaluation( + ArrayInsert( + Literal.create(null, ArrayType(StringType)), + Literal(-1), + Literal.create("c", StringType), + legacyNegativeIndex = false), + null) + checkEvaluation( + ArrayInsert( + Literal.create(null, ArrayType(StringType)), + Literal(-1), + Literal.create(null, StringType), + legacyNegativeIndex = false), + null) + checkEvaluation( + ArrayInsert( + Literal.create(Seq(""), ArrayType(StringType)), + Literal(-1), + Literal.create(null, StringType), + legacyNegativeIndex = false), + Seq("", null)) checkEvaluation(new ArrayInsert( a1, Literal(3), Literal.create(null, IntegerType)), Seq(1, 2, null, 4) ) @@ -2336,6 +2357,38 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq("b", null, "d", "a", "g", null)) checkEvaluation(new ArrayInsert(a11, Literal(3), Literal("d")), null) checkEvaluation(new ArrayInsert(a10, Literal.create(null, IntegerType), Literal("d")), null) + + assert( + ArrayInsert( + Literal.create(Seq(null, 1d, 2d), ArrayType(DoubleType)), + Literal(-1), + Literal.create(3, IntegerType), + legacyNegativeIndex = false) + .checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", + messageParameters = Map( + "functionName" -> "`array_insert`", + "dataType" -> "\"ARRAY\"", + "leftType" -> "\"ARRAY<DOUBLE>\"", + "rightType" -> "\"INT\"")) + ) + + assert( + ArrayInsert( + Literal.create("Hi", StringType), + Literal(-1), + Literal.create("Spark", StringType), + legacyNegativeIndex = false) + .checkInputDataTypes() == DataTypeMismatch( + errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", + messageParameters = Map( + "functionName" -> "`array_insert`", + "dataType" -> "\"ARRAY\"", + "leftType" -> "\"STRING\"", + "rightType" -> "\"STRING\"") + ) + ) } test("Array Intersect") { @@ -2679,86 +2732,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } } - test("ArrayAppend Expression Test") { - checkEvaluation( - ArrayAppend( - Literal.create(null, ArrayType(StringType)), - Literal.create("c", StringType)), - null) - - checkEvaluation( - ArrayAppend( - Literal.create(null, ArrayType(StringType)), - Literal.create(null, StringType)), - null) - - checkEvaluation( - ArrayAppend( - Literal.create(Seq(""), ArrayType(StringType)), - Literal.create(null, StringType)), - Seq("", null)) - - checkEvaluation( - ArrayAppend( - Literal.create(Seq("a", "b", "c"), ArrayType(StringType)), - Literal.create(null, StringType)), - Seq("a", "b", "c", null)) - - checkEvaluation( - ArrayAppend( - Literal.create(Seq(Double.NaN, 1d, 2d), ArrayType(DoubleType)), - Literal.create(3d, DoubleType)), - Seq(Double.NaN, 1d, 2d, 3d)) - // Null entry check - checkEvaluation( - ArrayAppend( - Literal.create(Seq(null, 1d, 2d), ArrayType(DoubleType)), - Literal.create(3d, DoubleType)), - Seq(null, 1d, 2d, 3d)) - - checkEvaluation( - ArrayAppend( - Literal.create(Seq("a", "b", "c"), ArrayType(StringType)), - Literal.create("c", StringType)), - Seq("a", "b", "c", "c")) - - assert( - ArrayAppend( - Literal.create(Seq(null, 1d, 2d), ArrayType(DoubleType)), - Literal.create(3, IntegerType)) - .checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", - messageParameters = Map( - "functionName" -> "`array_append`", - "dataType" -> "\"ARRAY\"", - "leftType" -> "\"ARRAY<DOUBLE>\"", - "rightType" -> "\"INT\"")) - ) - - - assert( - ArrayAppend( - Literal.create("Hi", StringType), - Literal.create("Spark", StringType)) - .checkInputDataTypes() == DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> "0", - "requiredType" -> "\"ARRAY\"", - "inputSql" -> "\"Hi\"", - "inputType" -> "\"STRING\"" - ) - ) - ) - - } - test("SPARK-42401: Array insert of null value (explicit)") { val a = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false)) - checkEvaluation(new ArrayInsert( - a, Literal(2), Literal.create(null, StringType)), Seq("b", null, "a", "c") - ) + checkEvaluation( + new ArrayInsert(a, Literal(2), Literal.create(null, StringType)), + Seq("b", null, "a", "c")) + checkEvaluation( + new ArrayInsert(a, Literal(-1), Literal.create(null, StringType)), + Seq("b", "a", "c", null)) } test("SPARK-42401: Array insert of null value (implicit)") { @@ -2767,11 +2748,4 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper a, Literal(5), Literal.create("q", StringType)), Seq("b", "a", "c", null, "q") ) } - - test("SPARK-42401: Array append of null value") { - val a = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false)) - checkEvaluation(ArrayAppend( - a, Literal.create(null, StringType)), Seq("b", "a", "c", null) - ) - } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org