This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 216379df3543 [SPARK-47858][SPARK-47852][PYTHON][SQL] Refactoring the 
structure for DataFrame error context
216379df3543 is described below

commit 216379df35435961106c5a2aef35d5f60a6723bf
Author: Haejoon Lee <haejoon....@databricks.com>
AuthorDate: Thu Apr 18 16:01:54 2024 +0900

    [SPARK-47858][SPARK-47852][PYTHON][SQL] Refactoring the structure for 
DataFrame error context
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to refactoring the current structure for DataFrame error 
context.
    
    This change can cover the reverse binary operations, so it can cover 
SPARK-47852 as well.
    
    ### Why are the changes needed?
    
    To make future management and expansion more flexible
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, it's internal code refactoring
    
    ### How was this patch tested?
    
    The existing `DataFrameTests.test_dataframe_error_context` should pass.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46063 from itholic/error_context_refactoring.
    
    Authored-by: Haejoon Lee <haejoon....@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/errors/utils.py                     |  85 +++-
 python/pyspark/sql/column.py                       |  39 +-
 .../sql/tests/connect/test_parity_dataframe.py     |   4 -
 ...e.py => test_parity_dataframe_query_context.py} |  18 +-
 python/pyspark/sql/tests/test_dataframe.py         | 482 --------------------
 .../sql/tests/test_dataframe_query_context.py      | 497 +++++++++++++++++++++
 .../apache/spark/sql/catalyst/trees/origin.scala   |  17 +
 .../main/scala/org/apache/spark/sql/Column.scala   |  23 -
 .../main/scala/org/apache/spark/sql/package.scala  |  76 +---
 9 files changed, 617 insertions(+), 624 deletions(-)

diff --git a/python/pyspark/errors/utils.py b/python/pyspark/errors/utils.py
index e1f249506dd0..16fba7e272bc 100644
--- a/python/pyspark/errors/utils.py
+++ b/python/pyspark/errors/utils.py
@@ -16,11 +16,20 @@
 #
 
 import re
-from typing import Dict, Match
-
+import functools
+import inspect
+import os
+from typing import Any, Callable, Dict, Match, TypeVar, Type, TYPE_CHECKING
 from pyspark.errors.error_classes import ERROR_CLASSES_MAP
 
 
