This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9aa42a970c4 [SPARK-41811][PYTHON][CONNECT] Implement 
SparkSession.sql's string formatter
9aa42a970c4 is described below

commit 9aa42a970c4bd8e54603b1795a0f449bd556b11b
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Jul 13 17:58:00 2023 +0800

    [SPARK-41811][PYTHON][CONNECT] Implement SparkSession.sql's string formatter
    
    ### What changes were proposed in this pull request?
    Implement SparkSession.sql's string formatter
    
    ### Why are the changes needed?
    for parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    before:
    ```
    In [1]: spark.createDataFrame([("Alice", 6), ("Bob", 7), ("John", 10)], 
['name', 'age']).createOrReplaceTempView("person")
    
    In [2]: spark.sql("""SELECT * FROM person WHERE age < {age}""", age = 
9).show()
    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    Cell In[2], line 1
    ----> 1 spark.sql("""SELECT * FROM person WHERE age < {age}""", age = 
9).show()
    
    TypeError: sql() got an unexpected keyword argument 'age'
    ```
    
    after:
    ```
    In [1]: spark.createDataFrame([("Alice", 6), ("Bob", 7), ("John", 10)], 
['name', 'age']).createOrReplaceTempView("person")
    
    In [2]: spark.sql("""SELECT * FROM person WHERE age < {age}""", age = 
9).show()
    +-----+---+
    | name|age|
    +-----+---+
    |Alice|  6|
    |  Bob|  7|
    +-----+---+
    ```
    
    ### How was this patch tested?
    enabled doc test
    
    Closes #41980 from zhengruifeng/py_connect_sql_formatter.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/pandas/sql_formatter.py            |  7 ++---
 python/pyspark/sql/connect/session.py             | 35 ++++++++++++++++-------
 python/pyspark/sql/{ => connect}/sql_formatter.py | 30 ++++++++-----------
 python/pyspark/sql/sql_formatter.py               |  5 ++--
 python/pyspark/sql/utils.py                       |  8 ++++++
 5 files changed, 51 insertions(+), 34 deletions(-)

diff --git a/python/pyspark/pandas/sql_formatter.py 
b/python/pyspark/pandas/sql_formatter.py
index 8593703bd94..7501e19c038 100644
--- a/python/pyspark/pandas/sql_formatter.py
+++ b/python/pyspark/pandas/sql_formatter.py
@@ -264,10 +264,9 @@ class PandasSQLStringFormatter(string.Formatter):
             val._to_spark().createOrReplaceTempView(df_name)
             return df_name
         elif isinstance(val, str):
-            # This is matched to behavior from JVM implementation.
-            # See `sql` definition from 
`sql/catalyst/src/main/scala/org/apache/spark/
-            # sql/catalyst/expressions/literals.scala`
-            return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'"
+            from pyspark.sql.utils import get_lit_sql_str
+
+            return get_lit_sql_str(val)
         else:
             return val
 
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index ea88d60d760..13868263174 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -489,13 +489,31 @@ class SparkSession:
 
     createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__
 
-    def sql(self, sqlQuery: str, args: Optional[Union[Dict[str, Any], List]] = 
None) -> "DataFrame":
-        cmd = SQL(sqlQuery, args)
-        data, properties = 
self.client.execute_command(cmd.command(self._client))
-        if "sql_command_result" in properties:
-            return 
DataFrame.withPlan(CachedRelation(properties["sql_command_result"]), self)
-        else:
-            return DataFrame.withPlan(SQL(sqlQuery, args), self)
+    def sql(
+        self,
+        sqlQuery: str,
+        args: Optional[Union[Dict[str, Any], List]] = None,
+        **kwargs: Any,
+    ) -> "DataFrame":
+
+        if len(kwargs) > 0:
+            from pyspark.sql.connect.sql_formatter import SQLStringFormatter
+
+            formatter = SQLStringFormatter(self)
+            sqlQuery = formatter.format(sqlQuery, **kwargs)
+
+        try:
+            cmd = SQL(sqlQuery, args)
+            data, properties = 
self.client.execute_command(cmd.command(self._client))
+            if "sql_command_result" in properties:
+                return 
DataFrame.withPlan(CachedRelation(properties["sql_command_result"]), self)
+            else:
+                return DataFrame.withPlan(SQL(sqlQuery, args), self)
+        finally:
+            if len(kwargs) > 0:
+                # TODO: should drop temp views after SPARK-44406 get resolved
+                # formatter.clear()
+                pass
 
     sql.__doc__ = PySparkSession.sql.__doc__
 
