This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 86ae0d2bc198 [SPARK-47274][PYTHON][SQL] Provide more useful context 
for PySpark DataFrame API errors
86ae0d2bc198 is described below

commit 86ae0d2bc19832f5bf5d872491cdede800427691
Author: Haejoon Lee <haejoon....@databricks.com>
AuthorDate: Thu Apr 11 09:41:31 2024 +0800

    [SPARK-47274][PYTHON][SQL] Provide more useful context for PySpark 
DataFrame API errors
    
    ### What changes were proposed in this pull request?
    
    This PR introduces an enhancement to the error messages generated by 
PySpark's DataFrame API, adding detailed context about the location within the 
user's PySpark code where the error occurred.
    
    This directly adds a PySpark user call site information into 
`DataFrameQueryContext` added from https://github.com/apache/spark/pull/43334, 
aiming to provide PySpark users with the same level of detailed error context 
for better usability and debugging efficiency for DataFrame APIs.
    
    This PR also introduces `QueryContext.pysparkCallSite` and 
`QueryContext.pysparkFragment` to get a PySpark information from the query 
context easily.
    
    This PR also enhances the functionality of `check_error` so that it can 
test the query context if it exists.
    
    ### Why are the changes needed?
    
    To improve a debuggability. Errors originating from PySpark operations can 
be difficult to debug with limited context in the error messages. While 
improvements on the JVM side have been made to offer detailed error contexts, 
PySpark errors often lack this level of detail.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No API changes, but error messages will include a reference to the exact 
line of user code that triggered the error, in addition to the existing 
descriptive error message.
    
    For example, consider the following PySpark code snippet that triggers a 
`DIVIDE_BY_ZERO` error:
    
    ```python
    1  spark.conf.set("spark.sql.ansi.enabled", True)
    2
    3  df = spark.range(10)
    4  df.select(df.id / 0).show()
    ```
    
    **Before:**
    ```
    pyspark.errors.exceptions.captured.ArithmeticException: [DIVIDE_BY_ZERO] 
Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL 
instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this 
error. SQLSTATE: 22012
    == DataFrame ==
    "divide" was called from
    java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native 
Method)
    ```
    
    **After:**
    ```
    pyspark.errors.exceptions.captured.ArithmeticException: [DIVIDE_BY_ZERO] 
Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL 
instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this 
error. SQLSTATE: 22012
    == DataFrame ==
    "divide" was called from
    /.../spark/python/test_pyspark_error.py:4
    ```
    
    Now the error message points out the exact problematic code path with file 
name and line number that user writes.
    
    ## Points to the actual problem site instead of the site where the action 
was called
    
    Even when action calling after multiple transform operations are mixed, the 
exact problematic site can be provided to the user:
    
    **In:**
    
    ```python
      1 spark.conf.set("spark.sql.ansi.enabled", True)
      2 df = spark.range(10)
      3
      4 df1 = df.withColumn("div_ten", df.id / 10)
      5 df2 = df1.withColumn("plus_four", df.id + 4)
      6
      7 # This is problematic divide operation that occurs DIVIDE_BY_ZERO.
      8 df3 = df2.withColumn("div_zero", df.id / 0)
      9 df4 = df3.withColumn("minus_five", df.id / 5)
     10
     11 df4.collect()
    ```
    
    **Out:**
    
    ```
    pyspark.errors.exceptions.captured.ArithmeticException: [DIVIDE_BY_ZERO] 
Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL 
instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this 
error. SQLSTATE: 22012
    == DataFrame ==
    "divide" was called from
    /.../spark/python/test_pyspark_error.py:8
    ```
    
    ### How was this patch tested?
    
    Added UTs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #45377 from itholic/error_context_for_dataframe_api.
    
    Authored-by: Haejoon Lee <haejoon....@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 python/pyspark/errors/exceptions/captured.py       |   8 +
 python/pyspark/sql/column.py                       |  37 +-
 .../sql/tests/connect/test_parity_dataframe.py     |   4 +
 python/pyspark/sql/tests/test_dataframe.py         | 485 +++++++++++++++++++++
 python/pyspark/testing/utils.py                    |  30 ++
 .../apache/spark/sql/catalyst/parser/parsers.scala |   2 +-
 .../spark/sql/catalyst/trees/QueryContexts.scala   |  18 +-
 .../apache/spark/sql/catalyst/trees/origin.scala   |   5 +-
 .../main/scala/org/apache/spark/sql/Column.scala   |  23 +
 .../main/scala/org/apache/spark/sql/package.scala  |  73 +++-
 10 files changed, 669 insertions(+), 16 deletions(-)

