This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new eb4bb44d446 [SPARK-42099][SPARK-41845][CONNECT][PYTHON] Fix `count(*)` 
and `count(col(*))`
eb4bb44d446 is described below

commit eb4bb44d446a0416c360da8127659b10f98e5ceb
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Sat Jan 21 09:48:25 2023 +0800

    [SPARK-42099][SPARK-41845][CONNECT][PYTHON] Fix `count(*)` and 
`count(col(*))`
    
    ### What changes were proposed in this pull request?
    1, add `UnresolvedStar` to `expressions.py`;
    2, Fix `count(*)` and `count(col(*))`, should return 
`Column(UnresolvedStar(None))` instead of `Column(UnresolvedAttribute("*"))`, 
see: 
https://github.com/apache/spark/blob/68531ada34db72d352c39396f85458a8370af812/sql/core/src/main/scala/org/apache/spark/sql/Column.scala#L144-L150
    3, remove the `count(*) -> count(1)` transformation in `group.py`, since 
it's no longer needed.
    
    ### Why are the changes needed?
    
    https://github.com/apache/spark/pull/39636 fixed the `count(*)` issue in 
the server side, and then `count(expr(*))` works after that PR.
    
    This PR makes the corresponding changes in the Python Client side, in order 
to support `count(*)`, and `count(col(*))`
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    enabled UT and added UT
    
    Closes #39622 from zhengruifeng/connect_count_star.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../main/protobuf/spark/connect/expressions.proto  |  9 ++--
 .../sql/connect/planner/SparkConnectPlanner.scala  | 16 ++++--
 .../connect/planner/SparkConnectPlannerSuite.scala |  2 +-
 python/pyspark/sql/connect/dataframe.py            |  4 +-
 python/pyspark/sql/connect/expressions.py          | 46 +++++++++++++++--
 python/pyspark/sql/connect/functions.py            | 11 ++--
 python/pyspark/sql/connect/group.py                |  8 +--
 .../pyspark/sql/connect/proto/expressions_pb2.py   | 30 +++++------
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  | 31 ++++++++----
 .../sql/tests/connect/test_connect_function.py     | 58 ++++++++++++++++++++++
 10 files changed, 166 insertions(+), 49 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 349e2455be3..f7feae0e2f0 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -225,9 +225,12 @@ message Expression {
 
   // UnresolvedStar is used to expand all the fields of a relation or struct.
   message UnresolvedStar {
-    // (Optional) The target of the expansion, either be a table name or 
struct name, this
-    // is a list of identifiers that is the path of the expansion.
-    repeated string target = 1;
+
+    // (Optional) The target of the expansion.
+    //
+    // If set, it should end with '.*' and will be parsed by 
'parseAttributeName'
+    // in the server side.
+    optional string unparsed_target = 1;
   }
 
   // Represents all of the input attributes to a given relational operator, 
for example in
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 3d63558eb3e..d72aa162132 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1002,11 +1002,19 @@ class SparkConnectPlanner(session: SparkSession) {
     session.sessionState.sqlParser.parseExpression(expr.getExpression)
   }
 
-  private def transformUnresolvedStar(regex: proto.Expression.UnresolvedStar): 
Expression = {
-    if (regex.getTargetList.isEmpty) {
-      UnresolvedStar(Option.empty)
+  private def transformUnresolvedStar(star: proto.Expression.UnresolvedStar): 
UnresolvedStar = {
+    if (star.hasUnparsedTarget) {
+      val target = star.getUnparsedTarget
+      if (!target.endsWith(".*")) {
+        throw InvalidPlanInput(
+          s"UnresolvedStar requires a unparsed target ending with '.*', " +
+            s"but got $target.")
+      }
+
+      UnresolvedStar(
+        Some(UnresolvedAttribute.parseAttributeName(target.substring(0, 
target.length - 2))))
     } else {
-      UnresolvedStar(Some(regex.getTargetList.asScala.toSeq))
+      UnresolvedStar(None)
     }
   }
 
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 63e5415b44f..d8baa182e5a 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -552,7 +552,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with 
SparkConnectPlanTest {
         .addExpressions(
           proto.Expression
             .newBuilder()
-            
.setUnresolvedStar(UnresolvedStar.newBuilder().addTarget("a").addTarget("b").build())
+            
.setUnresolvedStar(UnresolvedStar.newBuilder().setUnparsedTarget("a.b.*").build())
             .build())
         .build()
 
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 11c0ef6fc06..d82862a870b 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -249,7 +249,7 @@ class DataFrame:
                 else:
                     return Column(SortOrder(col._expr))
             else:
-                return Column(SortOrder(ColumnReference(name=col)))
+                return Column(SortOrder(ColumnReference(col)))
 
         if isinstance(numPartitions, int):
             if not numPartitions > 0:
@@ -1176,7 +1176,7 @@ class DataFrame:
         from pyspark.sql.connect.expressions import ColumnReference
 
         if isinstance(col, str):
-            col = Column(ColumnReference(name=col))
+            col = Column(ColumnReference(col))
         elif not isinstance(col, Column):
             raise TypeError("col must be a string or a column, but got %r" % 
type(col))
         if not isinstance(fractions, dict):
diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index 6469c1917ec..c8d361af2a5 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -336,10 +336,10 @@ class ColumnReference(Expression):
     treat it as an unresolved attribute. Attributes that have the same fully
     qualified name are identical"""
 
-    def __init__(self, name: str) -> None:
+    def __init__(self, unparsed_identifier: str) -> None:
         super().__init__()
-        assert isinstance(name, str)
-        self._unparsed_identifier = name
+        assert isinstance(unparsed_identifier, str)
+        self._unparsed_identifier = unparsed_identifier
 
     def name(self) -> str:
         """Returns the qualified name of the column reference."""
@@ -354,6 +354,43 @@ class ColumnReference(Expression):
     def __repr__(self) -> str:
         return f"{self._unparsed_identifier}"
 
+    def __eq__(self, other: Any) -> bool:
+        return (
+            other is not None
+            and isinstance(other, ColumnReference)
+            and other._unparsed_identifier == self._unparsed_identifier
+        )
+
+
+class UnresolvedStar(Expression):
+    def __init__(self, unparsed_target: Optional[str]):
+        super().__init__()
+
+        if unparsed_target is not None:
+            assert isinstance(unparsed_target, str) and 
unparsed_target.endswith(".*")
+
+        self._unparsed_target = unparsed_target
+
+    def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
+        expr = proto.Expression()
+        expr.unresolved_star.SetInParent()
+        if self._unparsed_target is not None:
+            expr.unresolved_star.unparsed_target = self._unparsed_target
+        return expr
+
+    def __repr__(self) -> str:
+        if self._unparsed_target is not None:
+            return f"unresolvedstar({self._unparsed_target})"
+        else:
+            return "unresolvedstar()"
+
+    def __eq__(self, other: Any) -> bool:
+        return (
+            other is not None
+            and isinstance(other, UnresolvedStar)
+            and other._unparsed_target == self._unparsed_target
+        )
+
 
 class SQLExpression(Expression):
     """Returns Expression which contains a string which is a SQL expression
@@ -370,6 +407,9 @@ class SQLExpression(Expression):
         expr.expression_string.expression = self._expr
         return expr
 
+    def __eq__(self, other: Any) -> bool:
+        return other is not None and isinstance(other, SQLExpression) and 
other._expr == self._expr
+
 
 class SortOrder(Expression):
     def __init__(self, child: Expression, ascending: bool = True, nullsFirst: 
bool = True) -> None:
diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index 5f1eb9c06d7..c73e6ec1ee4 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -40,6 +40,7 @@ from pyspark.sql.connect.expressions import (
     LiteralExpression,
     ColumnReference,
     UnresolvedFunction,
+    UnresolvedStar,
     SQLExpression,
     LambdaFunction,
     UnresolvedNamedLambdaVariable,
@@ -186,7 +187,12 @@ def _options_to_col(options: Dict[str, Any]) -> Column:
 
 
 def col(col: str) -> Column:
-    return Column(ColumnReference(col))
+    if col == "*":
+        return Column(UnresolvedStar(unparsed_target=None))
+    elif col.endswith(".*"):
+        return Column(UnresolvedStar(unparsed_target=col))
+    else:
+        return Column(ColumnReference(unparsed_identifier=col))
 
 
 col.__doc__ = pysparkfuncs.col.__doc__
@@ -2389,9 +2395,6 @@ def _test() -> None:
         # TODO(SPARK-41843): Implement SparkSession.udf
         del pyspark.sql.connect.functions.call_udf.__doc__
 
-        # TODO(SPARK-41845): Fix count bug
-        del pyspark.sql.connect.functions.count.__doc__
-
         globs["spark"] = (
             PySparkSession.builder.appName("sql.connect.functions tests")
             .remote("local[4]")
diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index 3aa070ff8b6..cc728808d3a 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -80,14 +80,8 @@ class GroupedData:
 
         assert exprs, "exprs should not be empty"
         if len(exprs) == 1 and isinstance(exprs[0], dict):
-            # There is a special case for count(*) which is rewritten into 
count(1).
             # Convert the dict into key value pairs
-            aggregate_cols = [
-                _invoke_function(
-                    exprs[0][k], lit(1) if exprs[0][k] == "count" and k == "*" 
else col(k)
-                )
-                for k in exprs[0]
-            ]
+            aggregate_cols = [_invoke_function(exprs[0][k], col(k)) for k in 
exprs[0]]
         else:
             # Columns
             assert all(isinstance(c, Column) for c in exprs), "all exprs 
should be Column"
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 462384999bb..87c16964102 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xe8#\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
 [...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92$\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
 [...]
 )
 
 
@@ -262,7 +262,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = 
b"\n\036org.apache.spark.connect.protoP\001"
     _EXPRESSION._serialized_start = 105
-    _EXPRESSION._serialized_end = 4689
+    _EXPRESSION._serialized_end = 4731
     _EXPRESSION_WINDOW._serialized_start = 1347
     _EXPRESSION_WINDOW._serialized_end = 2130
     _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1637
@@ -292,17 +292,17 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3866
     _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3916
     _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3918
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3958
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3960
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4004
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4007
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4139
-    _EXPRESSION_UPDATEFIELDS._serialized_start = 4142
-    _EXPRESSION_UPDATEFIELDS._serialized_end = 4329
-    _EXPRESSION_ALIAS._serialized_start = 4331
-    _EXPRESSION_ALIAS._serialized_end = 4451
-    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4454
-    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4612
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4614
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4676
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4000
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4002
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4046
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4049
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4181
+    _EXPRESSION_UPDATEFIELDS._serialized_start = 4184
+    _EXPRESSION_UPDATEFIELDS._serialized_end = 4371
+    _EXPRESSION_ALIAS._serialized_start = 4373
+    _EXPRESSION_ALIAS._serialized_end = 4493
+    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4496
+    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4654
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4656
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4718
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 5f64159b854..45889c1518f 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -699,22 +699,33 @@ class Expression(google.protobuf.message.Message):
 
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
-        TARGET_FIELD_NUMBER: builtins.int
-        @property
-        def target(
-            self,
-        ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
-            """(Optional) The target of the expansion, either be a table name 
or struct name, this
-            is a list of identifiers that is the path of the expansion.
-            """
+        UNPARSED_TARGET_FIELD_NUMBER: builtins.int
+        unparsed_target: builtins.str
+        """(Optional) The target of the expansion.
+
+        If set, it should end with '.*' and will be parsed by 
'parseAttributeName'
+        in the server side.
+        """
         def __init__(
             self,
             *,
-            target: collections.abc.Iterable[builtins.str] | None = ...,
+            unparsed_target: builtins.str | None = ...,
         ) -> None: ...
+        def HasField(
+            self,
+            field_name: typing_extensions.Literal[
+                "_unparsed_target", b"_unparsed_target", "unparsed_target", 
b"unparsed_target"
+            ],
+        ) -> builtins.bool: ...
         def ClearField(
-            self, field_name: typing_extensions.Literal["target", b"target"]
+            self,
+            field_name: typing_extensions.Literal[
+                "_unparsed_target", b"_unparsed_target", "unparsed_target", 
b"unparsed_target"
+            ],
         ) -> None: ...
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_unparsed_target", 
b"_unparsed_target"]
+        ) -> typing_extensions.Literal["unparsed_target"] | None: ...
 
     class UnresolvedRegex(google.protobuf.message.Message):
         """Represents all of the input attributes to a given relational 
operator, for example in
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py 
b/python/pyspark/sql/tests/connect/test_connect_function.py
index 199fd6eb9a9..e1792b03a44 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -71,6 +71,64 @@ class SparkConnectFunctionTests(SparkConnectFuncTestCase):
 
         self.assertEqual(str1, str2)
 
+    def test_count_star(self):
+        # SPARK-42099: test count(*), count(col(*)) and count(expr(*))
+
+        from pyspark.sql import functions as SF
+        from pyspark.sql.connect import functions as CF
+
+        data = [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")]
+
+        cdf = self.connect.createDataFrame(data, schema=["age", "name"])
+        sdf = self.spark.createDataFrame(data, schema=["age", "name"])
+
+        self.assertEqual(
+            cdf.select(CF.count(CF.expr("*")), CF.count(cdf.age)).collect(),
+            sdf.select(SF.count(SF.expr("*")), SF.count(sdf.age)).collect(),
+        )
+
+        self.assertEqual(
+            cdf.select(CF.count(CF.col("*")), CF.count(cdf.age)).collect(),
+            sdf.select(SF.count(SF.col("*")), SF.count(sdf.age)).collect(),
+        )
+
+        self.assertEqual(
+            cdf.select(CF.count("*"), CF.count(cdf.age)).collect(),
+            sdf.select(SF.count("*"), SF.count(sdf.age)).collect(),
+        )
+
+        self.assertEqual(
+            cdf.groupby("name").agg({"*": "count"}).sort("name").collect(),
+            sdf.groupby("name").agg({"*": "count"}).sort("name").collect(),
+        )
+
+        self.assertEqual(
+            cdf.groupby("name")
+            .agg(CF.count(CF.expr("*")), CF.count(cdf.age))
+            .sort("name")
+            .collect(),
+            sdf.groupby("name")
+            .agg(SF.count(SF.expr("*")), SF.count(sdf.age))
+            .sort("name")
+            .collect(),
+        )
+
+        self.assertEqual(
+            cdf.groupby("name")
+            .agg(CF.count(CF.col("*")), CF.count(cdf.age))
+            .sort("name")
+            .collect(),
+            sdf.groupby("name")
+            .agg(SF.count(SF.col("*")), SF.count(sdf.age))
+            .sort("name")
+            .collect(),
+        )
+
+        self.assertEqual(
+            cdf.groupby("name").agg(CF.count("*"), 
CF.count(cdf.age)).sort("name").collect(),
+            sdf.groupby("name").agg(SF.count("*"), 
SF.count(sdf.age)).sort("name").collect(),
+        )
+
     def test_broadcast(self):
         from pyspark.sql import functions as SF
         from pyspark.sql.connect import functions as CF


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to