This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f5365d0dc59 [SPARK-45034][SQL] Support deterministic mode function f5365d0dc59 is described below commit f5365d0dc590d4965a269da223dbd72fbb764595 Author: Peter Toth <peter.t...@gmail.com> AuthorDate: Sun Sep 17 21:37:57 2023 +0300 [SPARK-45034][SQL] Support deterministic mode function ### What changes were proposed in this pull request? This PR adds a new optional argument to the `mode` aggregate function to provide deterministic results. When multiple values have the same greatest frequency then the new boolean argument can be used to get the lowest or highest value instead of an arbitraty one. ### Why are the changes needed? To make the function more user friendly. ### Does this PR introduce _any_ user-facing change? Yes, it adds a new argument to the `mode` function. ### How was this patch tested? Added new UTs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42755 from peter-toth/SPARK-45034-deterministic-mode-function. Authored-by: Peter Toth <peter.t...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../scala/org/apache/spark/sql/functions.scala | 14 ++- .../explain-results/function_mode.explain | 2 +- .../query-tests/queries/function_mode.json | 4 + .../query-tests/queries/function_mode.proto.bin | Bin 173 -> 179 bytes python/pyspark/sql/connect/functions.py | 4 +- python/pyspark/sql/functions.py | 35 ++++-- .../sql/catalyst/expressions/aggregate/Mode.scala | 76 ++++++++++-- .../scala/org/apache/spark/sql/functions.scala | 16 ++- .../sql-functions/sql-expression-schema.md | 2 +- .../sql-tests/analyzer-results/group-by.sql.out | 120 ++++++++++++++++++- .../test/resources/sql-tests/inputs/group-by.sql | 11 ++ .../resources/sql-tests/results/group-by.sql.out | 132 ++++++++++++++++++++- .../apache/spark/sql/DatasetAggregatorSuite.scala | 10 ++ 13 files changed, 397 insertions(+), 29 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index b2102d4ba55..83f0ee64501 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -827,7 +827,19 @@ object functions { * @group agg_funcs * @since 3.4.0 */ - def mode(e: Column): Column = Column.fn("mode", e) + def mode(e: Column): Column = mode(e, deterministic = false) + + /** + * Aggregate function: returns the most frequent value in a group. + * + * When multiple values have the same greatest frequency then either any of values is returned + * if deterministic is false or is not defined, or the lowest value is returned if deterministic + * is true. + * + * @group agg_funcs + * @since 4.0.0 + */ + def mode(e: Column, deterministic: Boolean): Column = Column.fn("mode", e, lit(deterministic)) /** * Aggregate function: returns the maximum value of the expression in a group. diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_mode.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_mode.explain index dfa2113a2c3..28bbb44b0fd 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_mode.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_mode.explain @@ -1,2 +1,2 @@ -Aggregate [mode(a#0, 0, 0) AS mode(a)#0] +Aggregate [mode(a#0, 0, 0, false) AS mode(a, false)#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_mode.json b/connector/connect/common/src/test/resources/query-tests/queries/function_mode.json index 8e8183e9e08..5c26edee803 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_mode.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_mode.json @@ -18,6 +18,10 @@ "unresolvedAttribute": { "unparsedIdentifier": "a" } + }, { + "literal": { + "boolean": false + } }] } }] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_mode.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_mode.proto.bin index dca0953a387..cc115e43172 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_mode.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_mode.proto.bin differ diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 892ad6e6295..f89b1aae500 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -1136,8 +1136,8 @@ def min_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column: min_by.__doc__ = pysparkfuncs.min_by.__doc__ -def mode(col: "ColumnOrName") -> Column: - return _invoke_function_over_columns("mode", col) +def mode(col: "ColumnOrName", deterministic: bool = False) -> Column: + return _invoke_function("mode", _to_col(col), lit(deterministic)) mode.__doc__ = pysparkfuncs.mode.__doc__ diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 31936241619..1e12b9bf469 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -737,16 +737,21 @@ def abs(col: "ColumnOrName") -> Column: @_try_remote_functions -def mode(col: "ColumnOrName") -> Column: +def mode(col: "ColumnOrName", deterministic: bool = False) -> Column: """ Returns the most frequent value in a group. .. versionadded:: 3.4.0 + .. versionchanged:: 4.0.0 + Supports deterministic argument. + Parameters ---------- col : :class:`~pyspark.sql.Column` or str target column to compute on. + deterministic : bool, optional + if there are multiple equally-frequent results then return the lowest (defaults to false). Returns ------- @@ -765,14 +770,26 @@ def mode(col: "ColumnOrName") -> Column: ... ("dotNET", 2013, 48000), ("Java", 2013, 30000)], ... schema=("course", "year", "earnings")) >>> df.groupby("course").agg(mode("year")).show() - +------+----------+ - |course|mode(year)| - +------+----------+ - | Java| 2012| - |dotNET| 2012| - +------+----------+ - """ - return _invoke_function_over_columns("mode", col) + +------+-----------------+ + |course|mode(year, false)| + +------+-----------------+ + | Java| 2012| + |dotNET| 2012| + +------+-----------------+ + + When multiple values have the same greatest frequency then either any of values is returned if + deterministic is false or is not defined, or the lowest value is returned if deterministic is + true. + + >>> df2 = spark.createDataFrame([(-10,), (0,), (10,)], ["col"]) + >>> df2.select(mode("col", False), mode("col", True)).show() + +----------------+---------------+ + |mode(col, false)|mode(col, true)| + +----------------+---------------+ + | 0| -10| + +----------------+---------------+ + """ + return _invoke_function("mode", _to_java_column(col), deterministic) @_try_remote_functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala index cad7d1f07dc..4ac44d9d2c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala @@ -18,15 +18,22 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes} -import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes, Literal} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike} +import org.apache.spark.sql.catalyst.types.PhysicalDataType import org.apache.spark.sql.catalyst.util.GenericArrayData -import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, DataType} +import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLExpr +import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType} +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType} import org.apache.spark.util.collection.OpenHashMap // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(col) - Returns the most frequent value for the values within `col`. NULL values are ignored. If all the values are NULL, or there are 0 rows, returns NULL.", + usage = """ + _FUNC_(col[, deterministic]) - Returns the most frequent value for the values within `col`. NULL values are ignored. If all the values are NULL, or there are 0 rows, returns NULL. + When multiple values have the same greatest frequency then either any of values is returned if `deterministic` is false or is not defined, or the lowest value is returned if `deterministic` is true.""", examples = """ Examples: > SELECT _FUNC_(col) FROM VALUES (0), (10), (10) AS tab(col); @@ -35,6 +42,10 @@ import org.apache.spark.util.collection.OpenHashMap 0-10 > SELECT _FUNC_(col) FROM VALUES (0), (10), (10), (null), (null), (null) AS tab(col); 10 + > SELECT _FUNC_(col, false) FROM VALUES (-10), (0), (10) AS tab(col); + 0 + > SELECT _FUNC_(col, true) FROM VALUES (-10), (0), (10) AS tab(col); + -10 """, group = "agg_funcs", since = "3.4.0") @@ -42,17 +53,53 @@ import org.apache.spark.util.collection.OpenHashMap case class Mode( child: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends TypedAggregateWithHashMapAsBuffer - with ImplicitCastInputTypes with UnaryLike[Expression] { + inputAggBufferOffset: Int = 0, + deterministicExpr: Expression = Literal.FalseLiteral) + extends TypedAggregateWithHashMapAsBuffer with ImplicitCastInputTypes + with BinaryLike[Expression] { def this(child: Expression) = this(child, 0, 0) + def this(child: Expression, deterministicExpr: Expression) = { + this(child, 0, 0, deterministicExpr) + } + + @transient + protected lazy val deterministicResult = deterministicExpr.eval().asInstanceOf[Boolean] + + override def left: Expression = child + + override def right: Expression = deterministicExpr + // Returns null for empty inputs override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + return defaultCheck + } + if (!deterministicExpr.foldable) { + DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> toSQLId("deterministic"), + "inputType" -> toSQLType(deterministicExpr.dataType), + "inputExpr" -> toSQLExpr(deterministicExpr) + ) + ) + } else if (deterministicExpr.eval() == null) { + DataTypeMismatch( + errorSubClass = "UNEXPECTED_NULL", + messageParameters = Map("exprName" -> toSQLId("deterministic"))) + } else { + TypeCheckSuccess + } + } override def prettyName: String = "mode" @@ -81,7 +128,16 @@ case class Mode( return null } - buffer.maxBy(_._2)._1 + (if (deterministicResult) { + // When deterministic result is rquired but multiple keys have the same greatest frequency + // then let's select the lowest. + val defaultKeyOrdering = + PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]] + val ordering = Ordering.Tuple2(Ordering.Long, defaultKeyOrdering.reverse) + buffer.maxBy { case (key, count) => (count, key) }(ordering) + } else { + buffer.maxBy(_._2) + })._1 } override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Mode = @@ -90,8 +146,8 @@ case class Mode( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Mode = copy(inputAggBufferOffset = newInputAggBufferOffset) - override protected def withNewChildInternal(newChild: Expression): Expression = - copy(child = newChild) + override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = + copy(child = newLeft, deterministicExpr = newRight) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5935695818e..dcde01ec408 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -870,7 +870,21 @@ object functions { * @group agg_funcs * @since 3.4.0 */ - def mode(e: Column): Column = withAggregateFunction { Mode(e.expr) } + def mode(e: Column): Column = mode(e, deterministic = false) + + /** + * Aggregate function: returns the most frequent value in a group. + * + * When multiple values have the same greatest frequency then either any of values is returned + * if deterministic is false or is not defined, or the lowest value is returned if deterministic + * is true. + * + * @group agg_funcs + * @since 4.0.0 + */ + def mode(e: Column, deterministic: Boolean): Column = withAggregateFunction { + Mode(e.expr, deterministicExpr = lit(deterministic).expr) + } /** * Aggregate function: returns the maximum value of the expression in a group. diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index f518a67e1fa..9e06d5ac58a 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -404,7 +404,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.Median | median | SELECT median(col) FROM VALUES (0), (10) AS tab(col) | struct<median(col):double> | | org.apache.spark.sql.catalyst.expressions.aggregate.Min | min | SELECT min(col) FROM VALUES (10), (-1), (20) AS tab(col) | struct<min(col):int> | | org.apache.spark.sql.catalyst.expressions.aggregate.MinBy | min_by | SELECT min_by(x, y) FROM VALUES ('a', 10), ('b', 50), ('c', 20) AS tab(x, y) | struct<min_by(x, y):string> | -| org.apache.spark.sql.catalyst.expressions.aggregate.Mode | mode | SELECT mode(col) FROM VALUES (0), (10), (10) AS tab(col) | struct<mode(col):int> | +| org.apache.spark.sql.catalyst.expressions.aggregate.Mode | mode | SELECT mode(col) FROM VALUES (0), (10), (10) AS tab(col) | struct<mode(col, false):int> | | org.apache.spark.sql.catalyst.expressions.aggregate.Percentile | percentile | SELECT percentile(col, 0.3) FROM VALUES (0), (10) AS tab(col) | struct<percentile(col, 0.3, 1):double> | | org.apache.spark.sql.catalyst.expressions.aggregate.RegrAvgX | regr_avgx | SELECT regr_avgx(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct<regr_avgx(y, x):double> | | org.apache.spark.sql.catalyst.expressions.aggregate.RegrAvgY | regr_avgy | SELECT regr_avgy(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct<regr_avgy(y, x):double> | diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out index 202ceee1804..56b2553045f 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out @@ -1155,7 +1155,7 @@ Aggregate [a#x], [a#x, collect_list(b#x, 0, 0) AS collect_list(b)#x, collect_lis -- !query SELECT mode(a), mode(b) FROM testData -- !query analysis -Aggregate [mode(a#x, 0, 0) AS mode(a)#x, mode(b#x, 0, 0) AS mode(b)#x] +Aggregate [mode(a#x, 0, 0, false) AS mode(a, false)#x, mode(b#x, 0, 0, false) AS mode(b, false)#x] +- SubqueryAlias testdata +- View (`testData`, [a#x,b#x]) +- Project [cast(a#x as int) AS a#x, cast(b#x as int) AS b#x] @@ -1168,7 +1168,7 @@ Aggregate [mode(a#x, 0, 0) AS mode(a)#x, mode(b#x, 0, 0) AS mode(b)#x] SELECT a, mode(b) FROM testData GROUP BY a ORDER BY a -- !query analysis Sort [a#x ASC NULLS FIRST], true -+- Aggregate [a#x], [a#x, mode(b#x, 0, 0) AS mode(b)#x] ++- Aggregate [a#x], [a#x, mode(b#x, 0, 0, false) AS mode(b, false)#x] +- SubqueryAlias testdata +- View (`testData`, [a#x,b#x]) +- Project [cast(a#x as int) AS a#x, cast(b#x as int) AS b#x] @@ -1196,3 +1196,119 @@ Aggregate [c#x], [(c#x * 2) AS d#x] +- Project [if ((a#x < 0)) 0 else a#x AS b#x] +- SubqueryAlias t1 +- LocalRelation [a#x] + + +-- !query +SELECT mode(col) FROM VALUES (-10), (0), (10) AS tab(col) +-- !query analysis +Aggregate [mode(col#x, 0, 0, false) AS mode(col, false)#x] ++- SubqueryAlias tab + +- LocalRelation [col#x] + + +-- !query +SELECT mode(col, false) FROM VALUES (-10), (0), (10) AS tab(col) +-- !query analysis +Aggregate [mode(col#x, 0, 0, false) AS mode(col, false)#x] ++- SubqueryAlias tab + +- LocalRelation [col#x] + + +-- !query +SELECT mode(col, true) FROM VALUES (-10), (0), (10) AS tab(col) +-- !query analysis +Aggregate [mode(col#x, 0, 0, true) AS mode(col, true)#x] ++- SubqueryAlias tab + +- LocalRelation [col#x] + + +-- !query +SELECT mode(col, 'true') FROM VALUES (-10), (0), (10) AS tab(col) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"true\"", + "inputType" : "\"STRING\"", + "paramIndex" : "2", + "requiredType" : "\"BOOLEAN\"", + "sqlExpr" : "\"mode(col, true)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "mode(col, 'true')" + } ] +} + + +-- !query +SELECT mode(col, null) FROM VALUES (-10), (0), (10) AS tab(col) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_NULL", + "sqlState" : "42K09", + "messageParameters" : { + "exprName" : "`deterministic`", + "sqlExpr" : "\"mode(col, NULL)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 22, + "fragment" : "mode(col, null)" + } ] +} + + +-- !query +SELECT mode(col, b) FROM VALUES (-10, false), (0, false), (10, false) AS tab(col, b) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"b\"", + "inputName" : "`deterministic`", + "inputType" : "\"BOOLEAN\"", + "sqlExpr" : "\"mode(col, b)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 19, + "fragment" : "mode(col, b)" + } ] +} + + +-- !query +SELECT mode(col) FROM VALUES (map(1, 'a')) AS tab(col) +-- !query analysis +Aggregate [mode(col#x, 0, 0, false) AS mode(col, false)#x] ++- SubqueryAlias tab + +- LocalRelation [col#x] + + +-- !query +SELECT mode(col, false) FROM VALUES (map(1, 'a')) AS tab(col) +-- !query analysis +Aggregate [mode(col#x, 0, 0, false) AS mode(col, false)#x] ++- SubqueryAlias tab + +- LocalRelation [col#x] + + +-- !query +SELECT mode(col, true) FROM VALUES (map(1, 'a')) AS tab(col) +-- !query analysis +Aggregate [mode(col#x, 0, 0, true) AS mode(col, true)#x] ++- SubqueryAlias tab + +- LocalRelation [col#x] diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index c35cdb0de27..4b76510b65f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -264,3 +264,14 @@ FROM ( GROUP BY b ) t3 GROUP BY c; + +-- SPARK-45034: Support deterministic mode function +SELECT mode(col) FROM VALUES (-10), (0), (10) AS tab(col); +SELECT mode(col, false) FROM VALUES (-10), (0), (10) AS tab(col); +SELECT mode(col, true) FROM VALUES (-10), (0), (10) AS tab(col); +SELECT mode(col, 'true') FROM VALUES (-10), (0), (10) AS tab(col); +SELECT mode(col, null) FROM VALUES (-10), (0), (10) AS tab(col); +SELECT mode(col, b) FROM VALUES (-10, false), (0, false), (10, false) AS tab(col, b); +SELECT mode(col) FROM VALUES (map(1, 'a')) AS tab(col); +SELECT mode(col, false) FROM VALUES (map(1, 'a')) AS tab(col); +SELECT mode(col, true) FROM VALUES (map(1, 'a')) AS tab(col); diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index db79646fe43..ac92c369de2 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1089,7 +1089,7 @@ struct<a:int,collect_list(b):array<int>,collect_list(b):array<int>> -- !query SELECT mode(a), mode(b) FROM testData -- !query schema -struct<mode(a):int,mode(b):int> +struct<mode(a, false):int,mode(b, false):int> -- !query output 3 1 @@ -1097,7 +1097,7 @@ struct<mode(a):int,mode(b):int> -- !query SELECT a, mode(b) FROM testData GROUP BY a ORDER BY a -- !query schema -struct<a:int,mode(b):int> +struct<a:int,mode(b, false):int> -- !query output NULL 1 1 1 @@ -1121,3 +1121,131 @@ struct<d:int> -- !query output 0 2 + + +-- !query +SELECT mode(col) FROM VALUES (-10), (0), (10) AS tab(col) +-- !query schema +struct<mode(col, false):int> +-- !query output +0 + + +-- !query +SELECT mode(col, false) FROM VALUES (-10), (0), (10) AS tab(col) +-- !query schema +struct<mode(col, false):int> +-- !query output +0 + + +-- !query +SELECT mode(col, true) FROM VALUES (-10), (0), (10) AS tab(col) +-- !query schema +struct<mode(col, true):int> +-- !query output +-10 + + +-- !query +SELECT mode(col, 'true') FROM VALUES (-10), (0), (10) AS tab(col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"true\"", + "inputType" : "\"STRING\"", + "paramIndex" : "2", + "requiredType" : "\"BOOLEAN\"", + "sqlExpr" : "\"mode(col, true)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "mode(col, 'true')" + } ] +} + + +-- !query +SELECT mode(col, null) FROM VALUES (-10), (0), (10) AS tab(col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_NULL", + "sqlState" : "42K09", + "messageParameters" : { + "exprName" : "`deterministic`", + "sqlExpr" : "\"mode(col, NULL)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 22, + "fragment" : "mode(col, null)" + } ] +} + + +-- !query +SELECT mode(col, b) FROM VALUES (-10, false), (0, false), (10, false) AS tab(col, b) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "sqlState" : "42K09", + "messageParameters" : { + "inputExpr" : "\"b\"", + "inputName" : "`deterministic`", + "inputType" : "\"BOOLEAN\"", + "sqlExpr" : "\"mode(col, b)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 19, + "fragment" : "mode(col, b)" + } ] +} + + +-- !query +SELECT mode(col) FROM VALUES (map(1, 'a')) AS tab(col) +-- !query schema +struct<mode(col, false):map<int,string>> +-- !query output +{1:"a"} + + +-- !query +SELECT mode(col, false) FROM VALUES (map(1, 'a')) AS tab(col) +-- !query schema +struct<mode(col, false):map<int,string>> +-- !query output +{1:"a"} + + +-- !query +SELECT mode(col, true) FROM VALUES (map(1, 'a')) AS tab(col) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkIllegalArgumentException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_2005", + "messageParameters" : { + "dataType" : "PhysicalMapType" + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index e9daa825dd4..2de2d90e7dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -432,4 +432,14 @@ class DatasetAggregatorSuite extends QueryTest with SharedSparkSession { val agg = df.select(mode(col("a"))).as[String] checkDataset(agg, "3") } + + test("SPARK-45034: Support deterministic mode function") { + val df = Seq(-10, 0, 10).toDF("col") + + val agg = df.select(mode(col("col"), false)) + checkAnswer(agg, Row(0)) + + val agg2 = df.select(mode(col("col"), true)) + checkAnswer(agg2, Row(-10)) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org