This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new bfb0f016817d [SPARK-46677][CONNECT][FOLLOWUP] Convert `count(df["*"])` to `count(1)` on client side bfb0f016817d is described below commit bfb0f016817d9abfb648bd47f7c5164e6e1004a7 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Jan 16 18:02:36 2024 +0800 [SPARK-46677][CONNECT][FOLLOWUP] Convert `count(df["*"])` to `count(1)` on client side ### What changes were proposed in this pull request? before https://github.com/apache/spark/pull/44689, `df["*"]` and `sf.col("*")` are both convert to `UnresolvedStar`, and then `Count(UnresolvedStar)` is converted to `Count(1)` in Analyzer: https://github.com/apache/spark/blob/381f3691bd481abc8f621ca3f282e06db32bea31/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala#L1893-L1897 in that fix, we introduced a new node `UnresolvedDataFrameStar` for `df["*"]` which will be replaced to `ResolvedStar` later. Unfortunately, it doesn't match `Count(UnresolvedStar)` any more. So it causes: ``` In [1]: from pyspark.sql import functions as sf In [2]: df1 = spark.createDataFrame([{"id": 1, "val": "v"}]) In [3]: df1.select(sf.count(df1["*"])) Out[3]: DataFrame[count(id, val): bigint] ``` which should be ``` In [3]: df1.select(sf.count(df1["*"])) Out[3]: DataFrame[count(1): bigint] ``` In vanilla Spark, it is up to the `count` function to make such conversion `sf.count(df1["*"])` -> `sf.count(sf.lit(1))`, see https://github.com/apache/spark/blob/e8dfcd3081abe16b2115bb2944a2b1cb547eca8e/sql/core/src/main/scala/org/apache/spark/sql/functions.scala#L422-L436 So it is a natural way to fix this behavior on the client side. ### Why are the changes needed? to keep the behavior ### Does this PR introduce _any_ user-facing change? it fix a behavior change introduced in https://github.com/apache/spark/pull/44689 ### How was this patch tested? added ut ### Was this patch authored or co-authored using generative AI tooling? no Closes #44752 from zhengruifeng/connect_fix_count_df_star. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../src/main/scala/org/apache/spark/sql/functions.scala | 9 ++++++++- .../test/resources/query-tests/queries/groupby_agg.json | 3 ++- .../resources/query-tests/queries/groupby_agg.proto.bin | Bin 208 -> 210 bytes python/pyspark/sql/connect/functions/builtin.py | 2 ++ python/pyspark/sql/tests/test_dataframe.py | 15 +++++++++++++++ 5 files changed, 27 insertions(+), 2 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 9191633171f7..2a48958d4222 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -402,7 +402,14 @@ object functions { * @group agg_funcs * @since 3.4.0 */ - def count(e: Column): Column = Column.fn("count", e) + def count(e: Column): Column = { + val withoutStar = e.expr.getExprTypeCase match { + // Turn count(*) into count(1) + case proto.Expression.ExprTypeCase.UNRESOLVED_STAR => lit(1) + case _ => e + } + Column.fn("count", withoutStar) + } /** * Aggregate function: returns the number of items in a group. diff --git a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json index 4a1cfddb0288..65f266794828 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json @@ -81,7 +81,8 @@ "unresolvedFunction": { "functionName": "count", "arguments": [{ - "unresolvedStar": { + "literal": { + "integer": 1 } }] } diff --git a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin index cfd6c2daa84b..18d8c6ce4115 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin differ diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 2eeefc9fae23..1e22a42c6241 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -1010,6 +1010,8 @@ corr.__doc__ = pysparkfuncs.corr.__doc__ def count(col: "ColumnOrName") -> Column: + if isinstance(col, Column) and isinstance(col._expr, UnresolvedStar): + col = lit(1) return _invoke_function_over_columns("count", col) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 407ab22a088c..1788f1d9fb1a 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -104,6 +104,21 @@ class DataFrameTestsMixin: self.assertEqual(df.select(df2["*"]).columns, ["a", "b"]) self.assertEqual(df.select(df3["*"]).columns, ["x", "y"]) + def test_count_star(self): + df1 = self.spark.createDataFrame([{"a": 1}]) + df2 = self.spark.createDataFrame([{"a": 1, "b": "v"}]) + df3 = df2.select(struct("a", "b").alias("s")) + + self.assertEqual(df1.select(count(df1["*"])).columns, ["count(1)"]) + self.assertEqual(df1.select(count(col("*"))).columns, ["count(1)"]) + + self.assertEqual(df2.select(count(df2["*"])).columns, ["count(1)"]) + self.assertEqual(df2.select(count(col("*"))).columns, ["count(1)"]) + + self.assertEqual(df3.select(count(df3["*"])).columns, ["count(1)"]) + self.assertEqual(df3.select(count(col("*"))).columns, ["count(1)"]) + self.assertEqual(df3.select(count(col("s.*"))).columns, ["count(1)"]) + def test_self_join(self): df1 = self.spark.range(10).withColumn("a", lit(0)) df2 = df1.withColumnRenamed("a", "b") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org