diff --git a/python/pyspark/errors/exceptions/captured.py 
b/python/pyspark/errors/exceptions/captured.py
index e5ec257fb32e..2a30eba3fb22 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -409,5 +409,13 @@ class QueryContext(BaseQueryContext):
     def callSite(self) -> str:
         return str(self._q.callSite())
 
+    def pysparkFragment(self) -> Optional[str]:  # type: ignore[return]
+        if self.contextType() == QueryContextType.DataFrame:
+            return str(self._q.pysparkFragment())
+
+    def pysparkCallSite(self) -> Optional[str]:  # type: ignore[return]
+        if self.contextType() == QueryContextType.DataFrame:
+            return str(self._q.pysparkCallSite())
+
     def summary(self) -> str:
         return str(self._q.summary())
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 31c1013742a0..fb266b03c2ff 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -18,6 +18,7 @@
 import sys
 import json
 import warnings
+import inspect
 from typing import (
     cast,
     overload,
@@ -174,16 +175,50 @@ def _bin_op(
     ["Column", Union["Column", "LiteralType", "DecimalLiteral", 
"DateTimeLiteral"]], "Column"
 ]:
     """Create a method for given binary operator"""
+    binary_operator_map = {
+        "plus": "+",
+        "minus": "-",
+        "divide": "/",
+        "multiply": "*",
+        "mod": "%",
+        "equalTo": "=",
+        "lt": "<",
+        "leq": "<=",
+        "geq": ">=",
+        "gt": ">",
+        "eqNullSafe": "<=>",
+        "bitwiseOR": "|",
+        "bitwiseAND": "&",
+        "bitwiseXOR": "^",
+        # Just following JVM rule even if the names of source and target are 
the same.
+        "and": "and",
+        "or": "or",
+    }
 
     def _(
         self: "Column",
         other: Union["Column", "LiteralType", "DecimalLiteral", 
"DateTimeLiteral"],
     ) -> "Column":
         jc = other._jc if isinstance(other, Column) else other
-        njc = getattr(self._jc, name)(jc)
+        if name in binary_operator_map:
+            from pyspark.sql import SparkSession
+
+            spark = SparkSession._getActiveSessionOrCreate()
+            stack = list(reversed(inspect.stack()))
+            depth = int(
+                spark.conf.get("spark.sql.stackTracesInDataFrameContext")  # 
type: ignore[arg-type]
+            )
+            selected_frames = stack[:depth]
+            call_sites = [f"{frame.filename}:{frame.lineno}" for frame in 
selected_frames]
+            call_site_str = "\n".join(call_sites)
+
+            njc = getattr(self._jc, "fn")(binary_operator_map[name], jc, name, 
call_site_str)
+        else:
+            njc = getattr(self._jc, name)(jc)
         return Column(njc)
 
     _.__doc__ = doc
+    _.__name__ = name
     return _
 
 
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py 
b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index 343f485553a9..6210d4ec72fe 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -30,6 +30,10 @@ class DataFrameParityTests(DataFrameTestsMixin, 
ReusedConnectTestCase):
     def test_toDF_with_schema_string(self):
         super().test_toDF_with_schema_string()
 
+    @unittest.skip("Spark Connect does not support DataFrameQueryContext 
currently.")
+    def test_dataframe_error_context(self):
+        super().test_dataframe_error_context()
+
 
 if __name__ == "__main__":
     import unittest
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index 1eccb40e709c..3f6a8eece5b0 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -37,6 +37,9 @@ from pyspark.errors import (
     AnalysisException,
     IllegalArgumentException,
     PySparkTypeError,
+    ArithmeticException,
+    QueryContextType,
+    NumberFormatException,
 )
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
@@ -832,6 +835,488 @@ class DataFrameTestsMixin:
         self.assertEqual(df.schema, schema)
         self.assertEqual(df.collect(), data)
 
+    def test_dataframe_error_context(self):
+        # SPARK-47274: Add more useful contexts for PySpark DataFrame API 
errors.
+        with self.sql_conf({"spark.sql.ansi.enabled": True}):
+            df = self.spark.range(10)
+
+            # DataFrameQueryContext with pysparkLoggingInfo - divide
+            with self.assertRaises(ArithmeticException) as pe:
+                df.withColumn("div_zero", df.id / 0).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="DIVIDE_BY_ZERO",
+                message_parameters={"config": '"spark.sql.ansi.enabled"'},
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="divide",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - plus
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("plus_invalid_type", df.id + "string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="plus",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - minus
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("minus_invalid_type", df.id - "string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="minus",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - multiply
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("multiply_invalid_type", df.id * 
"string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="multiply",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - mod
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("mod_invalid_type", df.id % "string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="mod",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - equalTo
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("equalTo_invalid_type", df.id == 
"string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="equalTo",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - lt
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("lt_invalid_type", df.id < "string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="lt",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - leq
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("leq_invalid_type", df.id <= "string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="leq",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - geq
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("geq_invalid_type", df.id >= "string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="geq",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - gt
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("gt_invalid_type", df.id > "string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="gt",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - eqNullSafe
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("eqNullSafe_invalid_type", 
df.id.eqNullSafe("string")).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="eqNullSafe",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - and
+            with self.assertRaises(AnalysisException) as pe:
+                df.withColumn("and_invalid_type", df.id & "string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="DATATYPE_MISMATCH.BINARY_OP_WRONG_TYPE",
+                message_parameters={
+                    "inputType": '"BOOLEAN"',
+                    "actualDataType": '"BIGINT"',
+                    "sqlExpr": '"(id AND string)"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="and",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - or
+            with self.assertRaises(AnalysisException) as pe:
+                df.withColumn("or_invalid_type", df.id | "string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="DATATYPE_MISMATCH.BINARY_OP_WRONG_TYPE",
+                message_parameters={
+                    "inputType": '"BOOLEAN"',
+                    "actualDataType": '"BIGINT"',
+                    "sqlExpr": '"(id OR string)"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="or",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - bitwiseOR
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("bitwiseOR_invalid_type", 
df.id.bitwiseOR("string")).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="bitwiseOR",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - bitwiseAND
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("bitwiseAND_invalid_type", 
df.id.bitwiseAND("string")).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="bitwiseAND",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - bitwiseXOR
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("bitwiseXOR_invalid_type", 
df.id.bitwiseXOR("string")).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="bitwiseXOR",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - chained 
(`divide` is problematic)
+            with self.assertRaises(ArithmeticException) as pe:
+                df.withColumn("multiply_ten", df.id * 10).withColumn(
+                    "divide_zero", df.id / 0
+                ).withColumn("plus_ten", df.id + 10).withColumn("minus_ten", 
df.id - 10).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="DIVIDE_BY_ZERO",
+                message_parameters={"config": '"spark.sql.ansi.enabled"'},
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="divide",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - chained (`plus` 
is problematic)
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("multiply_ten", df.id * 10).withColumn(
+                    "divide_ten", df.id / 10
+                ).withColumn("plus_string", df.id + "string").withColumn(
+                    "minus_ten", df.id - 10
+                ).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="plus",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - chained (`minus` 
is problematic)
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("multiply_ten", df.id * 10).withColumn(
+                    "divide_ten", df.id / 10
+                ).withColumn("plus_ten", df.id + 10).withColumn(
+                    "minus_string", df.id - "string"
+                ).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="minus",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - chained 
(`multiply` is problematic)
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("multiply_string", df.id * "string").withColumn(
+                    "divide_ten", df.id / 10
+                ).withColumn("plus_ten", df.id + 10).withColumn("minus_ten", 
df.id - 10).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="multiply",
+            )
+
+            # Multiple expressions in df.select (`divide` is problematic)
+            with self.assertRaises(ArithmeticException) as pe:
+                df.select(df.id - 10, df.id + 4, df.id / 0, df.id * 
5).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="DIVIDE_BY_ZERO",
+                message_parameters={"config": '"spark.sql.ansi.enabled"'},
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="divide",
+            )
+
+            # Multiple expressions in df.select (`plus` is problematic)
+            with self.assertRaises(NumberFormatException) as pe:
+                df.select(df.id - 10, df.id + "string", df.id / 10, df.id * 
5).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="plus",
+            )
+
+            # Multiple expressions in df.select (`minus` is problematic)
+            with self.assertRaises(NumberFormatException) as pe:
+                df.select(df.id - "string", df.id + 4, df.id / 10, df.id * 
5).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="minus",
+            )
+
+            # Multiple expressions in df.select (`multiply` is problematic)
+            with self.assertRaises(NumberFormatException) as pe:
+                df.select(df.id - 10, df.id + 4, df.id / 10, df.id * 
"string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="multiply",
+            )
+
+            # Multiple expressions with pre-declared expressions (`divide` is 
problematic)
+            a = df.id / 10
+            b = df.id / 0
+            with self.assertRaises(ArithmeticException) as pe:
+                df.select(a, df.id + 4, b, df.id * 5).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="DIVIDE_BY_ZERO",
+                message_parameters={"config": '"spark.sql.ansi.enabled"'},
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="divide",
+            )
+
+            # Multiple expressions with pre-declared expressions (`plus` is 
problematic)
+            a = df.id + "string"
+            b = df.id + 4
+            with self.assertRaises(NumberFormatException) as pe:
+                df.select(df.id / 10, a, b, df.id * 5).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="plus",
+            )
+
+            # Multiple expressions with pre-declared expressions (`minus` is 
problematic)
+            a = df.id - "string"
+            b = df.id - 5
+            with self.assertRaises(NumberFormatException) as pe:
+                df.select(a, df.id / 10, b, df.id * 5).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="minus",
+            )
+
+            # Multiple expressions with pre-declared expressions (`multiply` 
is problematic)
+            a = df.id * "string"
+            b = df.id * 10
+            with self.assertRaises(NumberFormatException) as pe:
+                df.select(a, df.id / 10, b, df.id + 5).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="multiply",
+            )
+
+            # DataFrameQueryContext without pysparkLoggingInfo
+            with self.assertRaises(AnalysisException) as pe:
+                df.select("non-existing-column")
+            self.check_error(
+                exception=pe.exception,
+                error_class="UNRESOLVED_COLUMN.WITH_SUGGESTION",
+                message_parameters={"objectName": "`non-existing-column`", 
"proposal": "`id`"},
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="",
+            )
+
+            # SQLQueryContext
+            with self.assertRaises(ArithmeticException) as pe:
+                self.spark.sql("select 10/0").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="DIVIDE_BY_ZERO",
+                message_parameters={"config": '"spark.sql.ansi.enabled"'},
+                query_context_type=QueryContextType.SQL,
+            )
+
+            # No QueryContext
+            with self.assertRaises(AnalysisException) as pe:
+                self.spark.sql("select * from non-existing-table")
+            self.check_error(
+                exception=pe.exception,
+                error_class="INVALID_IDENTIFIER",
+                message_parameters={"ident": "non-existing-table"},
+                query_context_type=None,
+            )
+
 
 class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index de40685dedc0..fe25136864ee 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -54,6 +54,8 @@ except ImportError:
 
 from pyspark import SparkConf
 from pyspark.errors import PySparkAssertionError, PySparkException
+from pyspark.errors.exceptions.captured import CapturedException
+from pyspark.errors.exceptions.base import QueryContextType
 from pyspark.find_spark_home import _find_spark_home
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql import Row
@@ -284,7 +286,14 @@ class PySparkErrorTestUtils:
         exception: PySparkException,
         error_class: str,
         message_parameters: Optional[Dict[str, str]] = None,
+        query_context_type: Optional[QueryContextType] = None,
+        pyspark_fragment: Optional[str] = None,
     ):
+        query_context = exception.getQueryContext()
+        assert bool(query_context) == (query_context_type is not None), (
+            "`query_context_type` is required when QueryContext exists. "
+            f"QueryContext: {query_context}."
+        )
         # Test if given error is an instance of PySparkException.
         self.assertIsInstance(
             exception,
@@ -306,6 +315,27 @@ class PySparkErrorTestUtils:
             expected, actual, f"Expected message parameters was '{expected}', 
got '{actual}'"
         )
 
+        # Test query context
+        if query_context:
+            expected = query_context_type
+            actual_contexts = exception.getQueryContext()
+            for actual_context in actual_contexts:
+                actual = actual_context.contextType()
+                self.assertEqual(
+                    expected, actual, f"Expected QueryContext was 
'{expected}', got '{actual}'"
+                )
+                if actual == QueryContextType.DataFrame:
+                    assert (
+                        pyspark_fragment is not None
+                    ), "`pyspark_fragment` is required when QueryContextType 
is DataFrame."
+                    expected = pyspark_fragment
+                    actual = actual_context.pysparkFragment()
+                    self.assertEqual(
+                        expected,
+                        actual,
+                        f"Expected PySpark fragment was '{expected}', got 
'{actual}'",
+                    )
+
 
 def assertSchemaEqual(
     actual: StructType,
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
index 6cfa7ed195a7..0a84ecd8203f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
@@ -250,7 +250,7 @@ class ParseException private(
     val builder = new StringBuilder
     builder ++= "\n" ++= message
     start match {
-      case Origin(Some(l), Some(p), _, _, _, _, _, _) =>
+      case Origin(Some(l), Some(p), _, _, _, _, _, _, _) =>
         builder ++= s" (line $l, pos $p)\n"
         command.foreach { cmd =>
           val (above, below) = cmd.split("\n").splitAt(l)
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala
index c716002ef35c..1c2456f00bcd 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala
@@ -134,7 +134,9 @@ case class SQLQueryContext(
   override def callSite: String = throw SparkUnsupportedOperationException()
 }
 
-case class DataFrameQueryContext(stackTrace: Seq[StackTraceElement]) extends 
QueryContext {
+case class DataFrameQueryContext(
+    stackTrace: Seq[StackTraceElement],
+    pysparkErrorContext: Option[(String, String)]) extends QueryContext {
   override val contextType = QueryContextType.DataFrame
 
   override def objectType: String = throw SparkUnsupportedOperationException()
@@ -155,16 +157,26 @@ case class DataFrameQueryContext(stackTrace: 
Seq[StackTraceElement]) extends Que
 
   override val callSite: String = stackTrace.tail.mkString("\n")
 
+  val pysparkFragment: String = pysparkErrorContext.map(_._1).getOrElse("")
+  val pysparkCallSite: String = pysparkErrorContext.map(_._2).getOrElse("")
+
+  val (displayedFragment, displayedCallsite) = if 
(pysparkErrorContext.nonEmpty) {
+    (pysparkFragment, pysparkCallSite)
+  } else {
+    (fragment, callSite)
+  }
+
   override lazy val summary: String = {
     val builder = new StringBuilder
     builder ++= "== DataFrame ==\n"
     builder ++= "\""
 
-    builder ++= fragment
+    builder ++= displayedFragment
     builder ++= "\""
     builder ++= " was called from\n"
-    builder ++= callSite
+    builder ++= displayedCallsite
     builder += '\n'
+
     builder.result()
   }
 }
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
index d8469d3056d5..9d3968b02535 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
@@ -32,10 +32,11 @@ case class Origin(
     sqlText: Option[String] = None,
     objectType: Option[String] = None,
     objectName: Option[String] = None,
-    stackTrace: Option[Array[StackTraceElement]] = None) {
+    stackTrace: Option[Array[StackTraceElement]] = None,
+    pysparkErrorContext: Option[(String, String)] = None) {
 
   lazy val context: QueryContext = if (stackTrace.isDefined) {
-    DataFrameQueryContext(stackTrace.get.toImmutableArraySeq)
+    DataFrameQueryContext(stackTrace.get.toImmutableArraySeq, 
pysparkErrorContext)
   } else {
     SQLQueryContext(
       line, startPosition, startIndex, stopIndex, sqlText, objectType, 
objectName)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index fdd315a44f1e..22c09c51c237 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -171,6 +171,29 @@ class Column(val expr: Expression) extends Logging {
     Column.fn(name, this, lit(other))
   }
 
+  /**
+   * A version of the `fn` method specifically designed for binary operations 
in PySpark
+   * that require logging information.
+   * This method is used when the operation involves another Column.
+   *
+   * @param name                The name of the operation to be performed.
+   * @param other               The value to be used in the operation, which 
will be converted to a
+   *                            Column if not already one.
+   * @param pysparkFragment     A string representing the 'fragment' of the 
PySpark error context,
+   *                            typically indicates the name of PySpark 
function.
+   * @param pysparkCallSite     A string representing the 'callSite' of the 
PySpark error context,
+   *                            providing the exact location within the 
PySpark code where the
+   *                            operation originated.
+   * @return A Column resulting from the operation.
+   */
+  private def fn(
+      name: String, other: Any, pysparkFragment: String, pysparkCallSite: 
String): Column = {
+    val tupleInfo = (pysparkFragment, pysparkCallSite)
+    withOrigin(Some(tupleInfo)) {
+      Column.fn(name, this, lit(other))
+    }
+  }
+
   override def toString: String = toPrettySQL(expr)
 
   override def equals(that: Any): Boolean = that match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index 9831ce62801a..1444eea09b27 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -78,6 +78,31 @@ package object sql {
    */
   private[sql] val SPARK_LEGACY_INT96_METADATA_KEY = 
"org.apache.spark.legacyINT96"
 
+  /**
+   * Captures the current Java stack trace up to a specified depth defined by 
the
+   * `spark.sql.stackTracesInDataFrameContext` configuration. This method 
helps in identifying
+   * the call sites in Spark code by filtering out the stack frames until it 
reaches the
+   * user code calling into Spark. This method is intended to be used for 
enhancing debuggability
+   * by providing detailed context about where in the Spark source code a 
particular operation
+   * was called from.
+   *
+   * This functionality is crucial for both debugging purposes and for 
providing more insightful
+   * logging and error messages. By capturing the stack trace up to a certain 
depth, it enables
+   * a more precise pinpointing of the execution flow, especially useful when 
troubleshooting
+   * complex interactions within Spark.
+   *
+   * @return An array of `StackTraceElement` representing the filtered stack 
trace.
+   */
+  private def captureStackTrace(): Array[StackTraceElement] = {
+    val st = Thread.currentThread().getStackTrace
+    var i = 0
+    // Find the beginning of Spark code traces
+    while (i < st.length && !sparkCode(st(i))) i += 1
+    // Stop at the end of the first Spark code traces
+    while (i < st.length && sparkCode(st(i))) i += 1
+    st.slice(from = i - 1, until = i + 
SQLConf.get.stackTracesInDataFrameContext)
+  }
+
   /**
    * This helper function captures the Spark API and its call site in the user 
code from the current
    * stacktrace.
@@ -98,15 +123,45 @@ package object sql {
     if (CurrentOrigin.get.stackTrace.isDefined) {
       f
     } else {
-      val st = Thread.currentThread().getStackTrace
-      var i = 0
-      // Find the beginning of Spark code traces
-      while (i < st.length && !sparkCode(st(i))) i += 1
-      // Stop at the end of the first Spark code traces
-      while (i < st.length && sparkCode(st(i))) i += 1
-      val origin = Origin(stackTrace = Some(st.slice(
-        from = i - 1,
-        until = i + SQLConf.get.stackTracesInDataFrameContext)))
+      val origin = Origin(stackTrace = Some(captureStackTrace()))
+      CurrentOrigin.withOrigin(origin)(f)
+    }
+  }
+
+  /**
+   * This overloaded helper function captures the call site information 
specifically for PySpark,
+   * using provided PySpark logging information instead of capturing the 
current Java stack trace.
+   *
+   * This method is designed to enhance the debuggability of PySpark by 
including PySpark-specific
+   * logging information (e.g., method names and call sites within PySpark 
scripts) in debug logs,
+   * without the overhead of capturing and processing Java stack traces that 
are less relevant
+   * to PySpark developers.
+   *
+   * The `pysparkErrorContext` parameter allows for passing PySpark call site 
information, which
+   * is then included in the Origin context. This facilitates more precise and 
useful logging for
+   * troubleshooting PySpark applications.
+   *
+   * This method should be used in places where PySpark API calls are made, 
and PySpark logging
+   * information is available and beneficial for debugging purposes.
+   *
+   * @param pysparkErrorContext Optional PySpark logging information including 
the call site,
+   *                            represented as a (String, String).
+   *                            This may contain keys like "fragment" and 
"callSite" to provide
+   *                            detailed context about the PySpark call site.
+   * @param f                   The function that can utilize the modified 
Origin context with
+   *                            PySpark logging information.
+   * @return The result of executing `f` within the context of the provided 
PySpark logging
+   *         information.
+   */
+  private[sql] def withOrigin[T](
+      pysparkErrorContext: Option[(String, String)] = None)(f: => T): T = {
+    if (CurrentOrigin.get.stackTrace.isDefined) {
+      f
+    } else {
+      val origin = Origin(
+        stackTrace = Some(captureStackTrace()),
+        pysparkErrorContext = pysparkErrorContext
+      )
       CurrentOrigin.withOrigin(origin)(f)
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to