This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new b00d0b460d6 [SPARK-41234][SQL][PYTHON] Add `array_insert` function b00d0b460d6 is described below commit b00d0b460d6ab05951a3620632b80fcb06907970 Author: Daniel Davies <ddav...@palantir.com> AuthorDate: Mon Feb 6 16:00:11 2023 +0800 [SPARK-41234][SQL][PYTHON] Add `array_insert` function ### What changes were proposed in this pull request? This PR implements the array_insert function, similar to the snowflake implementation [here](https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/api/snowflake.snowpark.functions.array_insert.html). The main difference of course being that spark arrays are 1-indexed, while snowflake arrays are 0-indexed, hence the definition of array_insert will differ by one index for positive numbers. #### Arguments arr: Array of anytype of elements in which the element has to be appended. pos: name of Numeric type column indicating position of insertion (starting index 1). item: Value for insertion into arr. The type of element has to match with the type of elements array is holding. `select array_insert(array(1, 2, 3), 3, 4);` | array_insert(array(1, 2, 3), 3, 4) | |------------------------------------| | [1, 2, 3, 4] | #### Support in other frameworks [MySQL](https://dev.mysql.com/doc/refman/5.7/en/json-modification-functions.html#function_json-array-insert) [Snowflake](https://docs.snowflake.com/en/sql-reference/functions/array_insert.html) [SQL Notebook](https://sqlnotebook.com/array-insert-func.html) ### Why are the changes needed? Implementation of an existing snowflake function that is not yet in spark. Use cases include adding arbitrary elements to a list at a given index. ### Does this PR introduce _any_ user-facing change? No changes to existing APIs; addition of a new array_insert function that can be accessed via Scala Spark and PySpark ### How was this patch tested? Tests at CollectionExpressionSuite (scala implementation class), DataFrameFunctionsSuite, SQLExpressionSuite) Closes #38867 from Daniel-Davies/ddavies/SPARK-41234. Lead-authored-by: Daniel Davies <ddav...@palantir.com> Co-authored-by: Daniel-Davies <33356828+daniel-dav...@users.noreply.github.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> (cherry picked from commit fbfcd81e65630260273148fb19ee3e471056119d) Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../source/reference/pyspark.sql/functions.rst | 1 + python/pyspark/sql/functions.py | 37 +++ .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 251 +++++++++++++++++++++ .../expressions/CollectionExpressionsSuite.scala | 72 ++++++ .../scala/org/apache/spark/sql/functions.scala | 10 + .../sql-functions/sql-expression-schema.md | 1 + .../src/test/resources/sql-tests/inputs/array.sql | 12 + .../resources/sql-tests/results/ansi/array.sql.out | 98 ++++++++ .../test/resources/sql-tests/results/array.sql.out | 98 ++++++++ .../apache/spark/sql/DataFrameFunctionsSuite.scala | 44 ++++ 11 files changed, 625 insertions(+) diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index ddc8eab90f7..70fc04ef9cf 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -157,6 +157,7 @@ Collection Functions element_at array_append array_sort + array_insert array_remove array_distinct array_intersect diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3426f2bdaf6..157432e07b0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7679,6 +7679,43 @@ def array_distinct(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("array_distinct", col) +@try_remote_functions +def array_insert(arr: "ColumnOrName", pos: "ColumnOrName", value: "ColumnOrName") -> Column: + """ + Collection function: adds an item into a given array at a specified array index. + Array indices start at 1 (or from the end if the index is negative). + Index specified beyond the size of the current array (plus additional element) + is extended with 'null' elements. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + arr : :class:`~pyspark.sql.Column` or str + name of column containing an array + pos : :class:`~pyspark.sql.Column` or str + name of Numeric type column indicating position of insertion + (starting at index 1, negative position is a start from the back of the array) + value : :class:`~pyspark.sql.Column` or str + name of column containing values for insertion into array + + Returns + ------- + :class:`~pyspark.sql.Column` + an array of values, including the new specified value + + Examples + -------- + >>> df = spark.createDataFrame( + ... [(['a', 'b', 'c'], 2, 'd'), (['c', 'b', 'a'], -2, 'd')], + ... ['data', 'pos', 'val'] + ... ) + >>> df.select(array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect() + [Row(data=['a', 'd', 'b', 'c']), Row(data=['c', 'd', 'b', 'a'])] + """ + return _invoke_function_over_columns("array_insert", arr, pos, value) + + @try_remote_functions def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 8d28826491f..6181cc7cae9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -661,6 +661,7 @@ object FunctionRegistry { expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), expression[ArraysOverlap]("arrays_overlap"), + expression[ArrayInsert]("array_insert"), expression[ArrayIntersect]("array_intersect"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ca3982f54c8..92a3127d438 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4601,6 +4601,257 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL newLeft: Expression, newRight: Expression): ArrayExcept = copy(left = newLeft, right = newRight) } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x, pos, val) - Places val into index pos of array x (array indices start at 1, or start from the end if start is negative).\",", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 5, 5); + [1,2,3,4,5] + > SELECT _FUNC_(array(5, 3, 2, 1), -3, 4); + [5,4,3,2,1] + """, + group = "array_funcs", + since = "3.4.0") +case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr: Expression) + extends TernaryExpression with ImplicitCastInputTypes with ComplexTypeMergingExpression + with QueryErrorsBase { + + override def inputTypes: Seq[AbstractDataType] = { + (srcArrayExpr.dataType, posExpr.dataType, itemExpr.dataType) match { + case (ArrayType(e1, hasNull), e2: IntegralType, e3) if (e2 != LongType) => + TypeCoercion.findTightestCommonType(e1, e3) match { + case Some(dt) => Seq(ArrayType(dt, hasNull), IntegerType, dt) + case _ => Seq.empty + } + case (e1, e2, e3) => Seq.empty + } + Seq.empty + } + + override def checkInputDataTypes(): TypeCheckResult = { + (first.dataType, second.dataType, third.dataType) match { + case (_: ArrayType, e2, e3) if e2 != IntegerType => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "2", + "requiredType" -> toSQLType(IntegerType), + "inputSql" -> toSQLExpr(second), + "inputType" -> toSQLType(second.dataType)) + ) + case (ArrayType(e1, _), e2, e3) if e1.sameType(e3) => + TypeCheckResult.TypeCheckSuccess + case _ => + DataTypeMismatch( + errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> toSQLType(ArrayType), + "leftType" -> toSQLType(first.dataType), + "rightType" -> toSQLType(third.dataType) + ) + ) + } + } + + override def eval(input: InternalRow): Any = { + val value1 = first.eval(input) + if (value1 != null) { + val value2 = second.eval(input) + if (value2 != null) { + val value3 = third.eval(input) + return nullSafeEval(value1, value2, value3) + } + } + null + } + + override def nullSafeEval(arr: Any, pos: Any, item: Any): Any = { + val baseArr = arr.asInstanceOf[ArrayData] + var posInt = pos.asInstanceOf[Int] + val arrayElementType = dataType.asInstanceOf[ArrayType].elementType + + val newPosExtendsArrayLeft = (posInt < 0) && (-posInt > baseArr.numElements()) + + if (newPosExtendsArrayLeft) { + // special case- if the new position is negative but larger than the current array size + // place the new item at start of array, place the current array contents at the end + // and fill the newly created array elements inbetween with a null + + val newArrayLength = -posInt + 1 + + if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength) + } + + val newArray = new Array[Any](newArrayLength) + + baseArr.foreach(arrayElementType, (i, v) => { + // current position, offset by new item + new null array elements + val elementPosition = i + 1 + math.abs(posInt + baseArr.numElements()) + newArray(elementPosition) = v + }) + + newArray(0) = item + + return new GenericArrayData(newArray) + } else { + if (posInt < 0) { + posInt = posInt + baseArr.numElements() + } else if (posInt > 0) { + posInt = posInt - 1 + } + + val newArrayLength = math.max(baseArr.numElements() + 1, posInt + 1) + + if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength) + } + + val newArray = new Array[Any](newArrayLength) + + baseArr.foreach(arrayElementType, (i, v) => { + if (i >= posInt) { + newArray(i + 1) = v + } else { + newArray(i) = v + } + }) + + newArray(posInt) = item + + return new GenericArrayData(newArray) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val f = (arrExpr: ExprCode, posExpr: ExprCode, itemExpr: ExprCode) => { + val arr = arrExpr.value + val pos = posExpr.value + val item = itemExpr.value + + val itemInsertionIndex = ctx.freshName("itemInsertionIndex") + val adjustedAllocIdx = ctx.freshName("adjustedAllocIdx") + val resLength = ctx.freshName("resLength") + val insertedItemIsNull = ctx.freshName("insertedItemIsNull") + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val values = ctx.freshName("values") + + val allocation = CodeGenerator.createArrayData( + values, elementType, resLength, s"$prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment(values, elementType, arr, + adjustedAllocIdx, i, first.dataType.asInstanceOf[ArrayType].containsNull) + + s""" + |int $itemInsertionIndex = 0; + |int $resLength = 0; + |int $adjustedAllocIdx = 0; + |boolean $insertedItemIsNull = ${itemExpr.isNull}; + | + |if ($pos < 0 && (java.lang.Math.abs($pos) > $arr.numElements())) { + | + | $resLength = java.lang.Math.abs($pos) + 1; + | if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength); + | } + | + | $allocation + | for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $adjustedAllocIdx = $i + 1 + java.lang.Math.abs($pos + $arr.numElements()); + | $assignment + | } + | ${CodeGenerator.setArrayElement( + values, elementType, itemInsertionIndex, item, Some(insertedItemIsNull))} + | + | for (int $j = $pos + $arr.numElements(); $j < 0; $j ++) { + | $values.setNullAt($j + 1 + java.lang.Math.abs($pos + $arr.numElements())); + | } + | + | ${ev.value} = $values; + |} else { + | + | $itemInsertionIndex = 0; + | if ($pos < 0) { + | $itemInsertionIndex = $pos + $arr.numElements(); + | } else if ($pos > 0) { + | $itemInsertionIndex = $pos - 1; + | } + | + | $resLength = java.lang.Math.max($arr.numElements() + 1, $itemInsertionIndex + 1); + | if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength); + | } + | + | $allocation + | for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $adjustedAllocIdx = $i; + | if ($i >= $itemInsertionIndex) { + | $adjustedAllocIdx = $adjustedAllocIdx + 1; + | } + | $assignment + | } + | ${CodeGenerator.setArrayElement( + values, elementType, itemInsertionIndex, item, Some(insertedItemIsNull))} + | + | for (int $j = $arr.numElements(); $j < $resLength - 1; $j ++) { + | $values.setNullAt($j); + | } + | + | ${ev.value} = $values; + |} + """.stripMargin + } + + val leftGen = first.genCode(ctx) + val midGen = second.genCode(ctx) + val rightGen = third.genCode(ctx) + val resultCode = f(leftGen, midGen, rightGen) + + if (nullable) { + val nullSafeEval = + leftGen.code + ctx.nullSafeExec(first.nullable, leftGen.isNull) { + midGen.code + ctx.nullSafeExec(second.nullable, midGen.isNull) { + s""" + ${rightGen.code} + ${ev.isNull} = false; + $resultCode + """ + } + } + + ev.copy(code = code""" + boolean ${ev.isNull} = true; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $nullSafeEval""") + } else { + ev.copy(code = code""" + ${leftGen.code} + ${midGen.code} + ${rightGen.code} + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $resultCode""", isNull = FalseLiteral) + } + } + + override def first: Expression = srcArrayExpr + override def second: Expression = posExpr + override def third: Expression = itemExpr + + override def prettyName: String = "array_insert" + override def dataType: DataType = first.dataType + override def nullable: Boolean = first.nullable | second.nullable + + @transient private lazy val elementType: DataType = + srcArrayExpr.dataType.asInstanceOf[ArrayType].elementType + + + override protected def withNewChildrenInternal( + newSrcArrayExpr: Expression, newPosExpr: Expression, newItemExpr: Expression): ArrayInsert = + copy(srcArrayExpr = newSrcArrayExpr, posExpr = newPosExpr, itemExpr = newItemExpr) +} + @ExpressionDescription( usage = "_FUNC_(array) - Removes null values from the array.", examples = """ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d83739df38d..9b97430594d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -2250,6 +2250,78 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq(2d)) } + test("Array Insert") { + val a1 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType)) + val a2 = Literal.create(Seq(1, 2, null, 4, 5, null), ArrayType(IntegerType)) + val a3 = Literal.create(Seq[Boolean](true, false, true), ArrayType(BooleanType)) + val a4 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType)) + val a5 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType)) + val a6 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), ArrayType(FloatType)) + val a7 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), ArrayType(DoubleType)) + val a8 = Literal.create(Seq(1L, 2L, 4L), ArrayType(LongType)) + val a9 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false)) + val a10 = Literal.create(Seq("b", null, "a", "g", null), ArrayType(StringType, true)) + val a11 = Literal.create(null, ArrayType(StringType)) + + // basic additions per type + checkEvaluation(ArrayInsert(a1, Literal(3), Literal(3)), Seq(1, 2, 3, 4)) + checkEvaluation( + ArrayInsert(a3, Literal.create(3, IntegerType), Literal(true)), + Seq[Boolean](true, false, true, true) + ) + checkEvaluation( + ArrayInsert( + a4, + Literal(3), + Literal.create(5.asInstanceOf[Byte], ByteType)), + Seq[Byte](1, 2, 5, 3, 2)) + + checkEvaluation( + ArrayInsert( + a5, + Literal(3), + Literal.create(3.asInstanceOf[Short], ShortType)), + Seq[Short](1, 2, 3, 3, 2)) + + checkEvaluation( + ArrayInsert(a7, Literal(4), Literal(4.4)), + Seq[Double](1.1, 2.2, 3.3, 4.4, 2.2) + ) + + checkEvaluation( + ArrayInsert(a6, Literal(4), Literal(4.4F)), + Seq(1.1F, 2.2F, 3.3F, 4.4F, 2.2F) + ) + checkEvaluation(ArrayInsert(a8, Literal(3), Literal(3L)), Seq(1L, 2L, 3L, 4L)) + checkEvaluation(ArrayInsert(a9, Literal(3), Literal("d")), Seq("b", "a", "d", "c")) + + // index edge cases + checkEvaluation(ArrayInsert(a1, Literal(2), Literal(3)), Seq(1, 3, 2, 4)) + checkEvaluation(ArrayInsert(a1, Literal(0), Literal(3)), Seq(3, 1, 2, 4)) + checkEvaluation(ArrayInsert(a1, Literal(1), Literal(3)), Seq(3, 1, 2, 4)) + checkEvaluation(ArrayInsert(a1, Literal(4), Literal(3)), Seq(1, 2, 4, 3)) + checkEvaluation(ArrayInsert(a1, Literal(-2), Literal(3)), Seq(1, 3, 2, 4)) + checkEvaluation(ArrayInsert(a1, Literal(-3), Literal(3)), Seq(3, 1, 2, 4)) + checkEvaluation(ArrayInsert(a1, Literal(-4), Literal(3)), Seq(3, null, 1, 2, 4)) + checkEvaluation( + ArrayInsert(a1, Literal(10), Literal(3)), + Seq(1, 2, 4, null, null, null, null, null, null, 3) + ) + checkEvaluation( + ArrayInsert(a1, Literal(-10), Literal(3)), + Seq(3, null, null, null, null, null, null, null, 1, 2, 4) + ) + + // null handling + checkEvaluation(ArrayInsert( + a1, Literal(3), Literal.create(null, IntegerType)), Seq(1, 2, null, 4) + ) + checkEvaluation(ArrayInsert(a2, Literal(3), Literal(3)), Seq(1, 2, 3, null, 4, 5, null)) + checkEvaluation(ArrayInsert(a10, Literal(3), Literal("d")), Seq("b", null, "d", "a", "g", null)) + checkEvaluation(ArrayInsert(a11, Literal(3), Literal("d")), null) + checkEvaluation(ArrayInsert(a10, Literal.create(null, IntegerType), Literal("d")), null) + } + test("Array Intersect") { val a00 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType, false)) val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3d5547ead83..cb5c1ad5c49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4062,6 +4062,16 @@ object functions { ArrayIntersect(col1.expr, col2.expr) } + /** + * Adds an item into a given array at a specified position + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_insert(arr: Column, pos: Column, value: Column): Column = withExpr { + ArrayInsert(arr.expr, pos.expr, value.expr) + } + /** * Returns an array of the elements in the union of the given two arrays, without duplicates. * diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 9b8d50d2ede..ef5c4addc84 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -20,6 +20,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayExists | exists | SELECT exists(array(1, 2, 3), x -> x % 2 == 0) | struct<exists(array(1, 2, 3), lambdafunction(((namedlambdavariable() % 2) = 0), namedlambdavariable())):boolean> | | org.apache.spark.sql.catalyst.expressions.ArrayFilter | filter | SELECT filter(array(1, 2, 3), x -> x % 2 == 1) | struct<filter(array(1, 2, 3), lambdafunction(((namedlambdavariable() % 2) = 1), namedlambdavariable())):array<int>> | | org.apache.spark.sql.catalyst.expressions.ArrayForAll | forall | SELECT forall(array(1, 2, 3), x -> x % 2 == 0) | struct<forall(array(1, 2, 3), lambdafunction(((namedlambdavariable() % 2) = 0), namedlambdavariable())):boolean> | +| org.apache.spark.sql.catalyst.expressions.ArrayInsert | array_insert | SELECT array_insert(array(1, 2, 3, 4), 5, 5) | struct<array_insert(array(1, 2, 3, 4), 5, 5):array<int>> | | org.apache.spark.sql.catalyst.expressions.ArrayIntersect | array_intersect | SELECT array_intersect(array(1, 2, 3), array(1, 3, 5)) | struct<array_intersect(array(1, 2, 3), array(1, 3, 5)):array<int>> | | org.apache.spark.sql.catalyst.expressions.ArrayJoin | array_join | SELECT array_join(array('hello', 'world'), ' ') | struct<array_join(array(hello, world), ):string> | | org.apache.spark.sql.catalyst.expressions.ArrayMax | array_max | SELECT array_max(array(1, 20, null, 3)) | struct<array_max(array(1, 20, NULL, 3)):int> | diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql index f2f0dac98e2..3d107cb6dfc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/array.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -130,6 +130,18 @@ select get(array(1, 2, 3), 3); select get(array(1, 2, 3), null); select get(array(1, 2, 3), -1); +-- function array_insert() +select array_insert(array(1, 2, 3), 3, 4); +select array_insert(array(2, 3, 4), 0, 1); +select array_insert(array(2, 3, 4), 1, 1); +select array_insert(array(1, 3, 4), -2, 2); +select array_insert(array(1, 2, 3), 3, "4"); +select array_insert(cast(NULL as ARRAY<INT>), 1, 1); +select array_insert(array(1, 2, 3, NULL), cast(NULL as INT), 4); +select array_insert(array(1, 2, 3, NULL), 4, cast(NULL as INT)); +select array_insert(array(2, 3, NULL, 4), 5, 5); +select array_insert(array(2, 3, NULL, 4), -5, 1); + -- function array_compact select array_compact(id) from values (1) as t(id); select array_compact(array("1", null, "2", null)); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out index 714775d7e8b..0d8ef39ed60 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out @@ -550,6 +550,104 @@ struct<get(array(1, 2, 3), -1):int> NULL +-- !query +select array_insert(array(1, 2, 3), 3, 4) +-- !query schema +struct<array_insert(array(1, 2, 3), 3, 4):array<int>> +-- !query output +[1,2,4,3] + + +-- !query +select array_insert(array(2, 3, 4), 0, 1) +-- !query schema +struct<array_insert(array(2, 3, 4), 0, 1):array<int>> +-- !query output +[1,2,3,4] + + +-- !query +select array_insert(array(2, 3, 4), 1, 1) +-- !query schema +struct<array_insert(array(2, 3, 4), 1, 1):array<int>> +-- !query output +[1,2,3,4] + + +-- !query +select array_insert(array(1, 3, 4), -2, 2) +-- !query schema +struct<array_insert(array(1, 3, 4), -2, 2):array<int>> +-- !query output +[1,2,3,4] + + +-- !query +select array_insert(array(1, 2, 3), 3, "4") +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", + "sqlState" : "42K09", + "messageParameters" : { + "dataType" : "\"ARRAY\"", + "functionName" : "`array_insert`", + "leftType" : "\"ARRAY<INT>\"", + "rightType" : "\"STRING\"", + "sqlExpr" : "\"array_insert(array(1, 2, 3), 3, 4)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 43, + "fragment" : "array_insert(array(1, 2, 3), 3, \"4\")" + } ] +} + + +-- !query +select array_insert(cast(NULL as ARRAY<INT>), 1, 1) +-- !query schema +struct<array_insert(NULL, 1, 1):array<int>> +-- !query output +NULL + + +-- !query +select array_insert(array(1, 2, 3, NULL), cast(NULL as INT), 4) +-- !query schema +struct<array_insert(array(1, 2, 3, NULL), CAST(NULL AS INT), 4):array<int>> +-- !query output +NULL + + +-- !query +select array_insert(array(1, 2, 3, NULL), 4, cast(NULL as INT)) +-- !query schema +struct<array_insert(array(1, 2, 3, NULL), 4, CAST(NULL AS INT)):array<int>> +-- !query output +[1,2,3,null,null] + + +-- !query +select array_insert(array(2, 3, NULL, 4), 5, 5) +-- !query schema +struct<array_insert(array(2, 3, NULL, 4), 5, 5):array<int>> +-- !query output +[2,3,null,4,5] + + +-- !query +select array_insert(array(2, 3, NULL, 4), -5, 1) +-- !query schema +struct<array_insert(array(2, 3, NULL, 4), -5, 1):array<int>> +-- !query output +[1,null,2,3,null,4] + + -- !query select array_compact(id) from values (1) as t(id) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 94d4c7987a8..609122a23d3 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -431,6 +431,104 @@ struct<get(array(1, 2, 3), -1):int> NULL +-- !query +select array_insert(array(1, 2, 3), 3, 4) +-- !query schema +struct<array_insert(array(1, 2, 3), 3, 4):array<int>> +-- !query output +[1,2,4,3] + + +-- !query +select array_insert(array(2, 3, 4), 0, 1) +-- !query schema +struct<array_insert(array(2, 3, 4), 0, 1):array<int>> +-- !query output +[1,2,3,4] + + +-- !query +select array_insert(array(2, 3, 4), 1, 1) +-- !query schema +struct<array_insert(array(2, 3, 4), 1, 1):array<int>> +-- !query output +[1,2,3,4] + + +-- !query +select array_insert(array(1, 3, 4), -2, 2) +-- !query schema +struct<array_insert(array(1, 3, 4), -2, 2):array<int>> +-- !query output +[1,2,3,4] + + +-- !query +select array_insert(array(1, 2, 3), 3, "4") +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", + "sqlState" : "42K09", + "messageParameters" : { + "dataType" : "\"ARRAY\"", + "functionName" : "`array_insert`", + "leftType" : "\"ARRAY<INT>\"", + "rightType" : "\"STRING\"", + "sqlExpr" : "\"array_insert(array(1, 2, 3), 3, 4)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 43, + "fragment" : "array_insert(array(1, 2, 3), 3, \"4\")" + } ] +} + + +-- !query +select array_insert(cast(NULL as ARRAY<INT>), 1, 1) +-- !query schema +struct<array_insert(NULL, 1, 1):array<int>> +-- !query output +NULL + + +-- !query +select array_insert(array(1, 2, 3, NULL), cast(NULL as INT), 4) +-- !query schema +struct<array_insert(array(1, 2, 3, NULL), CAST(NULL AS INT), 4):array<int>> +-- !query output +NULL + + +-- !query +select array_insert(array(1, 2, 3, NULL), 4, cast(NULL as INT)) +-- !query schema +struct<array_insert(array(1, 2, 3, NULL), 4, CAST(NULL AS INT)):array<int>> +-- !query output +[1,2,3,null,null] + + +-- !query +select array_insert(array(2, 3, NULL, 4), 5, 5) +-- !query schema +struct<array_insert(array(2, 3, NULL, 4), 5, 5):array<int>> +-- !query output +[2,3,null,4,5] + + +-- !query +select array_insert(array(2, 3, NULL, 4), -5, 1) +-- !query schema +struct<array_insert(array(2, 3, NULL, 4), -5, 1):array<int>> +-- !query output +[1,null,2,3,null,4] + + -- !query select array_compact(id) from values (1) as t(id) -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b7daa45a42a..14def67ba40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3106,6 +3106,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) } + test("array_insert functions") { + val fiveShort: Short = 5 + + val df1 = Seq((Array[Integer](3, 2, 5, 1, 2), 6, 3)).toDF("a", "b", "c") + val df2 = Seq((Array[Short](1, 2, 3, 4), 5, fiveShort)).toDF("a", "b", "c") + val df3 = Seq((Array[Double](3.0, 2.0, 5.0, 1.0, 2.0), 2, 3.0)).toDF("a", "b", "c") + val df4 = Seq((Array[Boolean](true, false), 3, false)).toDF("a", "b", "c") + val df5 = Seq((Array[String]("a", "b", "c"), 0, "d")).toDF("a", "b", "c") + val df6 = Seq((Array[String]("a", null, "b", "c"), 5, "d")).toDF("a", "b", "c") + + checkAnswer(df1.selectExpr("array_insert(a, b, c)"), Seq(Row(Seq(3, 2, 5, 1, 2, 3)))) + checkAnswer(df2.selectExpr("array_insert(a, b, c)"), Seq(Row(Seq[Short](1, 2, 3, 4, 5)))) + checkAnswer( + df3.selectExpr("array_insert(a, b, c)"), + Seq(Row(Seq[Double](3.0, 3.0, 2.0, 5.0, 1.0, 2.0))) + ) + checkAnswer(df4.selectExpr("array_insert(a, b, c)"), Seq(Row(Seq(true, false, false)))) + checkAnswer(df5.selectExpr("array_insert(a, b, c)"), Seq(Row(Seq("d", "a", "b", "c")))) + checkAnswer(df5.select( + array_insert(col("a"), lit(1), col("c"))), + Seq(Row(Seq("d", "a", "b", "c"))) + ) + // null checks + checkAnswer(df6.selectExpr("array_insert(a, b, c)"), Seq(Row(Seq("a", null, "b", "c", "d")))) + checkAnswer(df5.select( + array_insert(col("a"), col("b"), lit(null).cast("string"))), + Seq(Row(Seq(null, "a", "b", "c"))) + ) + checkAnswer(df6.select( + array_insert(col("a"), col("b"), lit(null).cast("string"))), + Seq(Row(Seq("a", null, "b", "c", null))) + ) + checkAnswer( + df5.select(array_insert(col("a"), lit(null).cast("integer"), col("c"))), + Seq(Row(null)) + ) + checkAnswer( + df5.select(array_insert(lit(null).cast("array<string>"), col("b"), col("c"))), + Seq(Row(null)) + ) + checkAnswer(df1.selectExpr("array_insert(a, 7, c)"), Seq(Row(Seq(3, 2, 5, 1, 2, null, 3)))) + checkAnswer(df1.selectExpr("array_insert(a, -6, c)"), Seq(Row(Seq(3, null, 3, 2, 5, 1, 2)))) + } + test("transform function - array for primitive type not containing null") { val df = Seq( Seq(1, 9, 8, 7), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org