This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 968463b070e [SPARK-40981][CONNECT][PYTHON] Support session.range in Python client 968463b070e is described below commit 968463b070eac325f7d018b13e27c5694f33089e Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Tue Nov 1 16:51:33 2022 +0800 [SPARK-40981][CONNECT][PYTHON] Support session.range in Python client ### What changes were proposed in this pull request? This PR adds `range` API to Python client's `RemoteSparkSession` with tests. This PR also updates `start`, `end`, `step` to `int64` in the Connect proto. ### Why are the changes needed? Improve API coverage. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? UT Closes #38460 from amaliujia/SPARK-40981. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../main/protobuf/spark/connect/relations.proto | 6 +-- .../org/apache/spark/sql/connect/dsl/package.scala | 6 +-- python/pyspark/sql/connect/client.py | 35 ++++++++++++++- python/pyspark/sql/connect/plan.py | 50 ++++++++++++++++++++++ python/pyspark/sql/connect/proto/relations_pb2.py | 2 +- .../sql/tests/connect/test_connect_basic.py | 17 ++++++++ .../sql/tests/connect/test_connect_plan_only.py | 15 +++++++ python/pyspark/testing/connectutils.py | 18 +++++++- 8 files changed, 139 insertions(+), 10 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index e88e70ceb73..a4503204aa1 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -223,9 +223,9 @@ message Sample { // Relation of type [[Range]] that generates a sequence of integers. message Range { // Optional. Default value = 0 - int32 start = 1; + int64 start = 1; // Required. - int32 end = 2; + int64 end = 2; // Optional. Default value = 1 Step step = 3; // Optional. Default value is assigned by 1) SQL conf "spark.sql.leafNodeDefaultParallelism" if @@ -233,7 +233,7 @@ message Range { NumPartitions num_partitions = 4; message Step { - int32 step = 1; + int64 step = 1; } message NumPartitions { diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 54e51868c75..067b6e42ec2 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -180,9 +180,9 @@ package object dsl { object plans { // scalastyle:ignore implicit class DslMockRemoteSession(val session: MockRemoteSession) { def range( - start: Option[Int], - end: Int, - step: Option[Int], + start: Option[Long], + end: Long, + step: Option[Long], numPartitions: Option[Int]): Relation = { val range = proto.Range.newBuilder() if (start.isDefined) { diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index f4b6d2ec302..e64d612c53e 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -32,7 +32,7 @@ import pyspark.sql.types from pyspark import cloudpickle from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect.readwriter import DataFrameReader -from pyspark.sql.connect.plan import SQL +from pyspark.sql.connect.plan import SQL, Range from pyspark.sql.types import DataType, StructType, StructField, LongType, StringType from typing import Optional, Any, Union @@ -145,6 +145,39 @@ class RemoteSparkSession(object): def sql(self, sql_string: str) -> "DataFrame": return DataFrame.withPlan(SQL(sql_string), self) + def range( + self, + start: int, + end: int, + step: Optional[int] = None, + numPartitions: Optional[int] = None, + ) -> DataFrame: + """ + Create a :class:`DataFrame` with column named ``id`` and typed Long, + containing elements in a range from ``start`` to ``end`` (exclusive) with + step value ``step``. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + start : int + the start value + end : int + the end value (exclusive) + step : int, optional + the incremental step (default: 1) + numPartitions : int, optional + the number of partitions of the DataFrame + + Returns + ------- + :class:`DataFrame` + """ + return DataFrame.withPlan( + Range(start=start, end=end, step=step, num_partitions=numPartitions), self + ) + def _to_pandas(self, plan: pb2.Plan) -> Optional[pandas.DataFrame]: req = pb2.Request() req.user_context.user_id = self._user_id diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 56c609d9576..71c971d9e91 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -698,3 +698,53 @@ class SQL(LogicalPlan): </li> </ul> """ + + +class Range(LogicalPlan): + def __init__( + self, + start: int, + end: int, + step: Optional[int] = None, + num_partitions: Optional[int] = None, + ) -> None: + super().__init__(None) + self._start = start + self._end = end + self._step = step + self._num_partitions = num_partitions + + def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + rel = proto.Relation() + rel.range.start = self._start + rel.range.end = self._end + if self._step is not None: + step_proto = rel.range.Step() + step_proto.step = self._step + rel.range.step.CopyFrom(step_proto) + if self._num_partitions is not None: + num_partitions_proto = rel.range.NumPartitions() + num_partitions_proto.num_partitions = self._num_partitions + rel.range.num_partitions.CopyFrom(num_partitions_proto) + return rel + + def print(self, indent: int = 0) -> str: + return ( + f"{' ' * indent}" + f"<Range start={self._start}, end={self._end}, " + f"step={self._step}, num_partitions={self._num_partitions}>" + ) + + def _repr_html_(self) -> str: + return f""" + <ul> + <li> + <b>Range</b><br /> + Start: {self._start} <br /> + End: {self._end} <br /> + Step: {self._step} <br /> + NumPartitions: {self._num_partitions} <br /> + {self._child_repr_()} + </li> + </uL> + """ diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 5db36434e30..6741341326a 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8c\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8c\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 11d51f73476..e9a06f9c545 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -134,6 +134,23 @@ class SparkConnectTests(SparkConnectSQLTestCase): self.assertIsNotNone(pd) self.assertEqual(10, len(pd.index)) + def test_range(self): + self.assertTrue( + self.connect.range(start=0, end=10) + .toPandas() + .equals(self.spark.range(start=0, end=10).toPandas()) + ) + self.assertTrue( + self.connect.range(start=0, end=10, step=3) + .toPandas() + .equals(self.spark.range(start=0, end=10, step=3).toPandas()) + ) + self.assertTrue( + self.connect.range(start=0, end=10, step=3, numPartitions=2) + .toPandas() + .equals(self.spark.range(start=0, end=10, step=3, numPartitions=2).toPandas()) + ) + def test_simple_datasource_read(self) -> None: writeDf = self.df_text tmpPath = tempfile.mkdtemp() diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index a295c9612be..054d2b00088 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -144,6 +144,21 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): plan = df.alias("table_alias")._plan.to_proto(self.connect) self.assertEqual(plan.root.subquery_alias.alias, "table_alias") + def test_range(self): + plan = self.connect.range(start=10, end=20, step=3, num_partitions=4)._plan.to_proto( + self.connect + ) + self.assertEqual(plan.root.range.start, 10) + self.assertEqual(plan.root.range.end, 20) + self.assertEqual(plan.root.range.step.step, 3) + self.assertEqual(plan.root.range.num_partitions.num_partitions, 4) + + plan = self.connect.range(start=10, end=20)._plan.to_proto(self.connect) + self.assertEqual(plan.root.range.start, 10) + self.assertEqual(plan.root.range.end, 20) + self.assertFalse(plan.root.range.HasField("step")) + self.assertFalse(plan.root.range.HasField("num_partitions")) + def test_datasource_read(self): reader = DataFrameReader(self.connect) df = reader.load(path="test_path", format="text", schema="id INT", op1="opv", op2="opv2") diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index d9bced3af11..9d0aa7f6884 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -15,14 +15,14 @@ # limitations under the License. # import os -from typing import Any, Dict +from typing import Any, Dict, Optional import functools import unittest from pyspark.testing.sqlutils import have_pandas if have_pandas: from pyspark.sql.connect import DataFrame - from pyspark.sql.connect.plan import Read + from pyspark.sql.connect.plan import Read, Range from pyspark.testing.utils import search_jar connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect") @@ -75,6 +75,18 @@ class PlanOnlyTestFixture(unittest.TestCase): def _udf_mock(cls, *args, **kwargs) -> str: return "internal_name" + @classmethod + def _session_range( + cls, + start: int, + end: int, + step: Optional[int] = None, + num_partitions: Optional[int] = None, + ) -> "DataFrame": + return DataFrame.withPlan( + Range(start, end, step, num_partitions), cls.connect # type: ignore + ) + @classmethod def setUpClass(cls: Any) -> None: cls.connect = MockRemoteSession() @@ -82,8 +94,10 @@ class PlanOnlyTestFixture(unittest.TestCase): cls.connect.set_hook("register_udf", cls._udf_mock) cls.connect.set_hook("readTable", cls._read_table) + cls.connect.set_hook("range", cls._session_range) @classmethod def tearDownClass(cls: Any) -> None: cls.connect.drop_hook("register_udf") cls.connect.drop_hook("readTable") + cls.connect.drop_hook("range") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org