This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new eb4bb44d446 [SPARK-42099][SPARK-41845][CONNECT][PYTHON] Fix `count(*)` and `count(col(*))` eb4bb44d446 is described below commit eb4bb44d446a0416c360da8127659b10f98e5ceb Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Sat Jan 21 09:48:25 2023 +0800 [SPARK-42099][SPARK-41845][CONNECT][PYTHON] Fix `count(*)` and `count(col(*))` ### What changes were proposed in this pull request? 1, add `UnresolvedStar` to `expressions.py`; 2, Fix `count(*)` and `count(col(*))`, should return `Column(UnresolvedStar(None))` instead of `Column(UnresolvedAttribute("*"))`, see: https://github.com/apache/spark/blob/68531ada34db72d352c39396f85458a8370af812/sql/core/src/main/scala/org/apache/spark/sql/Column.scala#L144-L150 3, remove the `count(*) -> count(1)` transformation in `group.py`, since it's no longer needed. ### Why are the changes needed? https://github.com/apache/spark/pull/39636 fixed the `count(*)` issue in the server side, and then `count(expr(*))` works after that PR. This PR makes the corresponding changes in the Python Client side, in order to support `count(*)`, and `count(col(*))` ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? enabled UT and added UT Closes #39622 from zhengruifeng/connect_count_star. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../main/protobuf/spark/connect/expressions.proto | 9 ++-- .../sql/connect/planner/SparkConnectPlanner.scala | 16 ++++-- .../connect/planner/SparkConnectPlannerSuite.scala | 2 +- python/pyspark/sql/connect/dataframe.py | 4 +- python/pyspark/sql/connect/expressions.py | 46 +++++++++++++++-- python/pyspark/sql/connect/functions.py | 11 ++-- python/pyspark/sql/connect/group.py | 8 +-- .../pyspark/sql/connect/proto/expressions_pb2.py | 30 +++++------ .../pyspark/sql/connect/proto/expressions_pb2.pyi | 31 ++++++++---- .../sql/tests/connect/test_connect_function.py | 58 ++++++++++++++++++++++ 10 files changed, 166 insertions(+), 49 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index 349e2455be3..f7feae0e2f0 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -225,9 +225,12 @@ message Expression { // UnresolvedStar is used to expand all the fields of a relation or struct. message UnresolvedStar { - // (Optional) The target of the expansion, either be a table name or struct name, this - // is a list of identifiers that is the path of the expansion. - repeated string target = 1; + + // (Optional) The target of the expansion. + // + // If set, it should end with '.*' and will be parsed by 'parseAttributeName' + // in the server side. + optional string unparsed_target = 1; } // Represents all of the input attributes to a given relational operator, for example in diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 3d63558eb3e..d72aa162132 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1002,11 +1002,19 @@ class SparkConnectPlanner(session: SparkSession) { session.sessionState.sqlParser.parseExpression(expr.getExpression) } - private def transformUnresolvedStar(regex: proto.Expression.UnresolvedStar): Expression = { - if (regex.getTargetList.isEmpty) { - UnresolvedStar(Option.empty) + private def transformUnresolvedStar(star: proto.Expression.UnresolvedStar): UnresolvedStar = { + if (star.hasUnparsedTarget) { + val target = star.getUnparsedTarget + if (!target.endsWith(".*")) { + throw InvalidPlanInput( + s"UnresolvedStar requires a unparsed target ending with '.*', " + + s"but got $target.") + } + + UnresolvedStar( + Some(UnresolvedAttribute.parseAttributeName(target.substring(0, target.length - 2)))) } else { - UnresolvedStar(Some(regex.getTargetList.asScala.toSeq)) + UnresolvedStar(None) } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 63e5415b44f..d8baa182e5a 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -552,7 +552,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { .addExpressions( proto.Expression .newBuilder() - .setUnresolvedStar(UnresolvedStar.newBuilder().addTarget("a").addTarget("b").build()) + .setUnresolvedStar(UnresolvedStar.newBuilder().setUnparsedTarget("a.b.*").build()) .build()) .build() diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 11c0ef6fc06..d82862a870b 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -249,7 +249,7 @@ class DataFrame: else: return Column(SortOrder(col._expr)) else: - return Column(SortOrder(ColumnReference(name=col))) + return Column(SortOrder(ColumnReference(col))) if isinstance(numPartitions, int): if not numPartitions > 0: @@ -1176,7 +1176,7 @@ class DataFrame: from pyspark.sql.connect.expressions import ColumnReference if isinstance(col, str): - col = Column(ColumnReference(name=col)) + col = Column(ColumnReference(col)) elif not isinstance(col, Column): raise TypeError("col must be a string or a column, but got %r" % type(col)) if not isinstance(fractions, dict): diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 6469c1917ec..c8d361af2a5 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -336,10 +336,10 @@ class ColumnReference(Expression): treat it as an unresolved attribute. Attributes that have the same fully qualified name are identical""" - def __init__(self, name: str) -> None: + def __init__(self, unparsed_identifier: str) -> None: super().__init__() - assert isinstance(name, str) - self._unparsed_identifier = name + assert isinstance(unparsed_identifier, str) + self._unparsed_identifier = unparsed_identifier def name(self) -> str: """Returns the qualified name of the column reference.""" @@ -354,6 +354,43 @@ class ColumnReference(Expression): def __repr__(self) -> str: return f"{self._unparsed_identifier}" + def __eq__(self, other: Any) -> bool: + return ( + other is not None + and isinstance(other, ColumnReference) + and other._unparsed_identifier == self._unparsed_identifier + ) + + +class UnresolvedStar(Expression): + def __init__(self, unparsed_target: Optional[str]): + super().__init__() + + if unparsed_target is not None: + assert isinstance(unparsed_target, str) and unparsed_target.endswith(".*") + + self._unparsed_target = unparsed_target + + def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": + expr = proto.Expression() + expr.unresolved_star.SetInParent() + if self._unparsed_target is not None: + expr.unresolved_star.unparsed_target = self._unparsed_target + return expr + + def __repr__(self) -> str: + if self._unparsed_target is not None: + return f"unresolvedstar({self._unparsed_target})" + else: + return "unresolvedstar()" + + def __eq__(self, other: Any) -> bool: + return ( + other is not None + and isinstance(other, UnresolvedStar) + and other._unparsed_target == self._unparsed_target + ) + class SQLExpression(Expression): """Returns Expression which contains a string which is a SQL expression @@ -370,6 +407,9 @@ class SQLExpression(Expression): expr.expression_string.expression = self._expr return expr + def __eq__(self, other: Any) -> bool: + return other is not None and isinstance(other, SQLExpression) and other._expr == self._expr + class SortOrder(Expression): def __init__(self, child: Expression, ascending: bool = True, nullsFirst: bool = True) -> None: diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 5f1eb9c06d7..c73e6ec1ee4 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -40,6 +40,7 @@ from pyspark.sql.connect.expressions import ( LiteralExpression, ColumnReference, UnresolvedFunction, + UnresolvedStar, SQLExpression, LambdaFunction, UnresolvedNamedLambdaVariable, @@ -186,7 +187,12 @@ def _options_to_col(options: Dict[str, Any]) -> Column: def col(col: str) -> Column: - return Column(ColumnReference(col)) + if col == "*": + return Column(UnresolvedStar(unparsed_target=None)) + elif col.endswith(".*"): + return Column(UnresolvedStar(unparsed_target=col)) + else: + return Column(ColumnReference(unparsed_identifier=col)) col.__doc__ = pysparkfuncs.col.__doc__ @@ -2389,9 +2395,6 @@ def _test() -> None: # TODO(SPARK-41843): Implement SparkSession.udf del pyspark.sql.connect.functions.call_udf.__doc__ - # TODO(SPARK-41845): Fix count bug - del pyspark.sql.connect.functions.count.__doc__ - globs["spark"] = ( PySparkSession.builder.appName("sql.connect.functions tests") .remote("local[4]") diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index 3aa070ff8b6..cc728808d3a 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -80,14 +80,8 @@ class GroupedData: assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): - # There is a special case for count(*) which is rewritten into count(1). # Convert the dict into key value pairs - aggregate_cols = [ - _invoke_function( - exprs[0][k], lit(1) if exprs[0][k] == "count" and k == "*" else col(k) - ) - for k in exprs[0] - ] + aggregate_cols = [_invoke_function(exprs[0][k], col(k)) for k in exprs[0]] else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 462384999bb..87c16964102 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xe8#\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92$\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] ) @@ -262,7 +262,7 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _EXPRESSION._serialized_start = 105 - _EXPRESSION._serialized_end = 4689 + _EXPRESSION._serialized_end = 4731 _EXPRESSION_WINDOW._serialized_start = 1347 _EXPRESSION_WINDOW._serialized_end = 2130 _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1637 @@ -292,17 +292,17 @@ if _descriptor._USE_C_DESCRIPTORS == False: _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3866 _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3916 _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3918 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3958 - _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3960 - _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4004 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4007 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4139 - _EXPRESSION_UPDATEFIELDS._serialized_start = 4142 - _EXPRESSION_UPDATEFIELDS._serialized_end = 4329 - _EXPRESSION_ALIAS._serialized_start = 4331 - _EXPRESSION_ALIAS._serialized_end = 4451 - _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4454 - _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4612 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4614 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4676 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4000 + _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4002 + _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4046 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4049 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4181 + _EXPRESSION_UPDATEFIELDS._serialized_start = 4184 + _EXPRESSION_UPDATEFIELDS._serialized_end = 4371 + _EXPRESSION_ALIAS._serialized_start = 4373 + _EXPRESSION_ALIAS._serialized_end = 4493 + _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4496 + _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4654 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4656 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4718 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 5f64159b854..45889c1518f 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -699,22 +699,33 @@ class Expression(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - TARGET_FIELD_NUMBER: builtins.int - @property - def target( - self, - ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: - """(Optional) The target of the expansion, either be a table name or struct name, this - is a list of identifiers that is the path of the expansion. - """ + UNPARSED_TARGET_FIELD_NUMBER: builtins.int + unparsed_target: builtins.str + """(Optional) The target of the expansion. + + If set, it should end with '.*' and will be parsed by 'parseAttributeName' + in the server side. + """ def __init__( self, *, - target: collections.abc.Iterable[builtins.str] | None = ..., + unparsed_target: builtins.str | None = ..., ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_unparsed_target", b"_unparsed_target", "unparsed_target", b"unparsed_target" + ], + ) -> builtins.bool: ... def ClearField( - self, field_name: typing_extensions.Literal["target", b"target"] + self, + field_name: typing_extensions.Literal[ + "_unparsed_target", b"_unparsed_target", "unparsed_target", b"unparsed_target" + ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_unparsed_target", b"_unparsed_target"] + ) -> typing_extensions.Literal["unparsed_target"] | None: ... class UnresolvedRegex(google.protobuf.message.Message): """Represents all of the input attributes to a given relational operator, for example in diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index 199fd6eb9a9..e1792b03a44 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -71,6 +71,64 @@ class SparkConnectFunctionTests(SparkConnectFuncTestCase): self.assertEqual(str1, str2) + def test_count_star(self): + # SPARK-42099: test count(*), count(col(*)) and count(expr(*)) + + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + + data = [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")] + + cdf = self.connect.createDataFrame(data, schema=["age", "name"]) + sdf = self.spark.createDataFrame(data, schema=["age", "name"]) + + self.assertEqual( + cdf.select(CF.count(CF.expr("*")), CF.count(cdf.age)).collect(), + sdf.select(SF.count(SF.expr("*")), SF.count(sdf.age)).collect(), + ) + + self.assertEqual( + cdf.select(CF.count(CF.col("*")), CF.count(cdf.age)).collect(), + sdf.select(SF.count(SF.col("*")), SF.count(sdf.age)).collect(), + ) + + self.assertEqual( + cdf.select(CF.count("*"), CF.count(cdf.age)).collect(), + sdf.select(SF.count("*"), SF.count(sdf.age)).collect(), + ) + + self.assertEqual( + cdf.groupby("name").agg({"*": "count"}).sort("name").collect(), + sdf.groupby("name").agg({"*": "count"}).sort("name").collect(), + ) + + self.assertEqual( + cdf.groupby("name") + .agg(CF.count(CF.expr("*")), CF.count(cdf.age)) + .sort("name") + .collect(), + sdf.groupby("name") + .agg(SF.count(SF.expr("*")), SF.count(sdf.age)) + .sort("name") + .collect(), + ) + + self.assertEqual( + cdf.groupby("name") + .agg(CF.count(CF.col("*")), CF.count(cdf.age)) + .sort("name") + .collect(), + sdf.groupby("name") + .agg(SF.count(SF.col("*")), SF.count(sdf.age)) + .sort("name") + .collect(), + ) + + self.assertEqual( + cdf.groupby("name").agg(CF.count("*"), CF.count(cdf.age)).sort("name").collect(), + sdf.groupby("name").agg(SF.count("*"), SF.count(sdf.age)).sort("name").collect(), + ) + def test_broadcast(self): from pyspark.sql import functions as SF from pyspark.sql.connect import functions as CF --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org