This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f03fdf90281 [SPARK-40883][CONNECT][FOLLOW-UP] Range.step is required 
and Python client should have a default value=1
f03fdf90281 is described below

commit f03fdf90281d67065b9ab211b5cd9cfbe5742614
Author: Rui Wang <rui.w...@databricks.com>
AuthorDate: Wed Nov 2 14:10:13 2022 +0800

    [SPARK-40883][CONNECT][FOLLOW-UP] Range.step is required and Python client 
should have a default value=1
    
    ### What changes were proposed in this pull request?
    
    To match existing Python DataFarme API, this PR changes the `Range.step` as 
required and Python client keep `1` as a default value for this field.
    
    ### Why are the changes needed?
    
    Matching existing DataFrame API.
    
    ### Does this PR introduce _any_ user-facing change?
    
    NO
    
    ### How was this patch tested?
    
    UT
    
    Closes #38471 from amaliujia/range_step_required.
    
    Authored-by: Rui Wang <rui.w...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../main/protobuf/spark/connect/relations.proto    |  8 ++------
 .../org/apache/spark/sql/connect/dsl/package.scala |  4 +++-
 .../sql/connect/planner/SparkConnectPlanner.scala  |  6 +-----
 python/pyspark/sql/connect/client.py               |  2 +-
 python/pyspark/sql/connect/plan.py                 |  7 ++-----
 python/pyspark/sql/connect/proto/relations_pb2.py  | 14 ++++++--------
 python/pyspark/sql/connect/proto/relations_pb2.pyi | 22 ++++------------------
 .../sql/tests/connect/test_connect_plan_only.py    |  4 ++--
 python/pyspark/testing/connectutils.py             |  2 +-
 9 files changed, 22 insertions(+), 47 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index a4503204aa1..deb35525728 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -226,16 +226,12 @@ message Range {
   int64 start = 1;
   // Required.
   int64 end = 2;
-  // Optional. Default value = 1
-  Step step = 3;
+  // Required.
+  int64 step = 3;
   // Optional. Default value is assigned by 1) SQL conf 
"spark.sql.leafNodeDefaultParallelism" if
   // it is set, or 2) spark default parallelism.
   NumPartitions num_partitions = 4;
 
-  message Step {
-    int64 step = 1;
-  }
-
   message NumPartitions {
     int32 num_partitions = 1;
   }
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index f649d040721..e2030c9ad31 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -190,7 +190,9 @@ package object dsl {
         }
         range.setEnd(end)
         if (step.isDefined) {
-          range.setStep(proto.Range.Step.newBuilder().setStep(step.get))
+          range.setStep(step.get)
+        } else {
+          range.setStep(1L)
         }
         if (numPartitions.isDefined) {
           range.setNumPartitions(
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index eea2579e61f..f5c6980290f 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -110,11 +110,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: 
SparkSession) {
   private def transformRange(rel: proto.Range): LogicalPlan = {
     val start = rel.getStart
     val end = rel.getEnd
-    val step = if (rel.hasStep) {
-      rel.getStep.getStep
-    } else {
-      1
-    }
+    val step = rel.getStep
     val numPartitions = if (rel.hasNumPartitions) {
       rel.getNumPartitions.getNumPartitions
     } else {
diff --git a/python/pyspark/sql/connect/client.py 
b/python/pyspark/sql/connect/client.py
index e64d612c53e..c845d378320 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -149,7 +149,7 @@ class RemoteSparkSession(object):
         self,
         start: int,
         end: int,
-        step: Optional[int] = None,
+        step: int = 1,
         numPartitions: Optional[int] = None,
     ) -> DataFrame:
         """
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 71c971d9e91..2f1f70ec1a9 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -705,7 +705,7 @@ class Range(LogicalPlan):
         self,
         start: int,
         end: int,
-        step: Optional[int] = None,
+        step: int,
         num_partitions: Optional[int] = None,
     ) -> None:
         super().__init__(None)
@@ -718,10 +718,7 @@ class Range(LogicalPlan):
         rel = proto.Relation()
         rel.range.start = self._start
         rel.range.end = self._end
-        if self._step is not None:
-            step_proto = rel.range.Step()
-            step_proto.step = self._step
-            rel.range.step.CopyFrom(step_proto)
+        rel.range.step = self._step
         if self._num_partitions is not None:
             num_partitions_proto = rel.range.NumPartitions()
             num_partitions_proto.num_partitions = self._num_partitions
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 6741341326a..3d5eb53e5a9 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as 
spark_dot_connect_dot_e
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8c\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
 \x01(\x0 [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8c\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
 \x01(\x0 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -96,11 +96,9 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _SAMPLE_SEED._serialized_start = 4024
     _SAMPLE_SEED._serialized_end = 4050
     _RANGE._serialized_start = 4053
-    _RANGE._serialized_end = 4306
-    _RANGE_STEP._serialized_start = 4224
-    _RANGE_STEP._serialized_end = 4250
-    _RANGE_NUMPARTITIONS._serialized_start = 4252
-    _RANGE_NUMPARTITIONS._serialized_end = 4306
-    _SUBQUERYALIAS._serialized_start = 4308
-    _SUBQUERYALIAS._serialized_end = 4422
+    _RANGE._serialized_end = 4251
+    _RANGE_NUMPARTITIONS._serialized_start = 4197
+    _RANGE_NUMPARTITIONS._serialized_end = 4251
+    _SUBQUERYALIAS._serialized_start = 4253
+    _SUBQUERYALIAS._serialized_end = 4367
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 34fbc2c300f..60f4e2033a8 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -960,18 +960,6 @@ class Range(google.protobuf.message.Message):
 
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
-    class Step(google.protobuf.message.Message):
-        DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
-        STEP_FIELD_NUMBER: builtins.int
-        step: builtins.int
-        def __init__(
-            self,
-            *,
-            step: builtins.int = ...,
-        ) -> None: ...
-        def ClearField(self, field_name: typing_extensions.Literal["step", 
b"step"]) -> None: ...
-
     class NumPartitions(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
@@ -994,9 +982,8 @@ class Range(google.protobuf.message.Message):
     """Optional. Default value = 0"""
     end: builtins.int
     """Required."""
-    @property
-    def step(self) -> global___Range.Step:
-        """Optional. Default value = 1"""
+    step: builtins.int
+    """Required."""
     @property
     def num_partitions(self) -> global___Range.NumPartitions:
         """Optional. Default value is assigned by 1) SQL conf 
"spark.sql.leafNodeDefaultParallelism" if
@@ -1007,12 +994,11 @@ class Range(google.protobuf.message.Message):
         *,
         start: builtins.int = ...,
         end: builtins.int = ...,
-        step: global___Range.Step | None = ...,
+        step: builtins.int = ...,
         num_partitions: global___Range.NumPartitions | None = ...,
     ) -> None: ...
     def HasField(
-        self,
-        field_name: typing_extensions.Literal["num_partitions", 
b"num_partitions", "step", b"step"],
+        self, field_name: typing_extensions.Literal["num_partitions", 
b"num_partitions"]
     ) -> builtins.bool: ...
     def ClearField(
         self,
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py 
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 054d2b00088..902e9feeb3c 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -150,13 +150,13 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         )
         self.assertEqual(plan.root.range.start, 10)
         self.assertEqual(plan.root.range.end, 20)
-        self.assertEqual(plan.root.range.step.step, 3)
+        self.assertEqual(plan.root.range.step, 3)
         self.assertEqual(plan.root.range.num_partitions.num_partitions, 4)
 
         plan = self.connect.range(start=10, 
end=20)._plan.to_proto(self.connect)
         self.assertEqual(plan.root.range.start, 10)
         self.assertEqual(plan.root.range.end, 20)
-        self.assertFalse(plan.root.range.HasField("step"))
+        self.assertEqual(plan.root.range.step, 1)
         self.assertFalse(plan.root.range.HasField("num_partitions"))
 
     def test_datasource_read(self):
diff --git a/python/pyspark/testing/connectutils.py 
b/python/pyspark/testing/connectutils.py
index 9d0aa7f6884..b94306406c3 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -80,7 +80,7 @@ class PlanOnlyTestFixture(unittest.TestCase):
         cls,
         start: int,
         end: int,
-        step: Optional[int] = None,
+        step: int = 1,
         num_partitions: Optional[int] = None,
     ) -> "DataFrame":
         return DataFrame.withPlan(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to