This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ed2ea7fb50d [SPARK-40915][CONNECT] Improve `on` in Join in Python client ed2ea7fb50d is described below commit ed2ea7fb50d163c1b4b3e18c13ed644bf1358998 Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Tue Nov 1 11:27:43 2022 +0900 [SPARK-40915][CONNECT] Improve `on` in Join in Python client ### What changes were proposed in this pull request? 1. Fix Join's `on` from ANY to concrete types (e.g. str, list of str, column, list of columns) 2. When `on` is str or list of str, it should generate a proto plan with `using_columns` 3. When `on` is column or list of column, it should generate a proto plan with `join_condition`. ### Why are the changes needed? Improve API coverage ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #38393 from amaliujia/python_join_on_improvement. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/dataframe.py | 9 +++++++-- python/pyspark/sql/connect/plan.py | 11 +++++++++-- .../pyspark/sql/tests/connect/test_connect_plan_only.py | 17 +++++++++++++++++ 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 03a766aff30..1ec105f5afd 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -217,8 +217,13 @@ class DataFrame(object): def head(self, n: int) -> Optional["pandas.DataFrame"]: return self.limit(n).toPandas() - # TODO(martin.grund) fix mypu - def join(self, other: "DataFrame", on: Any, how: Optional[str] = None) -> "DataFrame": + # TODO: extend `on` to also be type List[ColumnRef]. + def join( + self, + other: "DataFrame", + on: Optional[Union[str, List[str], ColumnRef]] = None, + how: Optional[str] = None, + ) -> "DataFrame": if self._plan is None: raise Exception("Cannot join when self._plan is empty.") if other._plan is None: diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 9f183c116eb..07fecfb47f3 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -537,7 +537,7 @@ class Join(LogicalPlan): self, left: Optional["LogicalPlan"], right: "LogicalPlan", - on: "ColumnOrString", + on: Optional[Union[str, List[str], ColumnRef]], how: Optional[str], ) -> None: super().__init__(left) @@ -575,7 +575,14 @@ class Join(LogicalPlan): rel = proto.Relation() rel.join.left.CopyFrom(self.left.plan(session)) rel.join.right.CopyFrom(self.right.plan(session)) - rel.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, session)) + if self.on is not None: + if not isinstance(self.on, list): + if isinstance(self.on, str): + rel.join.using_columns.append(self.on) + else: + rel.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, session)) + else: + rel.join.using_columns.extend(self.on) rel.join.join_type = self.how return rel diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 14b939e019b..622340a3ef1 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -37,6 +37,23 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): self.assertIsNotNone(plan.root, "Root relation must be set") self.assertIsNotNone(plan.root.read) + def test_join_using_columns(self): + left_input = self.connect.readTable(table_name=self.tbl_name) + right_input = self.connect.readTable(table_name=self.tbl_name) + plan = left_input.join(other=right_input, on="join_column")._plan.to_proto(self.connect) + self.assertEqual(len(plan.root.join.using_columns), 1) + + plan2 = left_input.join(other=right_input, on=["col1", "col2"])._plan.to_proto(self.connect) + self.assertEqual(len(plan2.root.join.using_columns), 2) + + def test_join_condition(self): + left_input = self.connect.readTable(table_name=self.tbl_name) + right_input = self.connect.readTable(table_name=self.tbl_name) + plan = left_input.join( + other=right_input, on=left_input.name == right_input.name + )._plan.to_proto(self.connect) + self.assertIsNotNone(plan.root.join.join_condition) + def test_filter(self): df = self.connect.readTable(table_name=self.tbl_name) plan = df.filter(df.col_name > 3)._plan.to_proto(self.connect) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org