This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new bfb0f016817d [SPARK-46677][CONNECT][FOLLOWUP] Convert `count(df["*"])` 
to `count(1)` on client side
bfb0f016817d is described below

commit bfb0f016817d9abfb648bd47f7c5164e6e1004a7
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Tue Jan 16 18:02:36 2024 +0800

    [SPARK-46677][CONNECT][FOLLOWUP] Convert `count(df["*"])` to `count(1)` on 
client side
    
    ### What changes were proposed in this pull request?
    before https://github.com/apache/spark/pull/44689, `df["*"]` and 
`sf.col("*")` are both convert to `UnresolvedStar`, and then 
`Count(UnresolvedStar)` is converted to `Count(1)` in Analyzer:
    
https://github.com/apache/spark/blob/381f3691bd481abc8f621ca3f282e06db32bea31/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala#L1893-L1897
    
    in that fix, we introduced a new node `UnresolvedDataFrameStar` for 
`df["*"]` which will be replaced to `ResolvedStar` later. Unfortunately, it 
doesn't match `Count(UnresolvedStar)` any more.
    So it causes:
    ```
    In [1]: from pyspark.sql import functions as sf
    
    In [2]: df1 = spark.createDataFrame([{"id": 1, "val": "v"}])
    
    In [3]: df1.select(sf.count(df1["*"]))
    Out[3]: DataFrame[count(id, val): bigint]
    ```
    
    which should be
    ```
    In [3]: df1.select(sf.count(df1["*"]))
    Out[3]: DataFrame[count(1): bigint]
    ```
    
    In vanilla Spark, it is up to the `count` function to make such conversion 
`sf.count(df1["*"])` -> `sf.count(sf.lit(1))`, see
    
    
https://github.com/apache/spark/blob/e8dfcd3081abe16b2115bb2944a2b1cb547eca8e/sql/core/src/main/scala/org/apache/spark/sql/functions.scala#L422-L436
    
    So it is a natural way to fix this behavior on the client side.
    
    ### Why are the changes needed?
    to keep the behavior
    
    ### Does this PR introduce _any_ user-facing change?
    it fix a behavior change introduced in 
https://github.com/apache/spark/pull/44689
    
    ### How was this patch tested?
    added ut
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #44752 from zhengruifeng/connect_fix_count_df_star.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../src/main/scala/org/apache/spark/sql/functions.scala |   9 ++++++++-
 .../test/resources/query-tests/queries/groupby_agg.json |   3 ++-
 .../resources/query-tests/queries/groupby_agg.proto.bin | Bin 208 -> 210 bytes
 python/pyspark/sql/connect/functions/builtin.py         |   2 ++
 python/pyspark/sql/tests/test_dataframe.py              |  15 +++++++++++++++
 5 files changed, 27 insertions(+), 2 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index 9191633171f7..2a48958d4222 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -402,7 +402,14 @@ object functions {
    * @group agg_funcs
    * @since 3.4.0
    */
-  def count(e: Column): Column = Column.fn("count", e)
+  def count(e: Column): Column = {
+    val withoutStar = e.expr.getExprTypeCase match {
+      // Turn count(*) into count(1)
+      case proto.Expression.ExprTypeCase.UNRESOLVED_STAR => lit(1)
+      case _ => e
+    }
+    Column.fn("count", withoutStar)
+  }
 
   /**
    * Aggregate function: returns the number of items in a group.
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json
 
b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json
index 4a1cfddb0288..65f266794828 100644
--- 
a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json
@@ -81,7 +81,8 @@
       "unresolvedFunction": {
         "functionName": "count",
         "arguments": [{
-          "unresolvedStar": {
+          "literal": {
+            "integer": 1
           }
         }]
       }
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin
index cfd6c2daa84b..18d8c6ce4115 100644
Binary files 
a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin
 and 
b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin
 differ
diff --git a/python/pyspark/sql/connect/functions/builtin.py 
b/python/pyspark/sql/connect/functions/builtin.py
index 2eeefc9fae23..1e22a42c6241 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -1010,6 +1010,8 @@ corr.__doc__ = pysparkfuncs.corr.__doc__
 
 
 def count(col: "ColumnOrName") -> Column:
+    if isinstance(col, Column) and isinstance(col._expr, UnresolvedStar):
+        col = lit(1)
     return _invoke_function_over_columns("count", col)
 
 
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index 407ab22a088c..1788f1d9fb1a 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -104,6 +104,21 @@ class DataFrameTestsMixin:
         self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])
         self.assertEqual(df.select(df3["*"]).columns, ["x", "y"])
 
+    def test_count_star(self):
+        df1 = self.spark.createDataFrame([{"a": 1}])
+        df2 = self.spark.createDataFrame([{"a": 1, "b": "v"}])
+        df3 = df2.select(struct("a", "b").alias("s"))
+
+        self.assertEqual(df1.select(count(df1["*"])).columns, ["count(1)"])
+        self.assertEqual(df1.select(count(col("*"))).columns, ["count(1)"])
+
+        self.assertEqual(df2.select(count(df2["*"])).columns, ["count(1)"])
+        self.assertEqual(df2.select(count(col("*"))).columns, ["count(1)"])
+
+        self.assertEqual(df3.select(count(df3["*"])).columns, ["count(1)"])
+        self.assertEqual(df3.select(count(col("*"))).columns, ["count(1)"])
+        self.assertEqual(df3.select(count(col("s.*"))).columns, ["count(1)"])
+
     def test_self_join(self):
         df1 = self.spark.range(10).withColumn("a", lit(0))
         df2 = df1.withColumnRenamed("a", "b")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to