This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 46fab54b500 [SPARK-41105][CONNECT] Adopt `optional` keyword from proto3 which offers `hasXXX` to differentiate if a field is set or unset 46fab54b500 is described below commit 46fab54b500c579cd421fb9e8ea95fae0ddda87d Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Fri Nov 11 12:04:34 2022 +0900 [SPARK-41105][CONNECT] Adopt `optional` keyword from proto3 which offers `hasXXX` to differentiate if a field is set or unset ### What changes were proposed in this pull request? We used to wrap those fields into messages to acquire the ability to tell if those field is set or unset. It turns out proto3 offers built-in mechanism to achieve the same thing: https://developers.google.com/protocol-buffers/docs/proto3#specifying_field_rules. It is as easy as adding `optional` keyword to the field to auto-generate `hasXXX` method. This PR refactors existing proto to get rid of redundant message definitions. ### Why are the changes needed? Codebase simplification. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? Existing UT Closes #38606 from amaliujia/refactor_proto. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../main/protobuf/spark/connect/relations.proto | 12 +--- .../org/apache/spark/sql/connect/dsl/package.scala | 5 +- .../sql/connect/planner/SparkConnectPlanner.scala | 4 +- python/pyspark/sql/connect/plan.py | 6 +- python/pyspark/sql/connect/proto/relations_pb2.py | 40 ++++++------ python/pyspark/sql/connect/proto/relations_pb2.pyi | 71 ++++++++++------------ .../sql/tests/connect/test_connect_plan_only.py | 4 +- 7 files changed, 61 insertions(+), 81 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index 639d1bafce5..4f30b5bfbde 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -215,11 +215,7 @@ message Sample { double lower_bound = 2; double upper_bound = 3; bool with_replacement = 4; - Seed seed = 5; - - message Seed { - int64 seed = 1; - } + optional int64 seed = 5; } // Relation of type [[Range]] that generates a sequence of integers. @@ -232,11 +228,7 @@ message Range { int64 step = 3; // Optional. Default value is assigned by 1) SQL conf "spark.sql.leafNodeDefaultParallelism" if // it is set, or 2) spark default parallelism. - NumPartitions num_partitions = 4; - - message NumPartitions { - int32 num_partitions = 1; - } + optional int32 num_partitions = 4; } // Relation alias. diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 5e7a94da347..f55ed835d23 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -216,8 +216,7 @@ package object dsl { range.setStep(1L) } if (numPartitions.isDefined) { - range.setNumPartitions( - proto.Range.NumPartitions.newBuilder().setNumPartitions(numPartitions.get)) + range.setNumPartitions(numPartitions.get) } Relation.newBuilder().setRange(range).build() } @@ -376,7 +375,7 @@ package object dsl { .setUpperBound(upperBound) .setLowerBound(lowerBound) .setWithReplacement(withReplacement) - .setSeed(Sample.Seed.newBuilder().setSeed(seed).build()) + .setSeed(seed) .build()) .build() } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 04ce880a925..b91fef58a11 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -104,7 +104,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { rel.getLowerBound, rel.getUpperBound, rel.getWithReplacement, - if (rel.hasSeed) rel.getSeed.getSeed else Utils.random.nextLong, + if (rel.hasSeed) rel.getSeed else Utils.random.nextLong, transformRelation(rel.getInput)) } @@ -117,7 +117,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { val end = rel.getEnd val step = rel.getStep val numPartitions = if (rel.hasNumPartitions) { - rel.getNumPartitions.getNumPartitions + rel.getNumPartitions } else { session.leafNodeDefaultParallelism } diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index e5eed195568..be1060a9fd8 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -443,7 +443,7 @@ class Sample(LogicalPlan): plan.sample.upper_bound = self.upper_bound plan.sample.with_replacement = self.with_replacement if self.seed is not None: - plan.sample.seed.seed = self.seed + plan.sample.seed = self.seed return plan def print(self, indent: int = 0) -> str: @@ -777,9 +777,7 @@ class Range(LogicalPlan): rel.range.end = self._end rel.range.step = self._step if self._num_partitions is not None: - num_partitions_proto = rel.range.NumPartitions() - num_partitions_proto.num_partitions = self._num_partitions - rel.range.num_partitions.CopyFrom(num_partitions_proto) + rel.range.num_partitions = self._num_partitions return rel def print(self, indent: int = 0) -> str: diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 323eb8e7690..73b789cf7d6 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xb6\n\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\ [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xb6\n\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\ [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -92,25 +92,21 @@ if _descriptor._USE_C_DESCRIPTORS == False: _LOCALRELATION._serialized_start = 4025 _LOCALRELATION._serialized_end = 4118 _SAMPLE._serialized_start = 4121 - _SAMPLE._serialized_end = 4361 - _SAMPLE_SEED._serialized_start = 4335 - _SAMPLE_SEED._serialized_end = 4361 - _RANGE._serialized_start = 4364 - _RANGE._serialized_end = 4562 - _RANGE_NUMPARTITIONS._serialized_start = 4508 - _RANGE_NUMPARTITIONS._serialized_end = 4562 - _SUBQUERYALIAS._serialized_start = 4564 - _SUBQUERYALIAS._serialized_end = 4678 - _REPARTITION._serialized_start = 4680 - _REPARTITION._serialized_end = 4805 - _STATSUMMARY._serialized_start = 4807 - _STATSUMMARY._serialized_end = 4899 - _STATCROSSTAB._serialized_start = 4901 - _STATCROSSTAB._serialized_end = 5002 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5004 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5118 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5121 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 5380 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5313 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 5380 + _SAMPLE._serialized_end = 4319 + _RANGE._serialized_start = 4322 + _RANGE._serialized_end = 4452 + _SUBQUERYALIAS._serialized_start = 4454 + _SUBQUERYALIAS._serialized_end = 4568 + _REPARTITION._serialized_start = 4570 + _REPARTITION._serialized_end = 4695 + _STATSUMMARY._serialized_start = 4697 + _STATSUMMARY._serialized_end = 4789 + _STATCROSSTAB._serialized_start = 4791 + _STATCROSSTAB._serialized_end = 4892 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 4894 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5008 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5011 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 5270 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5203 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 5270 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 53f75b7520f..e706fa3e11d 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -924,18 +924,6 @@ class Sample(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - class Seed(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - SEED_FIELD_NUMBER: builtins.int - seed: builtins.int - def __init__( - self, - *, - seed: builtins.int = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["seed", b"seed"]) -> None: ... - INPUT_FIELD_NUMBER: builtins.int LOWER_BOUND_FIELD_NUMBER: builtins.int UPPER_BOUND_FIELD_NUMBER: builtins.int @@ -946,8 +934,7 @@ class Sample(google.protobuf.message.Message): lower_bound: builtins.float upper_bound: builtins.float with_replacement: builtins.bool - @property - def seed(self) -> global___Sample.Seed: ... + seed: builtins.int def __init__( self, *, @@ -955,14 +942,19 @@ class Sample(google.protobuf.message.Message): lower_bound: builtins.float = ..., upper_bound: builtins.float = ..., with_replacement: builtins.bool = ..., - seed: global___Sample.Seed | None = ..., + seed: builtins.int | None = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["input", b"input", "seed", b"seed"] + self, + field_name: typing_extensions.Literal[ + "_seed", b"_seed", "input", b"input", "seed", b"seed" + ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ + "_seed", + b"_seed", "input", b"input", "lower_bound", @@ -975,6 +967,9 @@ class Sample(google.protobuf.message.Message): b"with_replacement", ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_seed", b"_seed"] + ) -> typing_extensions.Literal["seed"] | None: ... global___Sample = Sample @@ -983,20 +978,6 @@ class Range(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - class NumPartitions(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - NUM_PARTITIONS_FIELD_NUMBER: builtins.int - num_partitions: builtins.int - def __init__( - self, - *, - num_partitions: builtins.int = ..., - ) -> None: ... - def ClearField( - self, field_name: typing_extensions.Literal["num_partitions", b"num_partitions"] - ) -> None: ... - START_FIELD_NUMBER: builtins.int END_FIELD_NUMBER: builtins.int STEP_FIELD_NUMBER: builtins.int @@ -1007,28 +988,42 @@ class Range(google.protobuf.message.Message): """Required.""" step: builtins.int """Required.""" - @property - def num_partitions(self) -> global___Range.NumPartitions: - """Optional. Default value is assigned by 1) SQL conf "spark.sql.leafNodeDefaultParallelism" if - it is set, or 2) spark default parallelism. - """ + num_partitions: builtins.int + """Optional. Default value is assigned by 1) SQL conf "spark.sql.leafNodeDefaultParallelism" if + it is set, or 2) spark default parallelism. + """ def __init__( self, *, start: builtins.int = ..., end: builtins.int = ..., step: builtins.int = ..., - num_partitions: global___Range.NumPartitions | None = ..., + num_partitions: builtins.int | None = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["num_partitions", b"num_partitions"] + self, + field_name: typing_extensions.Literal[ + "_num_partitions", b"_num_partitions", "num_partitions", b"num_partitions" + ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "end", b"end", "num_partitions", b"num_partitions", "start", b"start", "step", b"step" + "_num_partitions", + b"_num_partitions", + "end", + b"end", + "num_partitions", + b"num_partitions", + "start", + b"start", + "step", + b"step", ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_num_partitions", b"_num_partitions"] + ) -> typing_extensions.Literal["num_partitions"] | None: ... global___Range = Range diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index c46d4d10624..4e26581a002 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -121,7 +121,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): self.assertEqual(plan.root.sample.lower_bound, 0.0) self.assertEqual(plan.root.sample.upper_bound, 0.4) self.assertEqual(plan.root.sample.with_replacement, True) - self.assertEqual(plan.root.sample.seed.seed, -1) + self.assertEqual(plan.root.sample.seed, -1) def test_sort(self): df = self.connect.readTable(table_name=self.tbl_name) @@ -180,7 +180,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): self.assertEqual(plan.root.range.start, 10) self.assertEqual(plan.root.range.end, 20) self.assertEqual(plan.root.range.step, 3) - self.assertEqual(plan.root.range.num_partitions.num_partitions, 4) + self.assertEqual(plan.root.range.num_partitions, 4) plan = self.connect.range(start=10, end=20)._plan.to_proto(self.connect) self.assertEqual(plan.root.range.start, 10) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org