This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new bac7050cf0a Revert "[SPARK-41811][PYTHON][CONNECT] Implement 
SparkSession.sql's string formatter"
bac7050cf0a is described below

commit bac7050cf0ad18608e921f46e40152d341d53fb8
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Fri Jul 14 09:31:17 2023 +0900

    Revert "[SPARK-41811][PYTHON][CONNECT] Implement SparkSession.sql's string 
formatter"
    
    This reverts commit 9aa42a970c4bd8e54603b1795a0f449bd556b11b.
---
 python/pyspark/pandas/sql_formatter.py      |  7 +--
 python/pyspark/sql/connect/session.py       | 35 ++++---------
 python/pyspark/sql/connect/sql_formatter.py | 78 -----------------------------
 python/pyspark/sql/sql_formatter.py         |  5 +-
 python/pyspark/sql/utils.py                 |  8 ---
 5 files changed, 16 insertions(+), 117 deletions(-)

diff --git a/python/pyspark/pandas/sql_formatter.py 
b/python/pyspark/pandas/sql_formatter.py
index 7501e19c038..8593703bd94 100644
--- a/python/pyspark/pandas/sql_formatter.py
+++ b/python/pyspark/pandas/sql_formatter.py
@@ -264,9 +264,10 @@ class PandasSQLStringFormatter(string.Formatter):
             val._to_spark().createOrReplaceTempView(df_name)
             return df_name
         elif isinstance(val, str):
-            from pyspark.sql.utils import get_lit_sql_str
-
-            return get_lit_sql_str(val)
+            # This is matched to behavior from JVM implementation.
+            # See `sql` definition from 
`sql/catalyst/src/main/scala/org/apache/spark/
+            # sql/catalyst/expressions/literals.scala`
+            return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'"
         else:
             return val
 
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 13868263174..ea88d60d760 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -489,31 +489,13 @@ class SparkSession:
 
     createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__
 
-    def sql(
-        self,
-        sqlQuery: str,
-        args: Optional[Union[Dict[str, Any], List]] = None,
-        **kwargs: Any,
-    ) -> "DataFrame":
-
-        if len(kwargs) > 0:
-            from pyspark.sql.connect.sql_formatter import SQLStringFormatter
-
-            formatter = SQLStringFormatter(self)
-            sqlQuery = formatter.format(sqlQuery, **kwargs)
-
-        try:
-            cmd = SQL(sqlQuery, args)
-            data, properties = 
self.client.execute_command(cmd.command(self._client))
-            if "sql_command_result" in properties:
-                return 
DataFrame.withPlan(CachedRelation(properties["sql_command_result"]), self)
-            else:
-                return DataFrame.withPlan(SQL(sqlQuery, args), self)
-        finally:
-            if len(kwargs) > 0:
-                # TODO: should drop temp views after SPARK-44406 get resolved
-                # formatter.clear()
-                pass
+    def sql(self, sqlQuery: str, args: Optional[Union[Dict[str, Any], List]] = 
None) -> "DataFrame":
+        cmd = SQL(sqlQuery, args)
+        data, properties = 
self.client.execute_command(cmd.command(self._client))
+        if "sql_command_result" in properties:
+            return 
DataFrame.withPlan(CachedRelation(properties["sql_command_result"]), self)
+        else:
+            return DataFrame.withPlan(SQL(sqlQuery, args), self)
 
     sql.__doc__ = PySparkSession.sql.__doc__
 
@@ -826,6 +808,9 @@ def _test() -> None:
     # RDD API is not supported in Spark Connect.
     del pyspark.sql.connect.session.SparkSession.createDataFrame.__doc__
 
