This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 3dd629629ab [SPARK-41233][SQL][PYTHON] Add `array_prepend` function 3dd629629ab is described below commit 3dd629629ab151688b82a3aa66e1b5fa568afbfa Author: Navin Viswanath <navin.v...@gmail.com> AuthorDate: Thu Mar 16 17:51:33 2023 +0800 [SPARK-41233][SQL][PYTHON] Add `array_prepend` function ### What changes were proposed in this pull request? Adds a new array function array_prepend to catalyst. ### Why are the changes needed? This adds a function that exists in many SQL implementations, specifically Snowflake: https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/api/snowflake.snowpark.functions.array_prepend.html ### Does this PR introduce _any_ user-facing change? Yes. ### How was this patch tested? Added unit tests. Closes #38947 from navinvishy/array-prepend. Lead-authored-by: Navin Viswanath <navin.v...@gmail.com> Co-authored-by: navinvishy <navin.v...@gmail.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../source/reference/pyspark.sql/functions.rst | 1 + python/pyspark/sql/functions.py | 30 +++++ .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 146 +++++++++++++++++++++ .../expressions/CollectionExpressionsSuite.scala | 44 +++++++ .../scala/org/apache/spark/sql/functions.scala | 10 ++ .../sql-functions/sql-expression-schema.md | 3 +- .../src/test/resources/sql-tests/inputs/array.sql | 11 ++ .../resources/sql-tests/results/ansi/array.sql.out | 72 ++++++++++ .../test/resources/sql-tests/results/array.sql.out | 72 ++++++++++ .../apache/spark/sql/DataFrameFunctionsSuite.scala | 68 ++++++++++ 11 files changed, 457 insertions(+), 1 deletion(-) diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index 70fc04ef9cf..cbc46e1fae1 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -159,6 +159,7 @@ Collection Functions array_sort array_insert array_remove + array_prepend array_distinct array_intersect array_union diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 051fd52a13c..1f02be3ad21 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7631,6 +7631,36 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +@try_remote_functions +def array_prepend(col: "ColumnOrName", value: Any) -> Column: + """ + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned + at the beginning of the array. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing array + value : + a literal value, or a :class:`~pyspark.sql.Column` expression. + + Returns + ------- + :class:`~pyspark.sql.Column` + an array excluding given value. + + Examples + -------- + >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + >>> df.select(array_prepend(df.data, 1)).collect() + [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] + """ + return _invoke_function_over_columns("array_prepend", col, lit(value)) + + @try_remote_functions def array_remove(col: "ColumnOrName", element: Any) -> Column: """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ad82a836199..aca73741c63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -697,6 +697,7 @@ object FunctionRegistry { expression[Sequence]("sequence"), expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), + expression[ArrayPrepend]("array_prepend"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), expression[MapFilter]("map_filter"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 289859d420b..2ccb3a6d0cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1399,6 +1399,152 @@ case class ArrayContains(left: Expression, right: Expression) copy(left = newLeft, right = newRight) } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(array, element) - Add the element at the beginning of the array passed as first + argument. Type of element should be the same as the type of the elements of the array. + Null element is also prepended to the array. But if the array passed is NULL + output is NULL + """, + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd'); + ["d","b","d","c","a"] + > SELECT _FUNC_(array(1, 2, 3, null), null); + [null,1,2,3,null] + > SELECT _FUNC_(CAST(null as Array<Int>), 2); + NULL + """, + group = "array_funcs", + since = "3.5.0") +case class ArrayPrepend(left: Expression, right: Expression) + extends BinaryExpression + with ImplicitCastInputTypes + with ComplexTypeMergingExpression + with QueryErrorsBase { + + override def nullable: Boolean = left.nullable + + @transient protected lazy val elementType: DataType = + inputTypes.head.asInstanceOf[ArrayType].elementType + + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + nullSafeEval(value1, value2) + } + } + override def nullSafeEval(arr: Any, elementData: Any): Any = { + val arrayData = arr.asInstanceOf[ArrayData] + val numberOfElements = arrayData.numElements() + 1 + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements) + } + val finalData = new Array[Any](numberOfElements) + finalData.update(0, elementData) + arrayData.foreach(elementType, (i: Int, v: Any) => finalData.update(i + 1, v)) + new GenericArrayData(finalData) + } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) + val f = (arr: String, value: String) => { + val newArraySize = s"$arr.numElements() + 1" + val newArray = ctx.freshName("newArray") + val i = ctx.freshName("i") + val iPlus1 = s"$i+1" + val zero = "0" + val allocation = CodeGenerator.createArrayData( + newArray, + elementType, + newArraySize, + s" $prettyName failed.") + val assignment = + CodeGenerator.createArrayAssignment(newArray, elementType, arr, iPlus1, i, false) + val newElemAssignment = + CodeGenerator.setArrayElement(newArray, elementType, zero, value, Some(rightGen.isNull)) + s""" + |$allocation + |$newElemAssignment + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $assignment + |} + |${ev.value} = $newArray; + |""".stripMargin + } + val resultCode = f(leftGen.value, rightGen.value) + if(nullable) { + val nullSafeEval = leftGen.code + rightGen.code + ctx.nullSafeExec(nullable, leftGen.isNull) { + s""" + |${ev.isNull} = false; + |${resultCode} + |""".stripMargin + } + ev.copy(code = + code""" + |boolean ${ev.isNull} = true; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$nullSafeEval + """.stripMargin + ) + } else { + ev.copy(code = + code""" + |${leftGen.code} + |${rightGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """.stripMargin, isNull = FalseLiteral) + } + } + + override def prettyName: String = "array_prepend" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayPrepend = + copy(left = newLeft, right = newRight) + + override def dataType: DataType = if (right.nullable) left.dataType.asNullable else left.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (ArrayType(e1, _), e2) if e1.sameType(e2) => TypeCheckResult.TypeCheckSuccess + case (ArrayType(e1, _), e2) => DataTypeMismatch( + errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "leftType" -> toSQLType(left.dataType), + "rightType" -> toSQLType(right.dataType), + "dataType" -> toSQLType(ArrayType) + )) + case _ => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "0", + "requiredType" -> toSQLType(ArrayType), + "inputSql" -> toSQLExpr(left), + "inputType" -> toSQLType(left.dataType) + ) + ) + } + } + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (ArrayType(e1, hasNull), e2) => + TypeCoercion.findTightestCommonType(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull), dt) + case _ => Seq.empty + } + case _ => Seq.empty + } + } +} + /** * Checks if the two arrays contain at least one common element. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 60300ba62f2..3abc70a3d55 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1855,6 +1855,50 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) } + test("SPARK-41233: ArrayPrepend") { + val a0 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType)) + val a1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) + val a2 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4)) + checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c")) + checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1)) + checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), Seq(null)) + checkEvaluation(ArrayPrepend(a3, Literal("a")), null) + checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null) + + // complex data types + val data = Seq[Array[Byte]]( + Array[Byte](5, 6), + Array[Byte](1, 2), + Array[Byte](1, 2), + Array[Byte](5, 6)) + val b0 = Literal.create( + data, + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType)) + val nullBinary = Literal.create(null, BinaryType) + // Calling ArrayPrepend with a null element should result in NULL being prepended to the array + val dataWithNullPrepended = null +: data + checkEvaluation(ArrayPrepend(b0, nullBinary), dataWithNullPrepended) + val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType) + checkEvaluation( + ArrayPrepend(b1, dataToPrepend1), + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](2, 1), null)) + + val c0 = Literal.create( + Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val dataToPrepend2 = Literal.create(Seq[Int](5, 6), ArrayType(IntegerType)) + checkEvaluation( + ArrayPrepend(c0, dataToPrepend2), + Seq(Seq[Int](5, 6), Seq[Int](1, 2), Seq[Int](3, 4))) + checkEvaluation( + ArrayPrepend(c0, Literal.create(Seq.empty[Int], ArrayType(IntegerType))), + Seq(Seq.empty[Int], Seq[Int](1, 2), Seq[Int](3, 4))) + } + test("Array remove") { val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType)) val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cb5c1ad5c49..d771367f318 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4043,6 +4043,16 @@ object functions { def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) } + /** + * Returns an array containing value as well as all elements from array. The new element is + * positioned at the beginning of the array. + * + * @group collection_funcs + * @since 3.5.0 + */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) + } /** * Removes duplicate values from the array. diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 0894d03f9d4..6b5b67f9849 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -26,6 +26,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayMax | array_max | SELECT array_max(array(1, 20, null, 3)) | struct<array_max(array(1, 20, NULL, 3)):int> | | org.apache.spark.sql.catalyst.expressions.ArrayMin | array_min | SELECT array_min(array(1, 20, null, 3)) | struct<array_min(array(1, 20, NULL, 3)):int> | | org.apache.spark.sql.catalyst.expressions.ArrayPosition | array_position | SELECT array_position(array(3, 2, 1), 1) | struct<array_position(array(3, 2, 1), 1):bigint> | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct<array_prepend(array(b, d, c, a), d):array<string>> | | org.apache.spark.sql.catalyst.expressions.ArrayRemove | array_remove | SELECT array_remove(array(1, 2, 3, null, 3), 3) | struct<array_remove(array(1, 2, 3, NULL, 3), 3):array<int>> | | org.apache.spark.sql.catalyst.expressions.ArrayRepeat | array_repeat | SELECT array_repeat('123', 2) | struct<array_repeat(123, 2):array<string>> | | org.apache.spark.sql.catalyst.expressions.ArraySize | array_size | SELECT array_size(array('b', 'd', 'c', 'a')) | struct<array_size(array(b, d, c, a)):int> | @@ -421,4 +422,4 @@ | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>','a/b/text()') | struct<xpath(<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>, a/b/text()):array<string>> | | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_long(<a><b>1</b><b>2</b></a>, sum(a/b)):bigint> | | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_short(<a><b>1</b><b>2</b></a>, sum(a/b)):smallint> | -| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> | \ No newline at end of file +| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> | diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql index 3d107cb6dfc..d3c36b79d1f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/array.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -160,3 +160,14 @@ select array_append(CAST(null AS ARRAY<String>), CAST(null as String)); select array_append(array(), 1); select array_append(CAST(array() AS ARRAY<String>), CAST(NULL AS String)); select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)); + +-- function array_prepend +select array_prepend(array(1, 2, 3), 4); +select array_prepend(array('a', 'b', 'c'), 'd'); +select array_prepend(array(1, 2, 3, NULL), NULL); +select array_prepend(array('a', 'b', 'c', NULL), NULL); +select array_prepend(CAST(null AS ARRAY<String>), 'a'); +select array_prepend(CAST(null AS ARRAY<String>), CAST(null as String)); +select array_prepend(array(), 1); +select array_prepend(CAST(array() AS ARRAY<String>), CAST(NULL AS String)); +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out index 0d8ef39ed60..d228c605705 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out @@ -784,3 +784,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)) struct<array_append(array(CAST(NULL AS STRING)), CAST(NULL AS STRING)):array<string>> -- !query output [null,null] + + +-- !query +select array_prepend(array(1, 2, 3), 4) +-- !query schema +struct<array_prepend(array(1, 2, 3), 4):array<int>> +-- !query output +[4,1,2,3] + + +-- !query +select array_prepend(array('a', 'b', 'c'), 'd') +-- !query schema +struct<array_prepend(array(a, b, c), d):array<string>> +-- !query output +["d","a","b","c"] + + +-- !query +select array_prepend(array(1, 2, 3, NULL), NULL) +-- !query schema +struct<array_prepend(array(1, 2, 3, NULL), NULL):array<int>> +-- !query output +[null,1,2,3,null] + + +-- !query +select array_prepend(array('a', 'b', 'c', NULL), NULL) +-- !query schema +struct<array_prepend(array(a, b, c, NULL), NULL):array<string>> +-- !query output +[null,"a","b","c",null] + + +-- !query +select array_prepend(CAST(null AS ARRAY<String>), 'a') +-- !query schema +struct<array_prepend(NULL, a):array<string>> +-- !query output +NULL + + +-- !query +select array_prepend(CAST(null AS ARRAY<String>), CAST(null as String)) +-- !query schema +struct<array_prepend(NULL, CAST(NULL AS STRING)):array<string>> +-- !query output +NULL + + +-- !query +select array_prepend(array(), 1) +-- !query schema +struct<array_prepend(array(), 1):array<int>> +-- !query output +[1] + + +-- !query +select array_prepend(CAST(array() AS ARRAY<String>), CAST(NULL AS String)) +-- !query schema +struct<array_prepend(array(), CAST(NULL AS STRING)):array<string>> +-- !query output +[null] + + +-- !query +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) +-- !query schema +struct<array_prepend(array(CAST(NULL AS STRING)), CAST(NULL AS STRING)):array<string>> +-- !query output +[null,null] diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 609122a23d3..029bd767f54 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -665,3 +665,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)) struct<array_append(array(CAST(NULL AS STRING)), CAST(NULL AS STRING)):array<string>> -- !query output [null,null] + + +-- !query +select array_prepend(array(1, 2, 3), 4) +-- !query schema +struct<array_prepend(array(1, 2, 3), 4):array<int>> +-- !query output +[4,1,2,3] + + +-- !query +select array_prepend(array('a', 'b', 'c'), 'd') +-- !query schema +struct<array_prepend(array(a, b, c), d):array<string>> +-- !query output +["d","a","b","c"] + + +-- !query +select array_prepend(array(1, 2, 3, NULL), NULL) +-- !query schema +struct<array_prepend(array(1, 2, 3, NULL), NULL):array<int>> +-- !query output +[null,1,2,3,null] + + +-- !query +select array_prepend(array('a', 'b', 'c', NULL), NULL) +-- !query schema +struct<array_prepend(array(a, b, c, NULL), NULL):array<string>> +-- !query output +[null,"a","b","c",null] + + +-- !query +select array_prepend(CAST(null AS ARRAY<String>), 'a') +-- !query schema +struct<array_prepend(NULL, a):array<string>> +-- !query output +NULL + + +-- !query +select array_prepend(CAST(null AS ARRAY<String>), CAST(null as String)) +-- !query schema +struct<array_prepend(NULL, CAST(NULL AS STRING)):array<string>> +-- !query output +NULL + + +-- !query +select array_prepend(array(), 1) +-- !query schema +struct<array_prepend(array(), 1):array<int>> +-- !query output +[1] + + +-- !query +select array_prepend(CAST(array() AS ARRAY<String>), CAST(NULL AS String)) +-- !query schema +struct<array_prepend(array(), CAST(NULL AS STRING)):array<string>> +-- !query output +[null] + + +-- !query +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) +-- !query schema +struct<array_prepend(array(CAST(NULL AS STRING)), CAST(NULL AS STRING)):array<string>> +-- !query output +[null,null] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index bd03d292820..355f2dfffb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2651,6 +2651,74 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) } + test("SPARK-41233: array prepend") { + val df = Seq( + (Array[Int](2, 3, 4), Array("b", "c", "d"), Array("", ""), 2), + (Array.empty[Int], Array.empty[String], Array.empty[String], 2), + (null, null, null, 2)).toDF("a", "b", "c", "d") + checkAnswer( + df.select(array_prepend($"a", 1), array_prepend($"b", "a"), array_prepend($"c", "")), + Seq( + Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")), + Row(Seq(1), Seq("a"), Seq("")), + Row(null, null, null))) + checkAnswer( + df.select(array_prepend($"a", $"d")), + Seq( + Row(Seq(2, 2, 3, 4)), + Row(Seq(2)), + Row(null))) + checkAnswer( + df.selectExpr("array_prepend(a, d)"), + Seq( + Row(Seq(2, 2, 3, 4)), + Row(Seq(2)), + Row(null))) + checkAnswer( + OneRowRelation().selectExpr("array_prepend(array(1, 2), 1.23D)"), + Seq( + Row(Seq(1.23, 1.0, 2.0)) + ) + ) + checkAnswer( + df.selectExpr("array_prepend(a, 1)", "array_prepend(b, \"a\")", "array_prepend(c, \"\")"), + Seq( + Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")), + Row(Seq(1), Seq("a"), Seq("")), + Row(null, null, null))) + checkError( + exception = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_prepend(_1, _2)") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "paramIndex" -> "0", + "sqlExpr" -> "\"array_prepend(_1, _2)\"", + "inputSql" -> "\"_1\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "\"ARRAY\""), + queryContext = Array(ExpectedContext("", "", 0, 20, "array_prepend(_1, _2)"))) + checkError( + exception = intercept[AnalysisException] { + OneRowRelation().selectExpr("array_prepend(array(1, 2), '1')") + }, + errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"array_prepend(array(1, 2), 1)\"", + "functionName" -> "`array_prepend`", + "dataType" -> "\"ARRAY\"", + "leftType" -> "\"ARRAY<INT>\"", + "rightType" -> "\"STRING\""), + queryContext = Array(ExpectedContext("", "", 0, 30, "array_prepend(array(1, 2), '1')"))) + val df2 = Seq((Array[String]("a", "b", "c"), "d"), + (null, "d"), + (Array[String]("x", "y", "z"), null), + (null, null) + ).toDF("a", "b") + checkAnswer(df2.selectExpr("array_prepend(a, b)"), + Seq(Row(Seq("d", "a", "b", "c")), Row(null), Row(Seq(null, "x", "y", "z")), Row(null))) + } + test("array remove") { val df = Seq( (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org