This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f03fdf90281 [SPARK-40883][CONNECT][FOLLOW-UP] Range.step is required and Python client should have a default value=1 f03fdf90281 is described below commit f03fdf90281d67065b9ab211b5cd9cfbe5742614 Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Wed Nov 2 14:10:13 2022 +0800 [SPARK-40883][CONNECT][FOLLOW-UP] Range.step is required and Python client should have a default value=1 ### What changes were proposed in this pull request? To match existing Python DataFarme API, this PR changes the `Range.step` as required and Python client keep `1` as a default value for this field. ### Why are the changes needed? Matching existing DataFrame API. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? UT Closes #38471 from amaliujia/range_step_required. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../main/protobuf/spark/connect/relations.proto | 8 ++------ .../org/apache/spark/sql/connect/dsl/package.scala | 4 +++- .../sql/connect/planner/SparkConnectPlanner.scala | 6 +----- python/pyspark/sql/connect/client.py | 2 +- python/pyspark/sql/connect/plan.py | 7 ++----- python/pyspark/sql/connect/proto/relations_pb2.py | 14 ++++++-------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 22 ++++------------------ .../sql/tests/connect/test_connect_plan_only.py | 4 ++-- python/pyspark/testing/connectutils.py | 2 +- 9 files changed, 22 insertions(+), 47 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index a4503204aa1..deb35525728 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -226,16 +226,12 @@ message Range { int64 start = 1; // Required. int64 end = 2; - // Optional. Default value = 1 - Step step = 3; + // Required. + int64 step = 3; // Optional. Default value is assigned by 1) SQL conf "spark.sql.leafNodeDefaultParallelism" if // it is set, or 2) spark default parallelism. NumPartitions num_partitions = 4; - message Step { - int64 step = 1; - } - message NumPartitions { int32 num_partitions = 1; } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index f649d040721..e2030c9ad31 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -190,7 +190,9 @@ package object dsl { } range.setEnd(end) if (step.isDefined) { - range.setStep(proto.Range.Step.newBuilder().setStep(step.get)) + range.setStep(step.get) + } else { + range.setStep(1L) } if (numPartitions.isDefined) { range.setNumPartitions( diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index eea2579e61f..f5c6980290f 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -110,11 +110,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { private def transformRange(rel: proto.Range): LogicalPlan = { val start = rel.getStart val end = rel.getEnd - val step = if (rel.hasStep) { - rel.getStep.getStep - } else { - 1 - } + val step = rel.getStep val numPartitions = if (rel.hasNumPartitions) { rel.getNumPartitions.getNumPartitions } else { diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index e64d612c53e..c845d378320 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -149,7 +149,7 @@ class RemoteSparkSession(object): self, start: int, end: int, - step: Optional[int] = None, + step: int = 1, numPartitions: Optional[int] = None, ) -> DataFrame: """ diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 71c971d9e91..2f1f70ec1a9 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -705,7 +705,7 @@ class Range(LogicalPlan): self, start: int, end: int, - step: Optional[int] = None, + step: int, num_partitions: Optional[int] = None, ) -> None: super().__init__(None) @@ -718,10 +718,7 @@ class Range(LogicalPlan): rel = proto.Relation() rel.range.start = self._start rel.range.end = self._end - if self._step is not None: - step_proto = rel.range.Step() - step_proto.step = self._step - rel.range.step.CopyFrom(step_proto) + rel.range.step = self._step if self._num_partitions is not None: num_partitions_proto = rel.range.NumPartitions() num_partitions_proto.num_partitions = self._num_partitions diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 6741341326a..3d5eb53e5a9 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8c\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8c\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -96,11 +96,9 @@ if _descriptor._USE_C_DESCRIPTORS == False: _SAMPLE_SEED._serialized_start = 4024 _SAMPLE_SEED._serialized_end = 4050 _RANGE._serialized_start = 4053 - _RANGE._serialized_end = 4306 - _RANGE_STEP._serialized_start = 4224 - _RANGE_STEP._serialized_end = 4250 - _RANGE_NUMPARTITIONS._serialized_start = 4252 - _RANGE_NUMPARTITIONS._serialized_end = 4306 - _SUBQUERYALIAS._serialized_start = 4308 - _SUBQUERYALIAS._serialized_end = 4422 + _RANGE._serialized_end = 4251 + _RANGE_NUMPARTITIONS._serialized_start = 4197 + _RANGE_NUMPARTITIONS._serialized_end = 4251 + _SUBQUERYALIAS._serialized_start = 4253 + _SUBQUERYALIAS._serialized_end = 4367 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 34fbc2c300f..60f4e2033a8 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -960,18 +960,6 @@ class Range(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - class Step(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - STEP_FIELD_NUMBER: builtins.int - step: builtins.int - def __init__( - self, - *, - step: builtins.int = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["step", b"step"]) -> None: ... - class NumPartitions(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -994,9 +982,8 @@ class Range(google.protobuf.message.Message): """Optional. Default value = 0""" end: builtins.int """Required.""" - @property - def step(self) -> global___Range.Step: - """Optional. Default value = 1""" + step: builtins.int + """Required.""" @property def num_partitions(self) -> global___Range.NumPartitions: """Optional. Default value is assigned by 1) SQL conf "spark.sql.leafNodeDefaultParallelism" if @@ -1007,12 +994,11 @@ class Range(google.protobuf.message.Message): *, start: builtins.int = ..., end: builtins.int = ..., - step: global___Range.Step | None = ..., + step: builtins.int = ..., num_partitions: global___Range.NumPartitions | None = ..., ) -> None: ... def HasField( - self, - field_name: typing_extensions.Literal["num_partitions", b"num_partitions", "step", b"step"], + self, field_name: typing_extensions.Literal["num_partitions", b"num_partitions"] ) -> builtins.bool: ... def ClearField( self, diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 054d2b00088..902e9feeb3c 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -150,13 +150,13 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): ) self.assertEqual(plan.root.range.start, 10) self.assertEqual(plan.root.range.end, 20) - self.assertEqual(plan.root.range.step.step, 3) + self.assertEqual(plan.root.range.step, 3) self.assertEqual(plan.root.range.num_partitions.num_partitions, 4) plan = self.connect.range(start=10, end=20)._plan.to_proto(self.connect) self.assertEqual(plan.root.range.start, 10) self.assertEqual(plan.root.range.end, 20) - self.assertFalse(plan.root.range.HasField("step")) + self.assertEqual(plan.root.range.step, 1) self.assertFalse(plan.root.range.HasField("num_partitions")) def test_datasource_read(self): diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 9d0aa7f6884..b94306406c3 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -80,7 +80,7 @@ class PlanOnlyTestFixture(unittest.TestCase): cls, start: int, end: int, - step: Optional[int] = None, + step: int = 1, num_partitions: Optional[int] = None, ) -> "DataFrame": return DataFrame.withPlan( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org