+if TYPE_CHECKING:
+    from pyspark.sql import SparkSession
+    from py4j.java_gateway import JavaClass
+
+T = TypeVar("T")
+
+
 class ErrorClassesReader:
     """
     A reader to load error information from error_classes.py.
@@ -119,3 +128,75 @@ class ErrorClassesReader:
             message_template = main_message_template + " " + 
sub_message_template
 
         return message_template
+
+
+def _capture_call_site(
+    spark_session: "SparkSession", pyspark_origin: "JavaClass", fragment: str
+) -> None:
+    """
+    Capture the call site information including file name, line number, and 
function name.
+    This function updates the thread-local storage from JVM side 
(PySparkCurrentOrigin)
+    with the current call site information when a PySpark API function is 
called.
+
+    Parameters
+    ----------
+    spark_session : SparkSession
+        Current active Spark session.
+    pyspark_origin : py4j.JavaClass
+        PySparkCurrentOrigin from current active Spark session.
+    fragment : str
+        The name of the PySpark API function being captured.
+
+    Notes
+    -----
+    The call site information is used to enhance error messages with the exact 
location
+    in the user code that led to the error.
+    """
+    stack = list(reversed(inspect.stack()))
+    depth = int(
+        spark_session.conf.get("spark.sql.stackTracesInDataFrameContext")  # 
type: ignore[arg-type]
+    )
+    selected_frames = stack[:depth]
+    call_sites = [f"{frame.filename}:{frame.lineno}" for frame in 
selected_frames]
+    call_sites_str = "\n".join(call_sites)
+
+    pyspark_origin.set(fragment, call_sites_str)
+
+
+def _with_origin(func: Callable[..., Any]) -> Callable[..., Any]:
+    """
+    A decorator to capture and provide the call site information to the server 
side
+    when PySpark API functions are invoked.
+    """
+
+    @functools.wraps(func)
+    def wrapper(*args: Any, **kwargs: Any) -> Any:
+        from pyspark.sql import SparkSession
+
+        spark = SparkSession.getActiveSession()
+        if spark is not None:
+            assert spark._jvm is not None
+            pyspark_origin = 
spark._jvm.org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
+
+            # Update call site when the function is called
+            _capture_call_site(spark, pyspark_origin, func.__name__)
+
+            try:
+                return func(*args, **kwargs)
+            finally:
+                pyspark_origin.clear()
+        else:
+            return func(*args, **kwargs)
+
+    return wrapper
+
+
+def with_origin_to_class(cls: Type[T]) -> Type[T]:
+    """
+    Decorate all methods of a class with `_with_origin` to capture call site 
information.
+    """
+    if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
+        for name, method in cls.__dict__.items():
+            if callable(method) and name != "__init__":
+                setattr(cls, name, _with_origin(method))
+    return cls
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index fb266b03c2ff..2e79e30285ba 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -18,7 +18,6 @@
 import sys
 import json
 import warnings
-import inspect
 from typing import (
     cast,
     overload,
@@ -33,6 +32,7 @@ from typing import (
 )
 
 from pyspark.errors import PySparkAttributeError, PySparkTypeError, 
PySparkValueError
+from pyspark.errors.utils import with_origin_to_class
 from pyspark.sql.types import DataType
 from pyspark.sql.utils import get_active_spark_context
 
@@ -175,46 +175,13 @@ def _bin_op(
     ["Column", Union["Column", "LiteralType", "DecimalLiteral", 
"DateTimeLiteral"]], "Column"
 ]:
     """Create a method for given binary operator"""
-    binary_operator_map = {
-        "plus": "+",
-        "minus": "-",
-        "divide": "/",
-        "multiply": "*",
-        "mod": "%",
-        "equalTo": "=",
-        "lt": "<",
-        "leq": "<=",
-        "geq": ">=",
-        "gt": ">",
-        "eqNullSafe": "<=>",
-        "bitwiseOR": "|",
-        "bitwiseAND": "&",
-        "bitwiseXOR": "^",
-        # Just following JVM rule even if the names of source and target are 
the same.
-        "and": "and",
-        "or": "or",
-    }
 
     def _(
         self: "Column",
         other: Union["Column", "LiteralType", "DecimalLiteral", 
"DateTimeLiteral"],
     ) -> "Column":
         jc = other._jc if isinstance(other, Column) else other
-        if name in binary_operator_map:
-            from pyspark.sql import SparkSession
-
-            spark = SparkSession._getActiveSessionOrCreate()
-            stack = list(reversed(inspect.stack()))
-            depth = int(
-                spark.conf.get("spark.sql.stackTracesInDataFrameContext")  # 
type: ignore[arg-type]
-            )
-            selected_frames = stack[:depth]
-            call_sites = [f"{frame.filename}:{frame.lineno}" for frame in 
selected_frames]
-            call_site_str = "\n".join(call_sites)
-
-            njc = getattr(self._jc, "fn")(binary_operator_map[name], jc, name, 
call_site_str)
-        else:
-            njc = getattr(self._jc, name)(jc)
+        njc = getattr(self._jc, name)(jc)
         return Column(njc)
 
     _.__doc__ = doc
@@ -234,9 +201,11 @@ def _reverse_op(
         return Column(jc)
 
     _.__doc__ = doc
+    _.__name__ = name
     return _
 
 
+@with_origin_to_class
 class Column:
 
     """
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py 
b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index 6210d4ec72fe..343f485553a9 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -30,10 +30,6 @@ class DataFrameParityTests(DataFrameTestsMixin, 
ReusedConnectTestCase):
     def test_toDF_with_schema_string(self):
         super().test_toDF_with_schema_string()
 
