This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5facaece4df [SPARK-44106][PYTHON][CONNECT] Add `__repr__` for `GroupedData` 5facaece4df is described below commit 5facaece4dfa1fa45e8c8f7bd7d92f11e2c91fd8 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed Jun 21 14:48:51 2023 +0800 [SPARK-44106][PYTHON][CONNECT] Add `__repr__` for `GroupedData` ### What changes were proposed in this pull request? Add `__repr__` for `GroupedData` ### Why are the changes needed? `GroupedData.__repr__` is missing ### Does this PR introduce _any_ user-facing change? yes 1. On Scala side: ``` scala> val df = Seq(("414243", "4243")).toDF("e", "f") df: org.apache.spark.sql.DataFrame = [e: string, f: string] scala> df.groupBy("e") res0: org.apache.spark.sql.RelationalGroupedDataset = RelationalGroupedDataset: [grouping expressions: [e: string], value: [e: string, f: string], type: GroupBy] scala> df.groupBy(df.col("e")) res1: org.apache.spark.sql.RelationalGroupedDataset = RelationalGroupedDataset: [grouping expressions: [e: string], value: [e: string, f: string], type: GroupBy] ``` 2. On vanilla PySpark: before this PR: ``` In [1]: df = spark.createDataFrame([("414243", "4243",)], ["e", "f"]) In [2]: df Out[2]: DataFrame[e: string, f: string] In [3]: df.groupBy("e") Out[3]: <pyspark.sql.group.GroupedData at 0x10423a4c0> In [4]: df.groupBy(df.e) Out[4]: <pyspark.sql.group.GroupedData at 0x1041dd640> ``` after this PR: ``` In [1]: df = spark.createDataFrame([("414243", "4243",)], ["e", "f"]) In [2]: df Out[2]: DataFrame[e: string, f: string] In [3]: df.groupBy("e") Out[3]: GroupedData[grouping expressions: [e], value: [e: string, f: string], type: GroupBy] In [4]: df.groupBy(df.e) Out[4]: GroupedData[grouping expressions: [e: string], value: [e: string, f: string], type: GroupBy] ``` 3. On Spark Connect Python Client: before this PR: ``` In [1]: df = spark.createDataFrame([("414243", "4243",)], ["e", "f"]) In [2]: df Out[2]: DataFrame[e: string, f: string] In [3]: df.groupBy("e") Out[3]: <pyspark.sql.connect.group.GroupedData at 0x1046157c0> In [4]: df.groupBy(df.e) Out[4]: <pyspark.sql.connect.group.GroupedData at 0x11da5ceb0> ``` after this PR: ``` In [1]: df = spark.createDataFrame([("414243", "4243",)], ["e", "f"]) In [2]: df Out[2]: DataFrame[e: string, f: string] In [3]: df.groupBy("e") Out[3]: GroupedData[grouping expressions: [e], value: [e: string, f: string], type: GroupBy] In [4]: df.groupBy(df.e) Out[4]: GroupedData[grouping expressions: [e], value: [e: string, f: string], type: GroupBy] // different from vanilla PySpark ``` Note that since the expressions in Python Client are not resolved, the string can be different from vanilla PySpark. ### How was this patch tested? added doctests Closes #41674 from zhengruifeng/group_repr. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/connect/group.py | 19 +++++++++++++++++++ python/pyspark/sql/group.py | 11 +++++++++++ 2 files changed, 30 insertions(+) diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index e75c8029ef2..a393d2cb37e 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -83,6 +83,25 @@ class GroupedData: self._pivot_col = pivot_col self._pivot_values = pivot_values + def __repr__(self) -> str: + # the expressions are not resolved here, + # so the string representation can be different from vanilla PySpark. + grouping_str = ", ".join(str(e._expr) for e in self._grouping_cols) + grouping_str = f"grouping expressions: [{grouping_str}]" + + value_str = ", ".join("%s: %s" % c for c in self._df.dtypes) + + if self._group_type == "groupby": + type_str = "GroupBy" + elif self._group_type == "rollup": + type_str = "RollUp" + elif self._group_type == "cube": + type_str = "Cube" + else: + type_str = "Pivot" + + return f"GroupedData[{grouping_str}, value: [{value_str}], type: {type_str}]" + @overload def agg(self, *exprs: Column) -> "DataFrame": ... diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index e33e3d6ec5e..9568a971229 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -70,6 +70,14 @@ class GroupedData(PandasGroupedOpsMixin): self._df = df self.session: SparkSession = df.sparkSession + def __repr__(self) -> str: + index = 26 # index to truncate string from the JVM side + jvm_string = self._jgd.toString() + if jvm_string is not None and len(jvm_string) > index and jvm_string[index] == "[": + return f"GroupedData{jvm_string[index:]}" + else: + return super().__repr__() + @overload def agg(self, *exprs: Column) -> DataFrame: ... @@ -133,6 +141,9 @@ class GroupedData(PandasGroupedOpsMixin): Group-by name, and count each group. + >>> df.groupBy(df.name) + GroupedData[grouping...: [name...], value: [age: bigint, name: string], type: GroupBy] + >>> df.groupBy(df.name).agg({"*": "count"}).sort("name").show() +-----+--------+ | name|count(1)| --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org