This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 4cba54d8c0e [SPARK-44663][PYTHON] Disable arrow optimization by default for Python UDTFs 4cba54d8c0e is described below commit 4cba54d8c0e113a2587082235518be738c3d4dda Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Fri Aug 4 16:44:01 2023 -0700 [SPARK-44663][PYTHON] Disable arrow optimization by default for Python UDTFs ### What changes were proposed in this pull request? This PR disables arrow optimization by default for Python UDTFs. ### Why are the changes needed? To make Python UDTFs consistent with Python UDFs (arrow optimization is by default disabled). ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit tests Closes #42329 from allisonwang-db/spark-44663-disable-arrow. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> (cherry picked from commit 8b53fed7ef0edaaf948ec67413017e60444230fd) Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/pyspark/sql/connect/udtf.py | 10 +++++----- python/pyspark/sql/tests/test_udtf.py | 16 ++++++++++++++++ python/pyspark/sql/udtf.py | 10 +++++----- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index 3747e37459e..07e2bad6ec7 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -70,11 +70,11 @@ def _create_py_udtf( else: from pyspark.sql.connect.session import _active_spark_session - arrow_enabled = ( - _active_spark_session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") == "true" - if _active_spark_session is not None - else True - ) + arrow_enabled = False + if _active_spark_session is not None: + value = _active_spark_session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") + if isinstance(value, str) and value.lower() == "true": + arrow_enabled = True # Create a regular Python UDTF and check for invalid handler class. regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic) diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 4bab77038e0..4a65a9bd2e4 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -1002,6 +1002,22 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin): PythonEvalType.SQL_ARROW_TABLE_UDF, ) + def test_udtf_arrow_sql_conf(self): + class TestUDTF: + def eval(self): + yield 1, + + # We do not use `self.sql_conf` here to test the SQL SET command + # instead of using PySpark's `spark.conf.set`. + old_value = self.spark.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") + self.spark.sql("SET spark.sql.execution.pythonUDTF.arrow.enabled=False") + self.assertEqual(udtf(TestUDTF, returnType="x: int").evalType, PythonEvalType.SQL_TABLE_UDF) + self.spark.sql("SET spark.sql.execution.pythonUDTF.arrow.enabled=True") + self.assertEqual( + udtf(TestUDTF, returnType="x: int").evalType, PythonEvalType.SQL_ARROW_TABLE_UDF + ) + self.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", old_value) + def test_udtf_eval_returning_non_tuple(self): class TestUDTF: def eval(self, a: int): diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index c2830d56db5..7cbf4732ba9 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -69,11 +69,11 @@ def _create_py_udtf( from pyspark.sql import SparkSession session = SparkSession._instantiatedSession - arrow_enabled = ( - session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") == "true" - if session is not None - else True - ) + arrow_enabled = False + if session is not None: + value = session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") + if isinstance(value, str) and value.lower() == "true": + arrow_enabled = True # Create a regular Python UDTF and check for invalid handler class. regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 444ecbd837f..083e88380c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2926,7 +2926,7 @@ object SQLConf { .doc("Enable Arrow optimization for Python UDTFs.") .version("3.5.0") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME = buildConf("spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org