This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 8b53fed7ef0 [SPARK-44663][PYTHON] Disable arrow optimization by default for Python UDTFs 8b53fed7ef0 is described below commit 8b53fed7ef0edaaf948ec67413017e60444230fd Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Fri Aug 4 16:44:01 2023 -0700 [SPARK-44663][PYTHON] Disable arrow optimization by default for Python UDTFs ### What changes were proposed in this pull request? This PR disables arrow optimization by default for Python UDTFs. ### Why are the changes needed? To make Python UDTFs consistent with Python UDFs (arrow optimization is by default disabled). ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit tests Closes #42329 from allisonwang-db/spark-44663-disable-arrow. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/pyspark/sql/connect/udtf.py | 10 +++++----- python/pyspark/sql/tests/test_udtf.py | 16 ++++++++++++++++ python/pyspark/sql/udtf.py | 10 +++++----- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index 919994401c8..5a95075a655 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -70,11 +70,11 @@ def _create_py_udtf( else: from pyspark.sql.connect.session import _active_spark_session - arrow_enabled = ( - _active_spark_session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") == "true" - if _active_spark_session is not None - else True - ) + arrow_enabled = False + if _active_spark_session is not None: + value = _active_spark_session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") + if isinstance(value, str) and value.lower() == "true": + arrow_enabled = True # Create a regular Python UDTF and check for invalid handler class. regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic) diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 4d36b537995..9caf267e48d 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -1723,6 +1723,22 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin): PythonEvalType.SQL_ARROW_TABLE_UDF, ) + def test_udtf_arrow_sql_conf(self): + class TestUDTF: + def eval(self): + yield 1, + + # We do not use `self.sql_conf` here to test the SQL SET command + # instead of using PySpark's `spark.conf.set`. + old_value = self.spark.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") + self.spark.sql("SET spark.sql.execution.pythonUDTF.arrow.enabled=False") + self.assertEqual(udtf(TestUDTF, returnType="x: int").evalType, PythonEvalType.SQL_TABLE_UDF) + self.spark.sql("SET spark.sql.execution.pythonUDTF.arrow.enabled=True") + self.assertEqual( + udtf(TestUDTF, returnType="x: int").evalType, PythonEvalType.SQL_ARROW_TABLE_UDF + ) + self.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", old_value) + def test_udtf_eval_returning_non_tuple(self): class TestUDTF: def eval(self, a: int): diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index fea0f74c8f2..027a2646a46 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -106,11 +106,11 @@ def _create_py_udtf( from pyspark.sql import SparkSession session = SparkSession._instantiatedSession - arrow_enabled = ( - session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") == "true" - if session is not None - else True - ) + arrow_enabled = False + if session is not None: + value = session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") + if isinstance(value, str) and value.lower() == "true": + arrow_enabled = True # Create a regular Python UDTF and check for invalid handler class. regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ad2d323140a..bcf8ce2bc54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2942,7 +2942,7 @@ object SQLConf { .doc("Enable Arrow optimization for Python UDTFs.") .version("3.5.0") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val PYTHON_TABLE_UDF_ANALYZER_MEMORY = buildConf("spark.sql.analyzer.pythonUDTF.analyzeInPython.memory") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org