-    @unittest.skip("Spark Connect does not support DataFrameQueryContext 
currently.")
-    def test_dataframe_error_context(self):
-        super().test_dataframe_error_context()
-
 
 if __name__ == "__main__":
     import unittest
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py 
b/python/pyspark/sql/tests/connect/test_parity_dataframe_query_context.py
similarity index 66%
copy from python/pyspark/sql/tests/connect/test_parity_dataframe.py
copy to python/pyspark/sql/tests/connect/test_parity_dataframe_query_context.py
index 6210d4ec72fe..38bcd5643984 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe_query_context.py
@@ -17,27 +17,19 @@
 
 import unittest
 
-from pyspark.sql.tests.test_dataframe import DataFrameTestsMixin
+from pyspark.sql.tests.test_dataframe_query_context import 
DataFrameQueryContextTestsMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
-class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
-    def test_help_command(self):
-        df = self.spark.createDataFrame(data=[{"foo": "bar"}, {"foo": "baz"}])
-        super().check_help_command(df)
-
-    @unittest.skip("Spark Connect does not support RDD but the tests depend on 
them.")
-    def test_toDF_with_schema_string(self):
-        super().test_toDF_with_schema_string()
-
+class DataFrameParityTests(DataFrameQueryContextTestsMixin, 
ReusedConnectTestCase):
     @unittest.skip("Spark Connect does not support DataFrameQueryContext 
currently.")
-    def test_dataframe_error_context(self):
-        super().test_dataframe_error_context()
+    def test_dataframe_query_context(self):
+        super().test_dataframe_query_context()
 
 
 if __name__ == "__main__":
     import unittest
-    from pyspark.sql.tests.connect.test_parity_dataframe import *  # noqa: F401
+    from pyspark.sql.tests.connect.test_parity_dataframe_query_context import 
*  # noqa: F401
 
     try:
         import xmlrunner  # type: ignore[import]
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index 3f6a8eece5b0..4267d8271f57 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -835,488 +835,6 @@ class DataFrameTestsMixin:
         self.assertEqual(df.schema, schema)
         self.assertEqual(df.collect(), data)
 
-    def test_dataframe_error_context(self):
-        # SPARK-47274: Add more useful contexts for PySpark DataFrame API 
errors.
-        with self.sql_conf({"spark.sql.ansi.enabled": True}):
-            df = self.spark.range(10)
-
-            # DataFrameQueryContext with pysparkLoggingInfo - divide
-            with self.assertRaises(ArithmeticException) as pe:
-                df.withColumn("div_zero", df.id / 0).collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="DIVIDE_BY_ZERO",
-                message_parameters={"config": '"spark.sql.ansi.enabled"'},
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="divide",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - plus
-            with self.assertRaises(NumberFormatException) as pe:
-                df.withColumn("plus_invalid_type", df.id + "string").collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="plus",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - minus
-            with self.assertRaises(NumberFormatException) as pe:
-                df.withColumn("minus_invalid_type", df.id - "string").collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="minus",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - multiply
-            with self.assertRaises(NumberFormatException) as pe:
-                df.withColumn("multiply_invalid_type", df.id * 
"string").collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="multiply",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - mod
-            with self.assertRaises(NumberFormatException) as pe:
-                df.withColumn("mod_invalid_type", df.id % "string").collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="mod",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - equalTo
-            with self.assertRaises(NumberFormatException) as pe:
-                df.withColumn("equalTo_invalid_type", df.id == 
"string").collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="equalTo",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - lt
-            with self.assertRaises(NumberFormatException) as pe:
-                df.withColumn("lt_invalid_type", df.id < "string").collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="lt",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - leq
-            with self.assertRaises(NumberFormatException) as pe:
-                df.withColumn("leq_invalid_type", df.id <= "string").collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="leq",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - geq
-            with self.assertRaises(NumberFormatException) as pe:
-                df.withColumn("geq_invalid_type", df.id >= "string").collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="geq",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - gt
-            with self.assertRaises(NumberFormatException) as pe:
-                df.withColumn("gt_invalid_type", df.id > "string").collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="gt",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - eqNullSafe
-            with self.assertRaises(NumberFormatException) as pe:
-                df.withColumn("eqNullSafe_invalid_type", 
df.id.eqNullSafe("string")).collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="eqNullSafe",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - and
-            with self.assertRaises(AnalysisException) as pe:
-                df.withColumn("and_invalid_type", df.id & "string").collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="DATATYPE_MISMATCH.BINARY_OP_WRONG_TYPE",
-                message_parameters={
-                    "inputType": '"BOOLEAN"',
-                    "actualDataType": '"BIGINT"',
-                    "sqlExpr": '"(id AND string)"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="and",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - or
-            with self.assertRaises(AnalysisException) as pe:
-                df.withColumn("or_invalid_type", df.id | "string").collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="DATATYPE_MISMATCH.BINARY_OP_WRONG_TYPE",
-                message_parameters={
-                    "inputType": '"BOOLEAN"',
-                    "actualDataType": '"BIGINT"',
-                    "sqlExpr": '"(id OR string)"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="or",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - bitwiseOR
-            with self.assertRaises(NumberFormatException) as pe:
-                df.withColumn("bitwiseOR_invalid_type", 
df.id.bitwiseOR("string")).collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="bitwiseOR",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - bitwiseAND
-            with self.assertRaises(NumberFormatException) as pe:
-                df.withColumn("bitwiseAND_invalid_type", 
df.id.bitwiseAND("string")).collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="bitwiseAND",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - bitwiseXOR
-            with self.assertRaises(NumberFormatException) as pe:
-                df.withColumn("bitwiseXOR_invalid_type", 
df.id.bitwiseXOR("string")).collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="bitwiseXOR",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - chained 
(`divide` is problematic)
-            with self.assertRaises(ArithmeticException) as pe:
-                df.withColumn("multiply_ten", df.id * 10).withColumn(
-                    "divide_zero", df.id / 0
-                ).withColumn("plus_ten", df.id + 10).withColumn("minus_ten", 
df.id - 10).collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="DIVIDE_BY_ZERO",
-                message_parameters={"config": '"spark.sql.ansi.enabled"'},
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="divide",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - chained (`plus` 
is problematic)
-            with self.assertRaises(NumberFormatException) as pe:
-                df.withColumn("multiply_ten", df.id * 10).withColumn(
-                    "divide_ten", df.id / 10
-                ).withColumn("plus_string", df.id + "string").withColumn(
-                    "minus_ten", df.id - 10
-                ).collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="plus",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - chained (`minus` 
is problematic)
-            with self.assertRaises(NumberFormatException) as pe:
-                df.withColumn("multiply_ten", df.id * 10).withColumn(
-                    "divide_ten", df.id / 10
-                ).withColumn("plus_ten", df.id + 10).withColumn(
-                    "minus_string", df.id - "string"
-                ).collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="minus",
-            )
-
-            # DataFrameQueryContext with pysparkLoggingInfo - chained 
(`multiply` is problematic)
-            with self.assertRaises(NumberFormatException) as pe:
-                df.withColumn("multiply_string", df.id * "string").withColumn(
-                    "divide_ten", df.id / 10
-                ).withColumn("plus_ten", df.id + 10).withColumn("minus_ten", 
df.id - 10).collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="multiply",
-            )
-
-            # Multiple expressions in df.select (`divide` is problematic)
-            with self.assertRaises(ArithmeticException) as pe:
-                df.select(df.id - 10, df.id + 4, df.id / 0, df.id * 
5).collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="DIVIDE_BY_ZERO",
-                message_parameters={"config": '"spark.sql.ansi.enabled"'},
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="divide",
-            )
-
-            # Multiple expressions in df.select (`plus` is problematic)
-            with self.assertRaises(NumberFormatException) as pe:
-                df.select(df.id - 10, df.id + "string", df.id / 10, df.id * 
5).collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="plus",
-            )
-
-            # Multiple expressions in df.select (`minus` is problematic)
-            with self.assertRaises(NumberFormatException) as pe:
-                df.select(df.id - "string", df.id + 4, df.id / 10, df.id * 
5).collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="minus",
-            )
-
-            # Multiple expressions in df.select (`multiply` is problematic)
-            with self.assertRaises(NumberFormatException) as pe:
-                df.select(df.id - 10, df.id + 4, df.id / 10, df.id * 
"string").collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="multiply",
-            )
-
-            # Multiple expressions with pre-declared expressions (`divide` is 
problematic)
-            a = df.id / 10
-            b = df.id / 0
-            with self.assertRaises(ArithmeticException) as pe:
-                df.select(a, df.id + 4, b, df.id * 5).collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="DIVIDE_BY_ZERO",
-                message_parameters={"config": '"spark.sql.ansi.enabled"'},
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="divide",
-            )
-
-            # Multiple expressions with pre-declared expressions (`plus` is 
problematic)
-            a = df.id + "string"
-            b = df.id + 4
-            with self.assertRaises(NumberFormatException) as pe:
-                df.select(df.id / 10, a, b, df.id * 5).collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="plus",
-            )
-
-            # Multiple expressions with pre-declared expressions (`minus` is 
problematic)
-            a = df.id - "string"
-            b = df.id - 5
-            with self.assertRaises(NumberFormatException) as pe:
-                df.select(a, df.id / 10, b, df.id * 5).collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="minus",
-            )
-
-            # Multiple expressions with pre-declared expressions (`multiply` 
is problematic)
-            a = df.id * "string"
-            b = df.id * 10
-            with self.assertRaises(NumberFormatException) as pe:
-                df.select(a, df.id / 10, b, df.id + 5).collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="CAST_INVALID_INPUT",
-                message_parameters={
-                    "expression": "'string'",
-                    "sourceType": '"STRING"',
-                    "targetType": '"BIGINT"',
-                    "ansiConfig": '"spark.sql.ansi.enabled"',
-                },
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="multiply",
-            )
-
-            # DataFrameQueryContext without pysparkLoggingInfo
-            with self.assertRaises(AnalysisException) as pe:
-                df.select("non-existing-column")
-            self.check_error(
-                exception=pe.exception,
-                error_class="UNRESOLVED_COLUMN.WITH_SUGGESTION",
-                message_parameters={"objectName": "`non-existing-column`", 
"proposal": "`id`"},
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="",
-            )
-
-            # SQLQueryContext
-            with self.assertRaises(ArithmeticException) as pe:
-                self.spark.sql("select 10/0").collect()
-            self.check_error(
-                exception=pe.exception,
-                error_class="DIVIDE_BY_ZERO",
-                message_parameters={"config": '"spark.sql.ansi.enabled"'},
-                query_context_type=QueryContextType.SQL,
-            )
-
-            # No QueryContext
-            with self.assertRaises(AnalysisException) as pe:
-                self.spark.sql("select * from non-existing-table")
-            self.check_error(
-                exception=pe.exception,
-                error_class="INVALID_IDENTIFIER",
-                message_parameters={"ident": "non-existing-table"},
-                query_context_type=None,
-            )
-
 
 class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/test_dataframe_query_context.py 
