This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 26f4953 [SPARK-37516][PYTHON][SQL] Uses Python's standard string formatter for SQL API in PySpark 26f4953 is described below commit 26f495370fb45071f52cde6fff199d7f4b674bc7 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Wed Dec 8 13:57:35 2021 +0900 [SPARK-37516][PYTHON][SQL] Uses Python's standard string formatter for SQL API in PySpark ### What changes were proposed in this pull request? This PR proposes to use [Python's standard string formatter](https://docs.python.org/3/library/string.html#custom-string-formatting) in `SparkSession.sql`, see also https://github.com/apache/spark/pull/34677. ### Why are the changes needed? To improve usability in PySpark. It works together with Python standard string formatter. ### Does this PR introduce _any_ user-facing change? By default, there is no user-facing change. If `kwargs` is specified, yes. 1. Attribute supports from frame (standard Python support): ```python mydf = spark.range(10) spark.sql("SELECT {tbl.id}, {tbl[id]} FROM {tbl}", tbl=mydf) ``` 2. Understanding `DataFrame`: ```python mydf = spark.range(10) spark.sql("SELECT * FROM {tbl}", tbl=mydf) ``` 3. Understanding `Column`. (explicit column reference only): ```python mydf = spark.range(10) spark.sql("SELECT {c} FROM {tbl}", c=col("id"), tbl=mydf) ``` 4. Leveraging other Python string format: ```python mydf = spark.range(10) spark.sql( "SELECT {col} FROM {mydf} WHERE id IN {x}", col=mydf.id, mydf=mydf, x=tuple(range(4))) ``` ### How was this patch tested? Doctests were added. Closes #34774 from HyukjinKwon/SPARK-37516. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/pandas/sql_formatter.py | 10 ++-- python/pyspark/pandas/tests/test_sql.py | 4 -- python/pyspark/sql/session.py | 90 +++++++++++++++++++++++++++++--- python/pyspark/sql/sql_formatter.py | 84 +++++++++++++++++++++++++++++ python/pyspark/sql/tests/test_session.py | 10 +++- 5 files changed, 182 insertions(+), 16 deletions(-) diff --git a/python/pyspark/pandas/sql_formatter.py b/python/pyspark/pandas/sql_formatter.py index 685ee25..4ade2b9 100644 --- a/python/pyspark/pandas/sql_formatter.py +++ b/python/pyspark/pandas/sql_formatter.py @@ -163,7 +163,7 @@ def sql( return sql_processor.sql(query, index_col=index_col, **kwargs) session = default_session() - formatter = SQLStringFormatter(session) + formatter = PandasSQLStringFormatter(session) try: sdf = session.sql(formatter.format(query, **kwargs)) finally: @@ -178,7 +178,7 @@ def sql( ) -class SQLStringFormatter(string.Formatter): +class PandasSQLStringFormatter(string.Formatter): """ A standard ``string.Formatter`` in Python that can understand pandas-on-Spark instances with basic Python objects. This object has to be clear after the use for single SQL @@ -191,7 +191,7 @@ class SQLStringFormatter(string.Formatter): self._ref_sers: List[Tuple[Series, str]] = [] def vformat(self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> str: - ret = super(SQLStringFormatter, self).vformat(format_string, args, kwargs) + ret = super(PandasSQLStringFormatter, self).vformat(format_string, args, kwargs) for ref, n in self._ref_sers: if not any((ref is v for v in df._pssers.values()) for df, _ in self._temp_views): @@ -200,7 +200,7 @@ class SQLStringFormatter(string.Formatter): return ret def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any: - obj, first = super(SQLStringFormatter, self).get_field(field_name, args, kwargs) + obj, first = super(PandasSQLStringFormatter, self).get_field(field_name, args, kwargs) return self._convert_value(obj, field_name), first def _convert_value(self, val: Any, name: str) -> Optional[str]: @@ -256,7 +256,7 @@ def _test() -> None: globs["ps"] = pyspark.pandas spark = ( SparkSession.builder.master("local[4]") - .appName("pyspark.pandas.sql_processor tests") + .appName("pyspark.pandas.sql_formatter tests") .getOrCreate() ) (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/pandas/tests/test_sql.py b/python/pyspark/pandas/tests/test_sql.py index ca0dd99..5a5d6d4 100644 --- a/python/pyspark/pandas/tests/test_sql.py +++ b/python/pyspark/pandas/tests/test_sql.py @@ -26,10 +26,6 @@ class SQLTest(PandasOnSparkTestCase, SQLTestUtils): with self.assertRaisesRegex(KeyError, "variable_foo"): ps.sql("select * from {variable_foo}") - def test_error_unsupported_type(self): - with self.assertRaisesRegex(KeyError, "some_dict"): - ps.sql("select * from {some_dict}") - def test_error_bad_sql(self): with self.assertRaises(ParseException): ps.sql("this is not valid sql") diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 586af62..6ff63bc 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -44,6 +44,7 @@ from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame from pyspark.sql.pandas.conversion import SparkConversionMixin from pyspark.sql.readwriter import DataFrameReader +from pyspark.sql.sql_formatter import SQLStringFormatter from pyspark.sql.streaming import DataStreamReader from pyspark.sql.types import ( AtomicType, @@ -924,23 +925,100 @@ class SparkSession(SparkConversionMixin): df._schema = struct return df - def sql(self, sqlQuery: str) -> DataFrame: + def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: """Returns a :class:`DataFrame` representing the result of the given query. + When ``kwargs`` is specified, this method formats the given string by using the Python + standard formatter. .. versionadded:: 2.0.0 + Parameters + ---------- + sqlQuery : str + SQL query string. + kwargs : dict + Other variables that the user wants to set that can be referenced in the query + + .. versionchanged:: 3.3.0 + Added optional argument ``kwargs`` to specify the mapping of variables in the query. + This feature is experimental and unstable. + Returns ------- :class:`DataFrame` Examples -------- - >>> df.createOrReplaceTempView("table1") - >>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1") - >>> df2.collect() - [Row(f1=1, f2='row1'), Row(f1=2, f2='row2'), Row(f1=3, f2='row3')] + Executing a SQL query. + + >>> spark.sql("SELECT * FROM range(10) where id > 7").show() + +---+ + | id| + +---+ + | 8| + | 9| + +---+ + + Executing a SQL query with variables as Python formatter standard. + + >>> spark.sql( + ... "SELECT * FROM range(10) WHERE id > {bound1} AND id < {bound2}", bound1=7, bound2=9 + ... ).show() + +---+ + | id| + +---+ + | 8| + +---+ + + >>> mydf = spark.range(10) + >>> spark.sql( + ... "SELECT {col} FROM {mydf} WHERE id IN {x}", + ... col=mydf.id, mydf=mydf, x=tuple(range(4))).show() + +---+ + | id| + +---+ + | 0| + | 1| + | 2| + | 3| + +---+ + + >>> spark.sql(''' + ... SELECT m1.a, m2.b + ... FROM {table1} m1 INNER JOIN {table2} m2 + ... ON m1.key = m2.key + ... ORDER BY m1.a, m2.b''', + ... table1=spark.createDataFrame([(1, "a"), (2, "b")], ["a", "key"]), + ... table2=spark.createDataFrame([(3, "a"), (4, "b"), (5, "b")], ["b", "key"])).show() + +---+---+ + | a| b| + +---+---+ + | 1| 3| + | 2| 4| + | 2| 5| + +---+---+ + + Also, it is possible to query using class:`Column` from :class:`DataFrame`. + + >>> mydf = spark.createDataFrame([(1, 4), (2, 4), (3, 6)], ["A", "B"]) + >>> spark.sql("SELECT {df.A}, {df[B]} FROM {df}", df=mydf).show() + +---+---+ + | A| B| + +---+---+ + | 1| 4| + | 2| 4| + | 3| 6| + +---+---+ """ - return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped) + + formatter = SQLStringFormatter(self) + if len(kwargs) > 0: + sqlQuery = formatter.format(sqlQuery, **kwargs) + try: + return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped) + finally: + if len(kwargs) > 0: + formatter.clear() def table(self, tableName: str) -> DataFrame: """Returns the specified table as a :class:`DataFrame`. diff --git a/python/pyspark/sql/sql_formatter.py b/python/pyspark/sql/sql_formatter.py new file mode 100644 index 0000000..8528dd3 --- /dev/null +++ b/python/pyspark/sql/sql_formatter.py @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import string +import typing +from typing import Any, Optional, List, Tuple, Sequence, Mapping +import uuid + +from py4j.java_gateway import is_instance_of + +if typing.TYPE_CHECKING: + from pyspark.sql import SparkSession, DataFrame +from pyspark.sql.functions import lit + + +class SQLStringFormatter(string.Formatter): + """ + A standard ``string.Formatter`` in Python that can understand PySpark instances + with basic Python objects. This object has to be clear after the use for single SQL + query; cannot be reused across multiple SQL queries without cleaning. + """ + + def __init__(self, session: "SparkSession") -> None: + self._session: "SparkSession" = session + self._temp_views: List[Tuple[DataFrame, str]] = [] + + def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any: + obj, first = super(SQLStringFormatter, self).get_field(field_name, args, kwargs) + return self._convert_value(obj, field_name), first + + def _convert_value(self, val: Any, field_name: str) -> Optional[str]: + """ + Converts the given value into a SQL string. + """ + from pyspark import SparkContext + from pyspark.sql import Column, DataFrame + + if isinstance(val, Column): + assert SparkContext._gateway is not None # type: ignore[attr-defined] + + gw = SparkContext._gateway # type: ignore[attr-defined] + jexpr = val._jc.expr() + if is_instance_of( + gw, jexpr, "org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute" + ) or is_instance_of( + gw, jexpr, "org.apache.spark.sql.catalyst.expressions.AttributeReference" + ): + return jexpr.sql() + else: + raise ValueError( + "%s in %s should be a plain column reference such as `df.col` " + "or `col('column')`" % (val, field_name) + ) + elif isinstance(val, DataFrame): + for df, n in self._temp_views: + if df is val: + return n + df_name = "_pyspark_%s" % str(uuid.uuid4()).replace("-", "") + self._temp_views.append((val, df_name)) + val.createOrReplaceTempView(df_name) + return df_name + elif isinstance(val, str): + return lit(val)._jc.expr().sql() # for escaped characters. + else: + return val + + def clear(self) -> None: + for _, n in self._temp_views: + self._session.catalog.dropTempView(n) + self._temp_views = [] diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index 84fa23d..1262e52 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -20,6 +20,7 @@ import unittest from pyspark import SparkConf, SparkContext from pyspark.sql import SparkSession, SQLContext, Row +from pyspark.sql.functions import col from pyspark.testing.sqlutils import ReusedSQLTestCase from pyspark.testing.utils import PySparkTestCase @@ -93,7 +94,7 @@ class SparkSessionTests3(unittest.TestCase): active = SparkSession.getActiveSession() self.assertEqual(active, None) - def test_SparkSession(self): + def test_spark_session(self): spark = SparkSession.builder.master("local").config("some-config", "v2").getOrCreate() try: self.assertEqual(spark.conf.get("some-config"), "v2") @@ -105,6 +106,13 @@ class SparkSessionTests3(unittest.TestCase): spark.sql("CREATE TABLE table1 (name STRING, age INT) USING parquet") self.assertEqual(spark.table("table1").columns, ["name", "age"]) self.assertEqual(spark.range(3).count(), 3) + + # SPARK-37516: Only plain column references work as variable in SQL. + self.assertEqual( + spark.sql("select {c} from range(1)", c=col("id")).first(), spark.range(1).first() + ) + with self.assertRaisesRegex(ValueError, "Column"): + spark.sql("select {c} from range(10)", c=col("id") + 1) finally: spark.sql("DROP DATABASE test_db CASCADE") spark.stop() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org