+    # TODO(SPARK-41811): Implement SparkSession.sql's string formatter
+    del pyspark.sql.connect.session.SparkSession.sql.__doc__
+
     (failure_count, test_count) = doctest.testmod(
         pyspark.sql.connect.session,
         globs=globs,
diff --git a/python/pyspark/sql/connect/sql_formatter.py 
b/python/pyspark/sql/connect/sql_formatter.py
deleted file mode 100644
index ab90a1bb847..00000000000
--- a/python/pyspark/sql/connect/sql_formatter.py
+++ /dev/null
@@ -1,78 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import string
-import typing
-from typing import Any, Optional, List, Tuple, Sequence, Mapping
-import uuid
-
-if typing.TYPE_CHECKING:
-    from pyspark.sql.connect.session import SparkSession
-    from pyspark.sql.connect.dataframe import DataFrame
-
-
-class SQLStringFormatter(string.Formatter):
-    """
-    A standard ``string.Formatter`` in Python that can understand PySpark 
instances
-    with basic Python objects. This object has to be clear after the use for 
single SQL
-    query; cannot be reused across multiple SQL queries without cleaning.
-    """
-
-    def __init__(self, session: "SparkSession") -> None:
-        self._session: "SparkSession" = session
-        self._temp_views: List[Tuple[DataFrame, str]] = []
-
-    def get_field(self, field_name: str, args: Sequence[Any], kwargs: 
Mapping[str, Any]) -> Any:
-        obj, first = super(SQLStringFormatter, self).get_field(field_name, 
args, kwargs)
-        return self._convert_value(obj, field_name), first
-
-    def _convert_value(self, val: Any, field_name: str) -> Optional[str]:
-        """
-        Converts the given value into a SQL string.
-        """
-        from pyspark.sql.connect.dataframe import DataFrame
-        from pyspark.sql.connect.column import Column
-        from pyspark.sql.connect.expressions import ColumnReference
-
-        if isinstance(val, Column):
-            expr = val._expr
-            if isinstance(expr, ColumnReference):
-                return expr._unparsed_identifier
-            else:
-                raise ValueError(
-                    "%s in %s should be a plain column reference such as 
`df.col` "
-                    "or `col('column')`" % (val, field_name)
-                )
-        elif isinstance(val, DataFrame):
-            for df, n in self._temp_views:
-                if df is val:
-                    return n
-            df_name = "_pyspark_connect_%s" % str(uuid.uuid4()).replace("-", 
"")
-            self._temp_views.append((val, df_name))
-            val.createOrReplaceTempView(df_name)
-            return df_name
-        elif isinstance(val, str):
-            from pyspark.sql.utils import get_lit_sql_str
-
-            return get_lit_sql_str(val)
-        else:
-            return val
-
-    def clear(self) -> None:
-        for _, n in self._temp_views:
-            self._session.catalog.dropTempView(n)
-        self._temp_views = []
diff --git a/python/pyspark/sql/sql_formatter.py 
b/python/pyspark/sql/sql_formatter.py
index fbaa6c46a26..5e79b9ff5ea 100644
--- a/python/pyspark/sql/sql_formatter.py
+++ b/python/pyspark/sql/sql_formatter.py
@@ -24,6 +24,7 @@ from py4j.java_gateway import is_instance_of
 
 if typing.TYPE_CHECKING:
     from pyspark.sql import SparkSession, DataFrame
+from pyspark.sql.functions import lit
 
 
 class SQLStringFormatter(string.Formatter):
@@ -73,9 +74,7 @@ class SQLStringFormatter(string.Formatter):
             val.createOrReplaceTempView(df_name)
             return df_name
         elif isinstance(val, str):
-            from pyspark.sql.utils import get_lit_sql_str
-
-            return get_lit_sql_str(val)
+            return lit(val)._jc.expr().sql()  # for escaped characters.
         else:
             return val
 
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index f2874ccb10e..608ed7e9ac9 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -294,11 +294,3 @@ def get_window_class() -> Type["Window"]:
         return ConnectWindow  # type: ignore[return-value]
     else:
         return PySparkWindow
-
-
-def get_lit_sql_str(val: str) -> str:
-    # Equivalent to `lit(val)._jc.expr().sql()` for string typed val
-    # This is matched to behavior from JVM implementation.
-    # See `sql` definition from `sql/catalyst/src/main/scala/org/apache/spark/
-    # sql/catalyst/expressions/literals.scala`
-    return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'"


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to