b/python/pyspark/sql/tests/test_dataframe_query_context.py
new file mode 100644
index 000000000000..42fb0b0e452f
--- /dev/null
+++ b/python/pyspark/sql/tests/test_dataframe_query_context.py
@@ -0,0 +1,497 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from pyspark.errors import (
+    AnalysisException,
+    ArithmeticException,
+    QueryContextType,
+    NumberFormatException,
+)
+from pyspark.testing.sqlutils import (
+    ReusedSQLTestCase,
+)
+
+
+class DataFrameQueryContextTestsMixin:
+    def test_dataframe_query_context(self):
+        # SPARK-47274: Add more useful contexts for PySpark DataFrame API 
errors.
+        with self.sql_conf({"spark.sql.ansi.enabled": True}):
+            df = self.spark.range(10)
+
+            # DataFrameQueryContext with pysparkLoggingInfo - divide
+            with self.assertRaises(ArithmeticException) as pe:
+                df.withColumn("div_zero", df.id / 0).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="DIVIDE_BY_ZERO",
+                message_parameters={"config": '"spark.sql.ansi.enabled"'},
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="divide",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - plus
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("plus_invalid_type", df.id + "string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="plus",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - minus
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("minus_invalid_type", df.id - "string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="minus",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - multiply
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("multiply_invalid_type", df.id * 
"string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="multiply",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - mod
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("mod_invalid_type", df.id % "string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="mod",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - equalTo
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("equalTo_invalid_type", df.id == 
"string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="__eq__",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - lt
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("lt_invalid_type", df.id < "string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="lt",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - leq
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("leq_invalid_type", df.id <= "string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="leq",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - geq
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("geq_invalid_type", df.id >= "string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="geq",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - gt
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("gt_invalid_type", df.id > "string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="gt",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - eqNullSafe
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("eqNullSafe_invalid_type", 
df.id.eqNullSafe("string")).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="eqNullSafe",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - bitwiseOR
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("bitwiseOR_invalid_type", 
df.id.bitwiseOR("string")).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="bitwiseOR",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - bitwiseAND
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("bitwiseAND_invalid_type", 
df.id.bitwiseAND("string")).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="bitwiseAND",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - bitwiseXOR
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("bitwiseXOR_invalid_type", 
df.id.bitwiseXOR("string")).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="bitwiseXOR",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - chained 
(`divide` is problematic)
+            with self.assertRaises(ArithmeticException) as pe:
+                df.withColumn("multiply_ten", df.id * 10).withColumn(
+                    "divide_zero", df.id / 0
+                ).withColumn("plus_ten", df.id + 10).withColumn("minus_ten", 
df.id - 10).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="DIVIDE_BY_ZERO",
+                message_parameters={"config": '"spark.sql.ansi.enabled"'},
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="divide",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - chained (`plus` 
is problematic)
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("multiply_ten", df.id * 10).withColumn(
+                    "divide_ten", df.id / 10
+                ).withColumn("plus_string", df.id + "string").withColumn(
+                    "minus_ten", df.id - 10
+                ).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="plus",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - chained (`minus` 
is problematic)
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("multiply_ten", df.id * 10).withColumn(
+                    "divide_ten", df.id / 10
+                ).withColumn("plus_ten", df.id + 10).withColumn(
+                    "minus_string", df.id - "string"
+                ).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="minus",
+            )
+
+            # DataFrameQueryContext with pysparkLoggingInfo - chained 
(`multiply` is problematic)
+            with self.assertRaises(NumberFormatException) as pe:
+                df.withColumn("multiply_string", df.id * "string").withColumn(
+                    "divide_ten", df.id / 10
+                ).withColumn("plus_ten", df.id + 10).withColumn("minus_ten", 
df.id - 10).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="multiply",
+            )
+
+            # Multiple expressions in df.select (`divide` is problematic)
+            with self.assertRaises(ArithmeticException) as pe:
+                df.select(df.id - 10, df.id + 4, df.id / 0, df.id * 
5).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="DIVIDE_BY_ZERO",
+                message_parameters={"config": '"spark.sql.ansi.enabled"'},
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="divide",
+            )
+
+            # Multiple expressions in df.select (`plus` is problematic)
+            with self.assertRaises(NumberFormatException) as pe:
+                df.select(df.id - 10, df.id + "string", df.id / 10, df.id * 
5).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="plus",
+            )
+
+            # Multiple expressions in df.select (`minus` is problematic)
+            with self.assertRaises(NumberFormatException) as pe:
+                df.select(df.id - "string", df.id + 4, df.id / 10, df.id * 
5).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="minus",
+            )
+
+            # Multiple expressions in df.select (`multiply` is problematic)
+            with self.assertRaises(NumberFormatException) as pe:
+                df.select(df.id - 10, df.id + 4, df.id / 10, df.id * 
"string").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="multiply",
+            )
+
+            # Multiple expressions with pre-declared expressions (`divide` is 
problematic)
+            a = df.id / 10
+            b = df.id / 0
+            with self.assertRaises(ArithmeticException) as pe:
+                df.select(a, df.id + 4, b, df.id * 5).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="DIVIDE_BY_ZERO",
+                message_parameters={"config": '"spark.sql.ansi.enabled"'},
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="divide",
+            )
+
+            # Multiple expressions with pre-declared expressions (`plus` is 
problematic)
+            a = df.id + "string"
+            b = df.id + 4
+            with self.assertRaises(NumberFormatException) as pe:
+                df.select(df.id / 10, a, b, df.id * 5).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="plus",
+            )
+
+            # Multiple expressions with pre-declared expressions (`minus` is 
problematic)
+            a = df.id - "string"
+            b = df.id - 5
+            with self.assertRaises(NumberFormatException) as pe:
+                df.select(a, df.id / 10, b, df.id * 5).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="minus",
+            )
+
+            # Multiple expressions with pre-declared expressions (`multiply` 
is problematic)
+            a = df.id * "string"
+            b = df.id * 10
+            with self.assertRaises(NumberFormatException) as pe:
+                df.select(a, df.id / 10, b, df.id + 5).collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="CAST_INVALID_INPUT",
+                message_parameters={
+                    "expression": "'string'",
+                    "sourceType": '"STRING"',
+                    "targetType": '"BIGINT"',
+                    "ansiConfig": '"spark.sql.ansi.enabled"',
+                },
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="multiply",
+            )
+
+            # DataFrameQueryContext without pysparkLoggingInfo
+            with self.assertRaises(AnalysisException) as pe:
+                df.select("non-existing-column")
+            self.check_error(
+                exception=pe.exception,
+                error_class="UNRESOLVED_COLUMN.WITH_SUGGESTION",
+                message_parameters={"objectName": "`non-existing-column`", 
"proposal": "`id`"},
+                query_context_type=QueryContextType.DataFrame,
+                pyspark_fragment="",
+            )
+
+            # SQLQueryContext
+            with self.assertRaises(ArithmeticException) as pe:
+                self.spark.sql("select 10/0").collect()
+            self.check_error(
+                exception=pe.exception,
+                error_class="DIVIDE_BY_ZERO",
+                message_parameters={"config": '"spark.sql.ansi.enabled"'},
+                query_context_type=QueryContextType.SQL,
+            )
+
+            # No QueryContext
+            with self.assertRaises(AnalysisException) as pe:
+                self.spark.sql("select * from non-existing-table")
+            self.check_error(
+                exception=pe.exception,
+                error_class="INVALID_IDENTIFIER",
+                message_parameters={"ident": "non-existing-table"},
+                query_context_type=None,
+            )
+
+
+class DataFrameQueryContextTests(DataFrameQueryContextTestsMixin, 
ReusedSQLTestCase):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.test_dataframe_query_context import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
index 9d3968b02535..4ecbfd631e7e 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
@@ -85,3 +85,20 @@ object CurrentOrigin {
     ret
   }
 }
