This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e98987220ae [SPARK-44189][CONNECT][PYTHON] Support positional parameters by `sql()` e98987220ae is described below commit e98987220ae191ecc10944026fee9c57ddf478c1 Author: Max Gekk <max.g...@gmail.com> AuthorDate: Mon Jun 26 19:42:17 2023 +0300 [SPARK-44189][CONNECT][PYTHON] Support positional parameters by `sql()` ### What changes were proposed in this pull request? In the PR, I propose to extend the `sql()` method of Python connect client, and support positional parameters as list of Python objects that can be converted to literal expressions. ```python def sql(self, sqlQuery: str, args: Optional[Union[Dict[str, Any], List]] = None) -> DataFrame: ``` where - **args** is a dictionary of parameter names to Python objects or a list of Python objects that can be converted to SQL literal expressions. See the [link](https://spark.apache.org/docs/latest/sql-ref-datatypes.html) regarding the supported value types in PySpark. For example: _1, "Steven", datetime.date(2023, 4, 2)_. The same as in Scala/Java API, a value can be also a `Column` of literal expression, in that case it is taken as is. For example: ```python >>> connect.sql("SELECT * FROM {df} WHERE {df[B]} > ? and ? < {df[A]}", [5, 2], df=mydf).show() +---+---+ | A| B| +---+---+ | 3| 6| +---+---+ ``` ### Why are the changes needed? To achieve feature parity with the PySpark API. ### Does this PR introduce _any_ user-facing change? No, the PR just extends the existing API. ### How was this patch tested? By running new test: ``` $ python/run-tests --parallelism=1 --testnames 'pyspark.sql.tests.connect.test_connect_basic SparkConnectBasicTests.test_sql_with_pos_args' ``` and the renamed test: ``` $ python/run-tests --parallelism=1 --testnames 'pyspark.sql.tests.connect.test_connect_basic SparkConnectBasicTests.test_sql_with_named_args' ``` Closes #41739 from MaxGekk/positional-params-python-connect. Authored-by: Max Gekk <max.g...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- python/pyspark/sql/connect/plan.py | 36 ++++++++++++++++------ python/pyspark/sql/connect/session.py | 2 +- .../sql/tests/connect/test_connect_basic.py | 7 ++++- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 406f65080d1..fabab98d9b2 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -1019,12 +1019,15 @@ class SubqueryAlias(LogicalPlan): class SQL(LogicalPlan): - def __init__(self, query: str, args: Optional[Dict[str, Any]] = None) -> None: + def __init__(self, query: str, args: Optional[Union[Dict[str, Any], List]] = None) -> None: super().__init__(None) if args is not None: - for k, v in args.items(): - assert isinstance(k, str) + if isinstance(args, Dict): + for k, v in args.items(): + assert isinstance(k, str) + else: + assert isinstance(args, List) self._query = query self._args = args @@ -1034,8 +1037,16 @@ class SQL(LogicalPlan): plan.sql.query = self._query if self._args is not None and len(self._args) > 0: - for k, v in self._args.items(): - plan.sql.args[k].CopyFrom(LiteralExpression._from_value(v).to_plan(session).literal) + if isinstance(self._args, Dict): + for k, v in self._args.items(): + plan.sql.args[k].CopyFrom( + LiteralExpression._from_value(v).to_plan(session).literal + ) + else: + for v in self._args: + plan.sql.pos_args.append( + LiteralExpression._from_value(v).to_plan(session).literal + ) return plan @@ -1043,10 +1054,17 @@ class SQL(LogicalPlan): cmd = proto.Command() cmd.sql_command.sql = self._query if self._args is not None and len(self._args) > 0: - for k, v in self._args.items(): - cmd.sql_command.args[k].CopyFrom( - LiteralExpression._from_value(v).to_plan(session).literal - ) + if isinstance(self._args, Dict): + for k, v in self._args.items(): + cmd.sql_command.args[k].CopyFrom( + LiteralExpression._from_value(v).to_plan(session).literal + ) + else: + for v in self._args: + cmd.sql_command.pos_args.append( + LiteralExpression._from_value(v).to_plan(session).literal + ) + return cmd diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 365829ff7bc..356dacd8e18 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -489,7 +489,7 @@ class SparkSession: createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__ - def sql(self, sqlQuery: str, args: Optional[Dict[str, Any]] = None) -> "DataFrame": + def sql(self, sqlQuery: str, args: Optional[Union[Dict[str, Any], List]] = None) -> "DataFrame": cmd = SQL(sqlQuery, args) data, properties = self.client.execute_command(cmd.command(self._client)) if "sql_command_result" in properties: diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 89384b24e45..268011ef1e4 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -1223,11 +1223,16 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): pdf = self.connect.sql("SELECT 1").toPandas() self.assertEqual(1, len(pdf.index)) - def test_sql_with_args(self): + def test_sql_with_named_args(self): df = self.connect.sql("SELECT * FROM range(10) WHERE id > :minId", args={"minId": 7}) df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > :minId", args={"minId": 7}) self.assert_eq(df.toPandas(), df2.toPandas()) + def test_sql_with_pos_args(self): + df = self.connect.sql("SELECT * FROM range(10) WHERE id > ?", args=[7]) + df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > ?", args=[7]) + self.assert_eq(df.toPandas(), df2.toPandas()) + def test_head(self): # SPARK-41002: test `head` API in Python Client df = self.connect.read.table(self.tbl_name) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org