This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ef9c8e0d045 [SPARK-41439][CONNECT][PYTHON][FOLLOWUP] Make unpivot of `connect/dataframe.py` consistent with `pyspark/dataframe.py` ef9c8e0d045 is described below commit ef9c8e0d045576fb325ef337319fe6d59b7ce858 Author: Jiaan Geng <belie...@163.com> AuthorDate: Mon Dec 12 08:34:21 2022 +0800 [SPARK-41439][CONNECT][PYTHON][FOLLOWUP] Make unpivot of `connect/dataframe.py` consistent with `pyspark/dataframe.py` ### What changes were proposed in this pull request? This PR lets `unpivot` of `connect/dataframe.py` consistent with `pyspark/dataframe.py` and adds test cases for connect's `unpivot`. This PR follows up https://github.com/apache/spark/pull/38973 ### Why are the changes needed? 1. Lets `unpivot` of `connect/dataframe.py` consistent with `pyspark/dataframe.py` 2. Add test cases for connect's `unpivot`. ### Does this PR introduce _any_ user-facing change? 'No'. New API ### How was this patch tested? New test cases. Closes #39019 from beliefer/SPARK-41439_followup. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/connect/dataframe.py | 22 ++++++++++++++++++--- .../sql/tests/connect/test_connect_basic.py | 23 ++++++++++++++++++++++ .../sql/tests/connect/test_connect_plan_only.py | 2 +- 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 4c1956cc577..08d48bb11f2 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -826,8 +826,8 @@ class DataFrame(object): def unpivot( self, - ids: List["ColumnOrName"], - values: List["ColumnOrName"], + ids: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]], + values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]], variableColumnName: str, valueColumnName: str, ) -> "DataFrame": @@ -852,8 +852,24 @@ class DataFrame(object): ------- :class:`DataFrame` """ + + def to_jcols( + cols: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]] + ) -> List["ColumnOrName"]: + if cols is None: + lst = [] + elif isinstance(cols, tuple): + lst = list(cols) + elif isinstance(cols, list): + lst = cols + else: + lst = [cols] + return lst + return DataFrame.withPlan( - plan.Unpivot(self._plan, ids, values, variableColumnName, valueColumnName), + plan.Unpivot( + self._plan, to_jcols(ids), to_jcols(values), variableColumnName, valueColumnName + ), self._session, ) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 9d49cfd321c..6dabbaedffe 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -783,6 +783,29 @@ class SparkConnectTests(SparkConnectSQLTestCase): """Cannot resolve column name "x" among (a, b, c)""", str(context.exception) ) + def test_unpivot(self): + self.assert_eq( + self.connect.read.table(self.tbl_name) + .filter("id > 3") + .unpivot(["id"], ["name"], "variable", "value") + .toPandas(), + self.spark.read.table(self.tbl_name) + .filter("id > 3") + .unpivot(["id"], ["name"], "variable", "value") + .toPandas(), + ) + + self.assert_eq( + self.connect.read.table(self.tbl_name) + .filter("id > 3") + .unpivot("id", None, "variable", "value") + .toPandas(), + self.spark.read.table(self.tbl_name) + .filter("id > 3") + .unpivot("id", None, "variable", "value") + .toPandas(), + ) + def test_with_columns(self): # SPARK-41256: test withColumn(s). self.assert_eq( diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 83e21e42bad..e0cd54195f3 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -189,7 +189,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): plan = ( df.filter(df.col_name > 3) - .unpivot(["id"], [], "variable", "value") + .unpivot(["id"], None, "variable", "value") ._plan.to_proto(self.connect) ) self.assertTrue(len(plan.root.unpivot.ids) == 1) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org