+
+/**
+ * Provides detailed error context information on PySpark.
+ */
+object PySparkCurrentOrigin {
+  private val pysparkErrorContext = new ThreadLocal[Option[(String, 
String)]]() {
+    override def initialValue(): Option[(String, String)] = None
+  }
+
+  def set(fragment: String, callSite: String): Unit = {
+    pysparkErrorContext.set(Some((fragment, callSite)))
+  }
+
+  def get(): Option[(String, String)] = pysparkErrorContext.get()
+
+  def clear(): Unit = pysparkErrorContext.remove()
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index d9f32682ab69..1cf315d45f65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -172,29 +172,6 @@ class Column(val expr: Expression) extends Logging {
     Column.fn(name, this, lit(other))
   }
 
-  /**
-   * A version of the `fn` method specifically designed for binary operations 
in PySpark
-   * that require logging information.
-   * This method is used when the operation involves another Column.
-   *
-   * @param name                The name of the operation to be performed.
-   * @param other               The value to be used in the operation, which 
will be converted to a
-   *                            Column if not already one.
-   * @param pysparkFragment     A string representing the 'fragment' of the 
PySpark error context,
-   *                            typically indicates the name of PySpark 
function.
-   * @param pysparkCallSite     A string representing the 'callSite' of the 
PySpark error context,
-   *                            providing the exact location within the 
PySpark code where the
-   *                            operation originated.
-   * @return A Column resulting from the operation.
-   */
-  private def fn(
-      name: String, other: Any, pysparkFragment: String, pysparkCallSite: 
String): Column = {
-    val tupleInfo = (pysparkFragment, pysparkCallSite)
-    withOrigin(Some(tupleInfo)) {
-      Column.fn(name, this, lit(other))
-    }
-  }
-
   override def toString: String = toPrettySQL(expr)
 
   override def equals(that: Any): Boolean = that match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index 1444eea09b27..96b5e2193f27 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -20,7 +20,7 @@ package org.apache.spark
 import java.util.regex.Pattern
 
 import org.apache.spark.annotation.{DeveloperApi, Unstable}
-import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
+import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, 
PySparkCurrentOrigin}
 import org.apache.spark.sql.execution.SparkStrategy
 import org.apache.spark.sql.internal.SQLConf
 