@@ -808,9 +826,6 @@ def _test() -> None:
     # RDD API is not supported in Spark Connect.
     del pyspark.sql.connect.session.SparkSession.createDataFrame.__doc__
 
-    # TODO(SPARK-41811): Implement SparkSession.sql's string formatter
-    del pyspark.sql.connect.session.SparkSession.sql.__doc__
-
     (failure_count, test_count) = doctest.testmod(
         pyspark.sql.connect.session,
         globs=globs,
diff --git a/python/pyspark/sql/sql_formatter.py 
b/python/pyspark/sql/connect/sql_formatter.py
similarity index 76%
copy from python/pyspark/sql/sql_formatter.py
copy to python/pyspark/sql/connect/sql_formatter.py
index 5e79b9ff5ea..ab90a1bb847 100644
--- a/python/pyspark/sql/sql_formatter.py
+++ b/python/pyspark/sql/connect/sql_formatter.py
@@ -20,11 +20,9 @@ import typing
 from typing import Any, Optional, List, Tuple, Sequence, Mapping
 import uuid
 
-from py4j.java_gateway import is_instance_of
-
 if typing.TYPE_CHECKING:
-    from pyspark.sql import SparkSession, DataFrame
-from pyspark.sql.functions import lit
+    from pyspark.sql.connect.session import SparkSession
+    from pyspark.sql.connect.dataframe import DataFrame
 
 
 class SQLStringFormatter(string.Formatter):
@@ -46,20 +44,14 @@ class SQLStringFormatter(string.Formatter):
         """
         Converts the given value into a SQL string.
         """
-        from pyspark import SparkContext
-        from pyspark.sql import Column, DataFrame
+        from pyspark.sql.connect.dataframe import DataFrame
+        from pyspark.sql.connect.column import Column
+        from pyspark.sql.connect.expressions import ColumnReference
 
         if isinstance(val, Column):
-            assert SparkContext._gateway is not None
-
-            gw = SparkContext._gateway
-            jexpr = val._jc.expr()
-            if is_instance_of(
-                gw, jexpr, 
"org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute"
-            ) or is_instance_of(
-                gw, jexpr, 
"org.apache.spark.sql.catalyst.expressions.AttributeReference"
-            ):
-                return jexpr.sql()
+            expr = val._expr
+            if isinstance(expr, ColumnReference):
+                return expr._unparsed_identifier
             else:
                 raise ValueError(
                     "%s in %s should be a plain column reference such as 
`df.col` "
@@ -69,12 +61,14 @@ class SQLStringFormatter(string.Formatter):
             for df, n in self._temp_views:
                 if df is val:
                     return n
-            df_name = "_pyspark_%s" % str(uuid.uuid4()).replace("-", "")
+            df_name = "_pyspark_connect_%s" % str(uuid.uuid4()).replace("-", 
"")
             self._temp_views.append((val, df_name))
             val.createOrReplaceTempView(df_name)
             return df_name
         elif isinstance(val, str):
-            return lit(val)._jc.expr().sql()  # for escaped characters.
+            from pyspark.sql.utils import get_lit_sql_str
+
+            return get_lit_sql_str(val)
         else:
             return val
 
diff --git a/python/pyspark/sql/sql_formatter.py 
b/python/pyspark/sql/sql_formatter.py
index 5e79b9ff5ea..fbaa6c46a26 100644
--- a/python/pyspark/sql/sql_formatter.py
+++ b/python/pyspark/sql/sql_formatter.py
@@ -24,7 +24,6 @@ from py4j.java_gateway import is_instance_of
 
 if typing.TYPE_CHECKING:
     from pyspark.sql import SparkSession, DataFrame
-from pyspark.sql.functions import lit
 
 
 class SQLStringFormatter(string.Formatter):
@@ -74,7 +73,9 @@ class SQLStringFormatter(string.Formatter):
             val.createOrReplaceTempView(df_name)
             return df_name
         elif isinstance(val, str):
-            return lit(val)._jc.expr().sql()  # for escaped characters.
+            from pyspark.sql.utils import get_lit_sql_str
+
+            return get_lit_sql_str(val)
         else:
             return val
 
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 608ed7e9ac9..f2874ccb10e 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -294,3 +294,11 @@ def get_window_class() -> Type["Window"]:
         return ConnectWindow  # type: ignore[return-value]
     else:
         return PySparkWindow
+
+
+def get_lit_sql_str(val: str) -> str:
+    # Equivalent to `lit(val)._jc.expr().sql()` for string typed val
+    # This is matched to behavior from JVM implementation.
+    # See `sql` definition from `sql/catalyst/src/main/scala/org/apache/spark/
+    # sql/catalyst/expressions/literals.scala`
+    return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'"


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to