This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a27ccd78875 [SPARK-40879][CONNECT] Support Join UsingColumns in proto
a27ccd78875 is described below

commit a27ccd788750c1b1394c8274f79643cb2ad6cf49
Author: Rui Wang <rui.w...@databricks.com>
AuthorDate: Tue Oct 25 23:49:02 2022 +0800

    [SPARK-40879][CONNECT] Support Join UsingColumns in proto
    
    ### What changes were proposed in this pull request?
    
    I was working on refactoring Connect proto tests from Catalyst DSL to 
DataFrame API, and identified that Join in Connect does not support 
`UsingColumns`. This is a gap between the Connect proto and DataFrame API. This 
also blocks the refactoring work because without `UsingColumns`, there is no 
compatible DataFrame Join API that we can covert existing tests to.
    
    This PR adds the support for Join's `UsingColumns`.
    
    ### Why are the changes needed?
    
    1. Improve API coverage.
    2. Unblock testing refactoring.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    UT
    
    Closes #38345 from amaliujia/proto-join-using-columns.
    
    Authored-by: Rui Wang <rui.w...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../main/protobuf/spark/connect/relations.proto    |  6 ++
 .../org/apache/spark/sql/connect/dsl/package.scala | 32 ++++++++++-
 .../sql/connect/planner/SparkConnectPlanner.scala  | 17 ++++--
 .../connect/planner/SparkConnectPlannerSuite.scala | 14 +++++
 .../connect/planner/SparkConnectProtoSuite.scala   | 16 +++++-
 python/pyspark/sql/connect/proto/relations_pb2.py  | 64 +++++++++++-----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi | 14 +++++
 7 files changed, 123 insertions(+), 40 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index 7dbde775ee8..94010487ee5 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -109,6 +109,12 @@ message Join {
   Relation right = 2;
   Expression join_condition = 3;
   JoinType join_type = 4;
+  // Optional. using_columns provides a list of columns that should present on 
both sides of
+  // the join inputs that this Join will join on. For example A JOIN B USING 
col_name is
+  // equivalent to A JOIN B on A.col_name = B.col_name.
+  //
+  // This field does not co-exist with join_condition.
+  repeated string using_columns = 5;
 
   enum JoinType {
     JOIN_TYPE_UNSPECIFIED = 0;
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 4630c86049c..6ae6dfa1577 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -236,15 +236,45 @@ package object dsl {
           .build()
 
       def join(
+          otherPlan: proto.Relation,
+          joinType: JoinType,
+          condition: Option[proto.Expression]): proto.Relation = {
+        join(otherPlan, joinType, Seq(), condition)
+      }
+
+      def join(otherPlan: proto.Relation, condition: 
Option[proto.Expression]): proto.Relation = {
+        join(otherPlan, JoinType.JOIN_TYPE_INNER, Seq(), condition)
+      }
+
+      def join(otherPlan: proto.Relation): proto.Relation = {
+        join(otherPlan, JoinType.JOIN_TYPE_INNER, Seq(), None)
+      }
+
+      def join(otherPlan: proto.Relation, joinType: JoinType): proto.Relation 
= {
+        join(otherPlan, joinType, Seq(), None)
+      }
+
+      def join(
+          otherPlan: proto.Relation,
+          joinType: JoinType,
+          usingColumns: Seq[String]): proto.Relation = {
+        join(otherPlan, joinType, usingColumns, None)
+      }
+
+      private def join(
           otherPlan: proto.Relation,
           joinType: JoinType = JoinType.JOIN_TYPE_INNER,
-          condition: Option[proto.Expression] = None): proto.Relation = {
+          usingColumns: Seq[String],
+          condition: Option[proto.Expression]): proto.Relation = {
         val relation = proto.Relation.newBuilder()
         val join = proto.Join.newBuilder()
         join
           .setLeft(logicalPlan)
           .setRight(otherPlan)
           .setJoinType(joinType)
+        if (usingColumns.nonEmpty) {
+          join.addAllUsingColumns(usingColumns.asJava)
+        }
         if (condition.isDefined) {
           join.setJoinCondition(condition.get)
         }
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 880618cc333..9e3899f4a1a 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -26,7 +26,7 @@ import 
org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttrib
 import org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeReference, Expression}
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
-import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, 
JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
+import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, 
JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, LogicalPlan, 
Sample, SubqueryAlias}
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 import org.apache.spark.sql.execution.QueryExecution
@@ -292,14 +292,23 @@ class SparkConnectPlanner(plan: proto.Relation, session: 
SparkSession) {
 
   private def transformJoin(rel: proto.Join): LogicalPlan = {
     assert(rel.hasLeft && rel.hasRight, "Both join sides must be present")
+    if (rel.hasJoinCondition && rel.getUsingColumnsCount > 0) {
+      throw InvalidPlanInput(
+        s"Using columns or join conditions cannot be set at the same time in 
Join")
+    }
     val joinCondition =
       if (rel.hasJoinCondition) 
Some(transformExpression(rel.getJoinCondition)) else None
-
+    val catalystJointype = transformJoinType(
+      if (rel.getJoinType != null) rel.getJoinType else 
proto.Join.JoinType.JOIN_TYPE_INNER)
+    val joinType = if (rel.getUsingColumnsCount > 0) {
+      UsingJoin(catalystJointype, rel.getUsingColumnsList.asScala.toSeq)
+    } else {
+      catalystJointype
+    }
     logical.Join(
       left = transformRelation(rel.getLeft),
       right = transformRelation(rel.getRight),
-      joinType = transformJoinType(
-        if (rel.getJoinType != null) rel.getJoinType else 
proto.Join.JoinType.JOIN_TYPE_INNER),
+      joinType = joinType,
       condition = joinCondition,
       hint = logical.JoinHint.NONE)
   }
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 980e899c26e..6fc47e07c59 100644
--- 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -220,6 +220,20 @@ class SparkConnectPlannerSuite extends SparkFunSuite with 
SparkConnectPlanTest {
     assert(res.nodeName == "Join")
     assert(res != null)
 
+    val e = intercept[InvalidPlanInput] {
+      val simpleJoin = proto.Relation.newBuilder
+        .setJoin(
+          proto.Join.newBuilder
+            .setLeft(readRel)
+            .setRight(readRel)
+            .addUsingColumns("test_col")
+            .setJoinCondition(joinCondition))
+        .build()
+      transform(simpleJoin)
+    }
+    assert(
+      e.getMessage.contains(
+        "Using columns or join conditions cannot be set at the same time in 
Join"))
   }
 
   test("Simple Projection") {
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index d8bb1684cb8..0325b6573bd 100644
--- 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -20,7 +20,7 @@ import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.Join.JoinType
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, 
LeftOuter, LeftSemi, PlanTest, RightOuter}
+import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, 
LeftOuter, LeftSemi, PlanTest, RightOuter, UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 
 /**
@@ -32,11 +32,13 @@ class SparkConnectProtoSuite extends PlanTest with 
SparkConnectPlanTest {
 
   lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int, 
$"name".string))
 
-  lazy val connectTestRelation2 = createLocalRelationProto(Seq($"key".int, 
$"value".int))
+  lazy val connectTestRelation2 = createLocalRelationProto(
+    Seq($"key".int, $"value".int, $"name".string))
 
   lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int, 
$"name".string)
 
-  lazy val sparkTestRelation2: LocalRelation = LocalRelation($"key".int, 
$"value".int)
+  lazy val sparkTestRelation2: LocalRelation =
+    LocalRelation($"key".int, $"value".int, $"name".string)
 
   test("Basic select") {
     val connectPlan = {
@@ -117,6 +119,14 @@ class SparkConnectProtoSuite extends PlanTest with 
SparkConnectPlanTest {
       val sparkPlan3 = sparkTestRelation.join(sparkTestRelation2, y)
       comparePlans(connectPlan3.analyze, sparkPlan3.analyze, false)
     }
+
+    val connectPlan4 = {
+      import org.apache.spark.sql.connect.dsl.plans._
+      transform(
+        connectTestRelation.join(connectTestRelation2, 
JoinType.JOIN_TYPE_INNER, Seq("name")))
+    }
+    val sparkPlan4 = sparkTestRelation.join(sparkTestRelation2, 
UsingJoin(Inner, Seq("name")))
+    comparePlans(connectPlan4.analyze, sparkPlan4.analyze, false)
   }
 
   test("Test sample") {
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index d9a596fba8c..2a38a014926 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as 
spark_dot_connect_dot_e
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8f\x06\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
 \x01(\x0 [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8f\x06\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
 \x01(\x0 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -64,35 +64,35 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _FILTER._serialized_start = 1512
     _FILTER._serialized_end = 1624
     _JOIN._serialized_start = 1627
-    _JOIN._serialized_end = 2040
-    _JOIN_JOINTYPE._serialized_start = 1853
-    _JOIN_JOINTYPE._serialized_end = 2040
-    _UNION._serialized_start = 2043
-    _UNION._serialized_end = 2248
-    _UNION_UNIONTYPE._serialized_start = 2164
-    _UNION_UNIONTYPE._serialized_end = 2248
-    _LIMIT._serialized_start = 2250
-    _LIMIT._serialized_end = 2326
-    _OFFSET._serialized_start = 2328
-    _OFFSET._serialized_end = 2407
-    _AGGREGATE._serialized_start = 2410
-    _AGGREGATE._serialized_end = 2735
-    _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2639
-    _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2735
-    _SORT._serialized_start = 2738
-    _SORT._serialized_end = 3240
-    _SORT_SORTFIELD._serialized_start = 2858
-    _SORT_SORTFIELD._serialized_end = 3046
-    _SORT_SORTDIRECTION._serialized_start = 3048
-    _SORT_SORTDIRECTION._serialized_end = 3156
-    _SORT_SORTNULLS._serialized_start = 3158
-    _SORT_SORTNULLS._serialized_end = 3240
-    _DEDUPLICATE._serialized_start = 3243
-    _DEDUPLICATE._serialized_end = 3385
-    _LOCALRELATION._serialized_start = 3387
-    _LOCALRELATION._serialized_end = 3480
-    _SAMPLE._serialized_start = 3483
-    _SAMPLE._serialized_end = 3723
-    _SAMPLE_SEED._serialized_start = 3697
-    _SAMPLE_SEED._serialized_end = 3723
+    _JOIN._serialized_end = 2077
+    _JOIN_JOINTYPE._serialized_start = 1890
+    _JOIN_JOINTYPE._serialized_end = 2077
+    _UNION._serialized_start = 2080
+    _UNION._serialized_end = 2285
+    _UNION_UNIONTYPE._serialized_start = 2201
+    _UNION_UNIONTYPE._serialized_end = 2285
+    _LIMIT._serialized_start = 2287
+    _LIMIT._serialized_end = 2363
+    _OFFSET._serialized_start = 2365
+    _OFFSET._serialized_end = 2444
+    _AGGREGATE._serialized_start = 2447
+    _AGGREGATE._serialized_end = 2772
+    _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2676
+    _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2772
+    _SORT._serialized_start = 2775
+    _SORT._serialized_end = 3277
+    _SORT_SORTFIELD._serialized_start = 2895
+    _SORT_SORTFIELD._serialized_end = 3083
+    _SORT_SORTDIRECTION._serialized_start = 3085
+    _SORT_SORTDIRECTION._serialized_end = 3193
+    _SORT_SORTNULLS._serialized_start = 3195
+    _SORT_SORTNULLS._serialized_end = 3277
+    _DEDUPLICATE._serialized_start = 3280
+    _DEDUPLICATE._serialized_end = 3422
+    _LOCALRELATION._serialized_start = 3424
+    _LOCALRELATION._serialized_end = 3517
+    _SAMPLE._serialized_start = 3520
+    _SAMPLE._serialized_end = 3760
+    _SAMPLE_SEED._serialized_start = 3734
+    _SAMPLE_SEED._serialized_end = 3760
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index df179df1480..d3186c4e3df 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -467,6 +467,7 @@ class Join(google.protobuf.message.Message):
     RIGHT_FIELD_NUMBER: builtins.int
     JOIN_CONDITION_FIELD_NUMBER: builtins.int
     JOIN_TYPE_FIELD_NUMBER: builtins.int
+    USING_COLUMNS_FIELD_NUMBER: builtins.int
     @property
     def left(self) -> global___Relation: ...
     @property
@@ -474,6 +475,16 @@ class Join(google.protobuf.message.Message):
     @property
     def join_condition(self) -> 
pyspark.sql.connect.proto.expressions_pb2.Expression: ...
     join_type: global___Join.JoinType.ValueType
+    @property
+    def using_columns(
+        self,
+    ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+        """Optional. using_columns provides a list of columns that should 
present on both sides of
+        the join inputs that this Join will join on. For example A JOIN B 
USING col_name is
+        equivalent to A JOIN B on A.col_name = B.col_name.
+
+        This field does not co-exist with join_condition.
+        """
     def __init__(
         self,
         *,
@@ -481,6 +492,7 @@ class Join(google.protobuf.message.Message):
         right: global___Relation | None = ...,
         join_condition: pyspark.sql.connect.proto.expressions_pb2.Expression | 
None = ...,
         join_type: global___Join.JoinType.ValueType = ...,
+        using_columns: collections.abc.Iterable[builtins.str] | None = ...,
     ) -> None: ...
     def HasField(
         self,
@@ -499,6 +511,8 @@ class Join(google.protobuf.message.Message):
             b"left",
             "right",
             b"right",
+            "using_columns",
+            b"using_columns",
         ],
     ) -> None: ...
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to