@@ -78,31 +78,6 @@ package object sql {
    */
   private[sql] val SPARK_LEGACY_INT96_METADATA_KEY = 
"org.apache.spark.legacyINT96"
 
-  /**
-   * Captures the current Java stack trace up to a specified depth defined by 
the
-   * `spark.sql.stackTracesInDataFrameContext` configuration. This method 
helps in identifying
-   * the call sites in Spark code by filtering out the stack frames until it 
reaches the
-   * user code calling into Spark. This method is intended to be used for 
enhancing debuggability
-   * by providing detailed context about where in the Spark source code a 
particular operation
-   * was called from.
-   *
-   * This functionality is crucial for both debugging purposes and for 
providing more insightful
-   * logging and error messages. By capturing the stack trace up to a certain 
depth, it enables
-   * a more precise pinpointing of the execution flow, especially useful when 
troubleshooting
-   * complex interactions within Spark.
-   *
-   * @return An array of `StackTraceElement` representing the filtered stack 
trace.
-   */
-  private def captureStackTrace(): Array[StackTraceElement] = {
-    val st = Thread.currentThread().getStackTrace
-    var i = 0
-    // Find the beginning of Spark code traces
-    while (i < st.length && !sparkCode(st(i))) i += 1
-    // Stop at the end of the first Spark code traces
-    while (i < st.length && sparkCode(st(i))) i += 1
-    st.slice(from = i - 1, until = i + 
SQLConf.get.stackTracesInDataFrameContext)
-  }
-
   /**
    * This helper function captures the Spark API and its call site in the user 
code from the current
    * stacktrace.
@@ -123,45 +98,16 @@ package object sql {
     if (CurrentOrigin.get.stackTrace.isDefined) {
       f
     } else {
-      val origin = Origin(stackTrace = Some(captureStackTrace()))
-      CurrentOrigin.withOrigin(origin)(f)
-    }
-  }
-
-  /**
-   * This overloaded helper function captures the call site information 
specifically for PySpark,
-   * using provided PySpark logging information instead of capturing the 
current Java stack trace.
-   *
-   * This method is designed to enhance the debuggability of PySpark by 
including PySpark-specific
-   * logging information (e.g., method names and call sites within PySpark 
scripts) in debug logs,
-   * without the overhead of capturing and processing Java stack traces that 
are less relevant
-   * to PySpark developers.
-   *
-   * The `pysparkErrorContext` parameter allows for passing PySpark call site 
information, which
-   * is then included in the Origin context. This facilitates more precise and 
useful logging for
-   * troubleshooting PySpark applications.
-   *
-   * This method should be used in places where PySpark API calls are made, 
and PySpark logging
-   * information is available and beneficial for debugging purposes.
-   *
-   * @param pysparkErrorContext Optional PySpark logging information including 
the call site,
-   *                            represented as a (String, String).
-   *                            This may contain keys like "fragment" and 
"callSite" to provide
-   *                            detailed context about the PySpark call site.
-   * @param f                   The function that can utilize the modified 
Origin context with
-   *                            PySpark logging information.
-   * @return The result of executing `f` within the context of the provided 
PySpark logging
-   *         information.
-   */
-  private[sql] def withOrigin[T](
-      pysparkErrorContext: Option[(String, String)] = None)(f: => T): T = {
-    if (CurrentOrigin.get.stackTrace.isDefined) {
-      f
-    } else {
-      val origin = Origin(
-        stackTrace = Some(captureStackTrace()),
-        pysparkErrorContext = pysparkErrorContext
-      )
+      val st = Thread.currentThread().getStackTrace
+      var i = 0
+      // Find the beginning of Spark code traces
+      while (i < st.length && !sparkCode(st(i))) i += 1
+      // Stop at the end of the first Spark code traces
+      while (i < st.length && sparkCode(st(i))) i += 1
+      val origin = Origin(stackTrace = Some(st.slice(
+        from = i - 1,
+        until = i + SQLConf.get.stackTracesInDataFrameContext)),
+        pysparkErrorContext = PySparkCurrentOrigin.get())
       CurrentOrigin.withOrigin(origin)(f)
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to