This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ed2ea7fb50d [SPARK-40915][CONNECT] Improve `on` in Join in Python 
client
ed2ea7fb50d is described below

commit ed2ea7fb50d163c1b4b3e18c13ed644bf1358998
Author: Rui Wang <rui.w...@databricks.com>
AuthorDate: Tue Nov 1 11:27:43 2022 +0900

    [SPARK-40915][CONNECT] Improve `on` in Join in Python client
    
    ### What changes were proposed in this pull request?
    
    1. Fix Join's `on` from ANY to concrete types (e.g. str, list of str, 
column, list of columns)
    2. When `on` is str or list of str, it should generate a proto plan with 
`using_columns`
    3. When `on` is column or list of column, it should generate a proto plan 
with `join_condition`.
    
    ### Why are the changes needed?
    
    Improve API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    UT
    
    Closes #38393 from amaliujia/python_join_on_improvement.
    
    Authored-by: Rui Wang <rui.w...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/dataframe.py                 |  9 +++++++--
 python/pyspark/sql/connect/plan.py                      | 11 +++++++++--
 .../pyspark/sql/tests/connect/test_connect_plan_only.py | 17 +++++++++++++++++
 3 files changed, 33 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 03a766aff30..1ec105f5afd 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -217,8 +217,13 @@ class DataFrame(object):
     def head(self, n: int) -> Optional["pandas.DataFrame"]:
         return self.limit(n).toPandas()
 
-    # TODO(martin.grund) fix mypu
-    def join(self, other: "DataFrame", on: Any, how: Optional[str] = None) -> 
"DataFrame":
+    # TODO: extend `on` to also be type List[ColumnRef].
+    def join(
+        self,
+        other: "DataFrame",
+        on: Optional[Union[str, List[str], ColumnRef]] = None,
+        how: Optional[str] = None,
+    ) -> "DataFrame":
         if self._plan is None:
             raise Exception("Cannot join when self._plan is empty.")
         if other._plan is None:
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 9f183c116eb..07fecfb47f3 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -537,7 +537,7 @@ class Join(LogicalPlan):
         self,
         left: Optional["LogicalPlan"],
         right: "LogicalPlan",
-        on: "ColumnOrString",
+        on: Optional[Union[str, List[str], ColumnRef]],
         how: Optional[str],
     ) -> None:
         super().__init__(left)
@@ -575,7 +575,14 @@ class Join(LogicalPlan):
         rel = proto.Relation()
         rel.join.left.CopyFrom(self.left.plan(session))
         rel.join.right.CopyFrom(self.right.plan(session))
-        rel.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, 
session))
+        if self.on is not None:
+            if not isinstance(self.on, list):
+                if isinstance(self.on, str):
+                    rel.join.using_columns.append(self.on)
+                else:
+                    
rel.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, session))
+            else:
+                rel.join.using_columns.extend(self.on)
         rel.join.join_type = self.how
         return rel
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py 
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 14b939e019b..622340a3ef1 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -37,6 +37,23 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         self.assertIsNotNone(plan.root, "Root relation must be set")
         self.assertIsNotNone(plan.root.read)
 
+    def test_join_using_columns(self):
+        left_input = self.connect.readTable(table_name=self.tbl_name)
+        right_input = self.connect.readTable(table_name=self.tbl_name)
+        plan = left_input.join(other=right_input, 
on="join_column")._plan.to_proto(self.connect)
+        self.assertEqual(len(plan.root.join.using_columns), 1)
+
+        plan2 = left_input.join(other=right_input, on=["col1", 
"col2"])._plan.to_proto(self.connect)
+        self.assertEqual(len(plan2.root.join.using_columns), 2)
+
+    def test_join_condition(self):
+        left_input = self.connect.readTable(table_name=self.tbl_name)
+        right_input = self.connect.readTable(table_name=self.tbl_name)
+        plan = left_input.join(
+            other=right_input, on=left_input.name == right_input.name
+        )._plan.to_proto(self.connect)
+        self.assertIsNotNone(plan.root.join.join_condition)
+
     def test_filter(self):
         df = self.connect.readTable(table_name=self.tbl_name)
         plan = df.filter(df.col_name > 3)._plan.to_proto(self.connect)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to