This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 968463b070e [SPARK-40981][CONNECT][PYTHON] Support session.range in 
Python client
968463b070e is described below

commit 968463b070eac325f7d018b13e27c5694f33089e
Author: Rui Wang <rui.w...@databricks.com>
AuthorDate: Tue Nov 1 16:51:33 2022 +0800

    [SPARK-40981][CONNECT][PYTHON] Support session.range in Python client
    
    ### What changes were proposed in this pull request?
    
    This PR adds `range` API to Python client's `RemoteSparkSession` with tests.
    
    This PR also updates `start`, `end`, `step` to `int64` in the Connect proto.
    
    ### Why are the changes needed?
    
    Improve API coverage.
    
    ### Does this PR introduce _any_ user-facing change?
    
    NO
    
    ### How was this patch tested?
    
    UT
    
    Closes #38460 from amaliujia/SPARK-40981.
    
    Authored-by: Rui Wang <rui.w...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../main/protobuf/spark/connect/relations.proto    |  6 +--
 .../org/apache/spark/sql/connect/dsl/package.scala |  6 +--
 python/pyspark/sql/connect/client.py               | 35 ++++++++++++++-
 python/pyspark/sql/connect/plan.py                 | 50 ++++++++++++++++++++++
 python/pyspark/sql/connect/proto/relations_pb2.py  |  2 +-
 .../sql/tests/connect/test_connect_basic.py        | 17 ++++++++
 .../sql/tests/connect/test_connect_plan_only.py    | 15 +++++++
 python/pyspark/testing/connectutils.py             | 18 +++++++-
 8 files changed, 139 insertions(+), 10 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index e88e70ceb73..a4503204aa1 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -223,9 +223,9 @@ message Sample {
 // Relation of type [[Range]] that generates a sequence of integers.
 message Range {
   // Optional. Default value = 0
-  int32 start = 1;
+  int64 start = 1;
   // Required.
-  int32 end = 2;
+  int64 end = 2;
   // Optional. Default value = 1
   Step step = 3;
   // Optional. Default value is assigned by 1) SQL conf 
"spark.sql.leafNodeDefaultParallelism" if
@@ -233,7 +233,7 @@ message Range {
   NumPartitions num_partitions = 4;
 
   message Step {
-    int32 step = 1;
+    int64 step = 1;
   }
 
   message NumPartitions {
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 54e51868c75..067b6e42ec2 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -180,9 +180,9 @@ package object dsl {
   object plans { // scalastyle:ignore
     implicit class DslMockRemoteSession(val session: MockRemoteSession) {
       def range(
-          start: Option[Int],
-          end: Int,
-          step: Option[Int],
+          start: Option[Long],
+          end: Long,
+          step: Option[Long],
           numPartitions: Option[Int]): Relation = {
         val range = proto.Range.newBuilder()
         if (start.isDefined) {
diff --git a/python/pyspark/sql/connect/client.py 
b/python/pyspark/sql/connect/client.py
index f4b6d2ec302..e64d612c53e 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -32,7 +32,7 @@ import pyspark.sql.types
 from pyspark import cloudpickle
 from pyspark.sql.connect.dataframe import DataFrame
 from pyspark.sql.connect.readwriter import DataFrameReader
-from pyspark.sql.connect.plan import SQL
+from pyspark.sql.connect.plan import SQL, Range
 from pyspark.sql.types import DataType, StructType, StructField, LongType, 
StringType
 
 from typing import Optional, Any, Union
@@ -145,6 +145,39 @@ class RemoteSparkSession(object):
     def sql(self, sql_string: str) -> "DataFrame":
         return DataFrame.withPlan(SQL(sql_string), self)
 
+    def range(
+        self,
+        start: int,
+        end: int,
+        step: Optional[int] = None,
+        numPartitions: Optional[int] = None,
+    ) -> DataFrame:
+        """
+        Create a :class:`DataFrame` with column named ``id`` and typed Long,
+        containing elements in a range from ``start`` to ``end`` (exclusive) 
with
+        step value ``step``.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        start : int
+            the start value
+        end : int
+            the end value (exclusive)
+        step : int, optional
+            the incremental step (default: 1)
+        numPartitions : int, optional
+            the number of partitions of the DataFrame
+
+        Returns
+        -------
+        :class:`DataFrame`
+        """
+        return DataFrame.withPlan(
+            Range(start=start, end=end, step=step, 
num_partitions=numPartitions), self
+        )
+
     def _to_pandas(self, plan: pb2.Plan) -> Optional[pandas.DataFrame]:
         req = pb2.Request()
         req.user_context.user_id = self._user_id
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 56c609d9576..71c971d9e91 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -698,3 +698,53 @@ class SQL(LogicalPlan):
            </li>
         </ul>
         """
+
+
+class Range(LogicalPlan):
+    def __init__(
+        self,
+        start: int,
+        end: int,
+        step: Optional[int] = None,
+        num_partitions: Optional[int] = None,
+    ) -> None:
+        super().__init__(None)
+        self._start = start
+        self._end = end
+        self._step = step
+        self._num_partitions = num_partitions
+
+    def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
+        rel = proto.Relation()
+        rel.range.start = self._start
+        rel.range.end = self._end
+        if self._step is not None:
+            step_proto = rel.range.Step()
+            step_proto.step = self._step
+            rel.range.step.CopyFrom(step_proto)
+        if self._num_partitions is not None:
+            num_partitions_proto = rel.range.NumPartitions()
+            num_partitions_proto.num_partitions = self._num_partitions
+            rel.range.num_partitions.CopyFrom(num_partitions_proto)
+        return rel
+
+    def print(self, indent: int = 0) -> str:
+        return (
+            f"{' ' * indent}"
+            f"<Range start={self._start}, end={self._end}, "
+            f"step={self._step}, num_partitions={self._num_partitions}>"
+        )
+
+    def _repr_html_(self) -> str:
+        return f"""
+        <ul>
+            <li>
+                <b>Range</b><br />
+                Start: {self._start} <br />
+                End: {self._end} <br />
+                Step: {self._step} <br />
+                NumPartitions: {self._num_partitions} <br />
+                {self._child_repr_()}
+            </li>
+        </uL>
+        """
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 5db36434e30..6741341326a 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as 
spark_dot_connect_dot_e
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8c\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
 \x01(\x0 [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8c\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 
\x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05
 \x01(\x0 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 11d51f73476..e9a06f9c545 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -134,6 +134,23 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         self.assertIsNotNone(pd)
         self.assertEqual(10, len(pd.index))
 
+    def test_range(self):
+        self.assertTrue(
+            self.connect.range(start=0, end=10)
+            .toPandas()
+            .equals(self.spark.range(start=0, end=10).toPandas())
+        )
+        self.assertTrue(
+            self.connect.range(start=0, end=10, step=3)
+            .toPandas()
+            .equals(self.spark.range(start=0, end=10, step=3).toPandas())
+        )
+        self.assertTrue(
+            self.connect.range(start=0, end=10, step=3, numPartitions=2)
+            .toPandas()
+            .equals(self.spark.range(start=0, end=10, step=3, 
numPartitions=2).toPandas())
+        )
+
     def test_simple_datasource_read(self) -> None:
         writeDf = self.df_text
         tmpPath = tempfile.mkdtemp()
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py 
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index a295c9612be..054d2b00088 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -144,6 +144,21 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         plan = df.alias("table_alias")._plan.to_proto(self.connect)
         self.assertEqual(plan.root.subquery_alias.alias, "table_alias")
 
+    def test_range(self):
+        plan = self.connect.range(start=10, end=20, step=3, 
num_partitions=4)._plan.to_proto(
+            self.connect
+        )
+        self.assertEqual(plan.root.range.start, 10)
+        self.assertEqual(plan.root.range.end, 20)
+        self.assertEqual(plan.root.range.step.step, 3)
+        self.assertEqual(plan.root.range.num_partitions.num_partitions, 4)
+
+        plan = self.connect.range(start=10, 
end=20)._plan.to_proto(self.connect)
+        self.assertEqual(plan.root.range.start, 10)
+        self.assertEqual(plan.root.range.end, 20)
+        self.assertFalse(plan.root.range.HasField("step"))
+        self.assertFalse(plan.root.range.HasField("num_partitions"))
+
     def test_datasource_read(self):
         reader = DataFrameReader(self.connect)
         df = reader.load(path="test_path", format="text", schema="id INT", 
op1="opv", op2="opv2")
diff --git a/python/pyspark/testing/connectutils.py 
b/python/pyspark/testing/connectutils.py
index d9bced3af11..9d0aa7f6884 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -15,14 +15,14 @@
 # limitations under the License.
 #
 import os
-from typing import Any, Dict
+from typing import Any, Dict, Optional
 import functools
 import unittest
 from pyspark.testing.sqlutils import have_pandas
 
 if have_pandas:
     from pyspark.sql.connect import DataFrame
-    from pyspark.sql.connect.plan import Read
+    from pyspark.sql.connect.plan import Read, Range
     from pyspark.testing.utils import search_jar
 
     connect_jar = search_jar("connector/connect", "spark-connect-assembly-", 
"spark-connect")
@@ -75,6 +75,18 @@ class PlanOnlyTestFixture(unittest.TestCase):
     def _udf_mock(cls, *args, **kwargs) -> str:
         return "internal_name"
 
+    @classmethod
+    def _session_range(
+        cls,
+        start: int,
+        end: int,
+        step: Optional[int] = None,
+        num_partitions: Optional[int] = None,
+    ) -> "DataFrame":
+        return DataFrame.withPlan(
+            Range(start, end, step, num_partitions), cls.connect  # type: 
ignore
+        )
+
     @classmethod
     def setUpClass(cls: Any) -> None:
         cls.connect = MockRemoteSession()
@@ -82,8 +94,10 @@ class PlanOnlyTestFixture(unittest.TestCase):
 
         cls.connect.set_hook("register_udf", cls._udf_mock)
         cls.connect.set_hook("readTable", cls._read_table)
+        cls.connect.set_hook("range", cls._session_range)
 
     @classmethod
     def tearDownClass(cls: Any) -> None:
         cls.connect.drop_hook("register_udf")
         cls.connect.drop_hook("readTable")
+        cls.connect.drop_hook("range")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to