This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9695b2cb59b [SPARK-41026][CONNECT] Support Repartition in Connect Proto 9695b2cb59b is described below commit 9695b2cb59b497709ca0050d754491d935742530 Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Tue Nov 8 09:00:12 2022 +0800 [SPARK-41026][CONNECT] Support Repartition in Connect Proto ### What changes were proposed in this pull request? Support `Repartition` in Connect proto, which further supports two API: `repartition` (shuffle=true) and `coalesce` (shuffle=false). ### Why are the changes needed? Improve API coverage. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? UT Closes #38529 from amaliujia/support_repartition_in_proto_connect. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../main/protobuf/spark/connect/relations.proto | 13 +++ .../org/apache/spark/sql/connect/dsl/package.scala | 18 ++++ .../sql/connect/planner/SparkConnectPlanner.scala | 5 + .../connect/planner/SparkConnectProtoSuite.scala | 10 ++ python/pyspark/sql/connect/proto/relations_pb2.py | 114 +++++++++++---------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 43 ++++++++ 6 files changed, 147 insertions(+), 56 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index 8edd8911242..36113e2a30c 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -46,6 +46,7 @@ message Relation { Deduplicate deduplicate = 14; Range range = 15; SubqueryAlias subquery_alias = 16; + Repartition repartition = 17; Unknown unknown = 999; } @@ -241,3 +242,15 @@ message SubqueryAlias { // Optional. Qualifier of the alias. repeated string qualifier = 3; } + +// Relation repartition. +message Repartition { + // Required. The input relation. + Relation input = 1; + + // Required. Must be positive. + int32 num_partitions = 2; + + // Optional. Default value is false. + bool shuffle = 3; +} diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index c40a9eed753..2755727de11 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -423,6 +423,24 @@ package object dsl { byName)) .build() + def coalesce(num: Integer): Relation = + Relation + .newBuilder() + .setRepartition( + Repartition + .newBuilder() + .setInput(logicalPlan) + .setNumPartitions(num) + .setShuffle(false)) + .build() + + def repartition(num: Integer): Relation = + Relation + .newBuilder() + .setRepartition( + Repartition.newBuilder().setInput(logicalPlan).setNumPartitions(num).setShuffle(true)) + .build() + private def createSetOperation( left: Relation, right: Relation, diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index d2b474711ab..1615fc56ab6 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -72,6 +72,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { case proto.Relation.RelTypeCase.RANGE => transformRange(rel.getRange) case proto.Relation.RelTypeCase.SUBQUERY_ALIAS => transformSubqueryAlias(rel.getSubqueryAlias) + case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") @@ -107,6 +108,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { transformRelation(rel.getInput)) } + private def transformRepartition(rel: proto.Repartition): LogicalPlan = { + logical.Repartition(rel.getNumPartitions, rel.getShuffle, transformRelation(rel.getInput)) + } + private def transformRange(rel: proto.Range): LogicalPlan = { val start = rel.getStart val end = rel.getEnd diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 0aa89d6f640..72dae674721 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -251,6 +251,16 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { comparePlans(connect.sql("SELECT 1"), spark.sql("SELECT 1")) } + test("Test Repartition") { + val connectPlan1 = connectTestRelation.repartition(12) + val sparkPlan1 = sparkTestRelation.repartition(12) + comparePlans(connectPlan1, sparkPlan1) + + val connectPlan2 = connectTestRelation.coalesce(2) + val sparkPlan2 = sparkTestRelation.coalesce(2) + comparePlans(connectPlan2, sparkPlan2) + } + private def createLocalRelationProtoByQualifiedAttributes( attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 6180c5e13c9..e43a5de583e 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8c\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xcc\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -44,59 +44,61 @@ if _descriptor._USE_C_DESCRIPTORS == False: _READ_DATASOURCE_OPTIONSENTRY._options = None _READ_DATASOURCE_OPTIONSENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 82 - _RELATION._serialized_end = 990 - _UNKNOWN._serialized_start = 992 - _UNKNOWN._serialized_end = 1001 - _RELATIONCOMMON._serialized_start = 1003 - _RELATIONCOMMON._serialized_end = 1052 - _SQL._serialized_start = 1054 - _SQL._serialized_end = 1081 - _READ._serialized_start = 1084 - _READ._serialized_end = 1494 - _READ_NAMEDTABLE._serialized_start = 1226 - _READ_NAMEDTABLE._serialized_end = 1287 - _READ_DATASOURCE._serialized_start = 1290 - _READ_DATASOURCE._serialized_end = 1481 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1423 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1481 - _PROJECT._serialized_start = 1496 - _PROJECT._serialized_end = 1613 - _FILTER._serialized_start = 1615 - _FILTER._serialized_end = 1727 - _JOIN._serialized_start = 1730 - _JOIN._serialized_end = 2180 - _JOIN_JOINTYPE._serialized_start = 1993 - _JOIN_JOINTYPE._serialized_end = 2180 - _SETOPERATION._serialized_start = 2183 - _SETOPERATION._serialized_end = 2546 - _SETOPERATION_SETOPTYPE._serialized_start = 2432 - _SETOPERATION_SETOPTYPE._serialized_end = 2546 - _LIMIT._serialized_start = 2548 - _LIMIT._serialized_end = 2624 - _OFFSET._serialized_start = 2626 - _OFFSET._serialized_end = 2705 - _AGGREGATE._serialized_start = 2708 - _AGGREGATE._serialized_end = 2918 - _SORT._serialized_start = 2921 - _SORT._serialized_end = 3452 - _SORT_SORTFIELD._serialized_start = 3070 - _SORT_SORTFIELD._serialized_end = 3258 - _SORT_SORTDIRECTION._serialized_start = 3260 - _SORT_SORTDIRECTION._serialized_end = 3368 - _SORT_SORTNULLS._serialized_start = 3370 - _SORT_SORTNULLS._serialized_end = 3452 - _DEDUPLICATE._serialized_start = 3455 - _DEDUPLICATE._serialized_end = 3597 - _LOCALRELATION._serialized_start = 3599 - _LOCALRELATION._serialized_end = 3692 - _SAMPLE._serialized_start = 3695 - _SAMPLE._serialized_end = 3935 - _SAMPLE_SEED._serialized_start = 3909 - _SAMPLE_SEED._serialized_end = 3935 - _RANGE._serialized_start = 3938 - _RANGE._serialized_end = 4136 - _RANGE_NUMPARTITIONS._serialized_start = 4082 - _RANGE_NUMPARTITIONS._serialized_end = 4136 - _SUBQUERYALIAS._serialized_start = 4138 - _SUBQUERYALIAS._serialized_end = 4252 + _RELATION._serialized_end = 1054 + _UNKNOWN._serialized_start = 1056 + _UNKNOWN._serialized_end = 1065 + _RELATIONCOMMON._serialized_start = 1067 + _RELATIONCOMMON._serialized_end = 1116 + _SQL._serialized_start = 1118 + _SQL._serialized_end = 1145 + _READ._serialized_start = 1148 + _READ._serialized_end = 1558 + _READ_NAMEDTABLE._serialized_start = 1290 + _READ_NAMEDTABLE._serialized_end = 1351 + _READ_DATASOURCE._serialized_start = 1354 + _READ_DATASOURCE._serialized_end = 1545 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1487 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1545 + _PROJECT._serialized_start = 1560 + _PROJECT._serialized_end = 1677 + _FILTER._serialized_start = 1679 + _FILTER._serialized_end = 1791 + _JOIN._serialized_start = 1794 + _JOIN._serialized_end = 2244 + _JOIN_JOINTYPE._serialized_start = 2057 + _JOIN_JOINTYPE._serialized_end = 2244 + _SETOPERATION._serialized_start = 2247 + _SETOPERATION._serialized_end = 2610 + _SETOPERATION_SETOPTYPE._serialized_start = 2496 + _SETOPERATION_SETOPTYPE._serialized_end = 2610 + _LIMIT._serialized_start = 2612 + _LIMIT._serialized_end = 2688 + _OFFSET._serialized_start = 2690 + _OFFSET._serialized_end = 2769 + _AGGREGATE._serialized_start = 2772 + _AGGREGATE._serialized_end = 2982 + _SORT._serialized_start = 2985 + _SORT._serialized_end = 3516 + _SORT_SORTFIELD._serialized_start = 3134 + _SORT_SORTFIELD._serialized_end = 3322 + _SORT_SORTDIRECTION._serialized_start = 3324 + _SORT_SORTDIRECTION._serialized_end = 3432 + _SORT_SORTNULLS._serialized_start = 3434 + _SORT_SORTNULLS._serialized_end = 3516 + _DEDUPLICATE._serialized_start = 3519 + _DEDUPLICATE._serialized_end = 3661 + _LOCALRELATION._serialized_start = 3663 + _LOCALRELATION._serialized_end = 3756 + _SAMPLE._serialized_start = 3759 + _SAMPLE._serialized_end = 3999 + _SAMPLE_SEED._serialized_start = 3973 + _SAMPLE_SEED._serialized_end = 3999 + _RANGE._serialized_start = 4002 + _RANGE._serialized_end = 4200 + _RANGE_NUMPARTITIONS._serialized_start = 4146 + _RANGE_NUMPARTITIONS._serialized_end = 4200 + _SUBQUERYALIAS._serialized_start = 4202 + _SUBQUERYALIAS._serialized_end = 4316 + _REPARTITION._serialized_start = 4318 + _REPARTITION._serialized_end = 4443 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index f5b5c9f90dc..30c1dddf885 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -75,6 +75,7 @@ class Relation(google.protobuf.message.Message): DEDUPLICATE_FIELD_NUMBER: builtins.int RANGE_FIELD_NUMBER: builtins.int SUBQUERY_ALIAS_FIELD_NUMBER: builtins.int + REPARTITION_FIELD_NUMBER: builtins.int UNKNOWN_FIELD_NUMBER: builtins.int @property def common(self) -> global___RelationCommon: ... @@ -109,6 +110,8 @@ class Relation(google.protobuf.message.Message): @property def subquery_alias(self) -> global___SubqueryAlias: ... @property + def repartition(self) -> global___Repartition: ... + @property def unknown(self) -> global___Unknown: ... def __init__( self, @@ -129,6 +132,7 @@ class Relation(google.protobuf.message.Message): deduplicate: global___Deduplicate | None = ..., range: global___Range | None = ..., subquery_alias: global___SubqueryAlias | None = ..., + repartition: global___Repartition | None = ..., unknown: global___Unknown | None = ..., ) -> None: ... def HasField( @@ -158,6 +162,8 @@ class Relation(google.protobuf.message.Message): b"read", "rel_type", b"rel_type", + "repartition", + b"repartition", "sample", b"sample", "set_op", @@ -199,6 +205,8 @@ class Relation(google.protobuf.message.Message): b"read", "rel_type", b"rel_type", + "repartition", + b"repartition", "sample", b"sample", "set_op", @@ -231,6 +239,7 @@ class Relation(google.protobuf.message.Message): "deduplicate", "range", "subquery_alias", + "repartition", "unknown", ] | None: ... @@ -1022,3 +1031,37 @@ class SubqueryAlias(google.protobuf.message.Message): ) -> None: ... global___SubqueryAlias = SubqueryAlias + +class Repartition(google.protobuf.message.Message): + """Relation repartition.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + NUM_PARTITIONS_FIELD_NUMBER: builtins.int + SHUFFLE_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: + """Required. The input relation.""" + num_partitions: builtins.int + """Required. Must be positive.""" + shuffle: builtins.bool + """Optional. Default value is false.""" + def __init__( + self, + *, + input: global___Relation | None = ..., + num_partitions: builtins.int = ..., + shuffle: builtins.bool = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["input", b"input"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "input", b"input", "num_partitions", b"num_partitions", "shuffle", b"shuffle" + ], + ) -> None: ... + +global___Repartition = Repartition --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org