This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new bd9dd3887ea [SPARK-43967][SQL][PYTHON] Add memory limits for Python 
UDTF analyzer
bd9dd3887ea is described below

commit bd9dd3887eaf8e80a7084774fa3e893f2b91f659
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Fri Aug 4 14:58:04 2023 +0800

    [SPARK-43967][SQL][PYTHON] Add memory limits for Python UDTF analyzer
    
    ### What changes were proposed in this pull request?
    
    Adds memory limits for Python UDTF analyzer.
    
    - `spark.sql.analyzer.pythonUDTF.analyzeInPython.memory` (`None` by default)
    
    > The amount of memory to be allocated to PySpark for Python UDTF analyzer, 
in MiB unless otherwise specified. If set, PySpark memory for Python UDTF 
analyzer will be limited to this amount. If not set, Spark will not limit 
Python's memory use and it is up to the application to avoid exceeding the 
overhead memory space shared with other non-JVM processes.
    Note: Windows does not support resource limiting and actual resource is not 
limited on MacOS.
    
    ### Why are the changes needed?
    
    Python UDTF analyzer should be able to set a memory limit.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Users will be able to set the memory limit for Python UDTF analyzer.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #42328 from ueshin/issues/SPARK-44648/memory_limits.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/worker/analyze_udtf.py          |  5 +++
 python/pyspark/worker.py                           | 42 ++-----------------
 python/pyspark/worker_util.py                      | 47 ++++++++++++++++++++++
 .../org/apache/spark/sql/internal/SQLConf.scala    | 14 +++++++
 .../python/UserDefinedPythonFunction.scala         |  4 ++
 5 files changed, 73 insertions(+), 39 deletions(-)

diff --git a/python/pyspark/sql/worker/analyze_udtf.py 
b/python/pyspark/sql/worker/analyze_udtf.py
index 44dcd8c892c..9ffa03541e6 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -40,6 +40,7 @@ from pyspark.worker_util import (
     pickleSer,
     send_accumulator_updates,
     setup_broadcasts,
+    setup_memory_limits,
     setup_spark_files,
     utf8_deserializer,
 )
@@ -96,6 +97,10 @@ def main(infile: IO, outfile: IO) -> None:
     """
     try:
         check_python_version(infile)
+
+        memory_limit_mb = 
int(os.environ.get("PYSPARK_UDTF_ANALYZER_MEMORY_MB", "-1"))
+        setup_memory_limits(memory_limit_mb)
+
         setup_spark_files(infile)
         setup_broadcasts(infile)
 
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 3acfa58b6fb..b32e20e3b04 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -21,18 +21,11 @@ Worker that receives input from Piped RDD.
 import os
 import sys
 import time
-from inspect import currentframe, getframeinfo, getfullargspec
+from inspect import getfullargspec
 import json
 from typing import Iterable, Iterator
 
-# 'resource' is a Unix specific module.
-has_resource_module = True
-try:
-    import resource
-except ImportError:
-    has_resource_module = False
 import traceback
-import warnings
 import faulthandler
 
 from pyspark.accumulators import _accumulatorRegistry
@@ -70,6 +63,7 @@ from pyspark.worker_util import (
     pickleSer,
     send_accumulator_updates,
     setup_broadcasts,
+    setup_memory_limits,
     setup_spark_files,
     utf8_deserializer,
 )
@@ -998,38 +992,8 @@ def main(infile, outfile):
         boundPort = read_int(infile)
         secret = UTF8Deserializer().loads(infile)
 
-        # set up memory limits
         memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", 
"-1"))
-        if memory_limit_mb > 0 and has_resource_module:
-            total_memory = resource.RLIMIT_AS
-            try:
-                (soft_limit, hard_limit) = resource.getrlimit(total_memory)
-                msg = "Current mem limits: {0} of max 
{1}\n".format(soft_limit, hard_limit)
-                print(msg, file=sys.stderr)
-
-                # convert to bytes
-                new_limit = memory_limit_mb * 1024 * 1024
-
-                if soft_limit == resource.RLIM_INFINITY or new_limit < 
soft_limit:
-                    msg = "Setting mem limits to {0} of max 
{1}\n".format(new_limit, new_limit)
-                    print(msg, file=sys.stderr)
-                    resource.setrlimit(total_memory, (new_limit, new_limit))
-
-            except (resource.error, OSError, ValueError) as e:
-                # not all systems support resource limits, so warn instead of 
failing
-                lineno = (
-                    getframeinfo(currentframe()).lineno + 1 if currentframe() 
is not None else 0
-                )
-                if "__file__" in globals():
-                    print(
-                        warnings.formatwarning(
-                            "Failed to set memory limit: {0}".format(e),
-                            ResourceWarning,
-                            __file__,
-                            lineno,
-                        ),
-                        file=sys.stderr,
-                    )
+        setup_memory_limits(memory_limit_mb)
 
         # initialize global state
         taskContext = None
diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py
index eab0daf8f59..9f6d46c6211 100644
--- a/python/pyspark/worker_util.py
+++ b/python/pyspark/worker_util.py
@@ -19,9 +19,18 @@
 Util functions for workers.
 """
 import importlib
+from inspect import currentframe, getframeinfo
 import os
 import sys
 from typing import Any, IO
+import warnings
+
+# 'resource' is a Unix specific module.
+has_resource_module = True
+try:
+    import resource
+except ImportError:
+    has_resource_module = False
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
@@ -71,6 +80,44 @@ def check_python_version(infile: IO) -> None:
         )
 
 
+def setup_memory_limits(memory_limit_mb: int) -> None:
+    """
+    Sets up the memory limits.
+
+    If memory_limit_mb > 0 and `resource` module is available, sets the memory 
limit.
+    Windows does not support resource limiting and actual resource is not 
limited on MacOS.
+    """
+    if memory_limit_mb > 0 and has_resource_module:
+        total_memory = resource.RLIMIT_AS
+        try:
+            (soft_limit, hard_limit) = resource.getrlimit(total_memory)
+            msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, 
hard_limit)
+            print(msg, file=sys.stderr)
+
+            # convert to bytes
+            new_limit = memory_limit_mb * 1024 * 1024
+
+            if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit:
+                msg = "Setting mem limits to {0} of max 
{1}\n".format(new_limit, new_limit)
+                print(msg, file=sys.stderr)
+                resource.setrlimit(total_memory, (new_limit, new_limit))
+
+        except (resource.error, OSError, ValueError) as e:
+            # not all systems support resource limits, so warn instead of 
failing
+            curent = currentframe()
+            lineno = getframeinfo(curent).lineno + 1 if curent is not None 
else 0
+            if "__file__" in globals():
+                print(
+                    warnings.formatwarning(
+                        "Failed to set memory limit: {0}".format(e),
+                        ResourceWarning,
+                        __file__,
+                        lineno,
+                    ),
+                    file=sys.stderr,
+                )
+
+
 def setup_spark_files(infile: IO) -> None:
     """
     Set up Spark files, archives, and pyfiles.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index dfa2a0f251f..ad2d323140a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2944,6 +2944,18 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
+  val PYTHON_TABLE_UDF_ANALYZER_MEMORY =
+    buildConf("spark.sql.analyzer.pythonUDTF.analyzeInPython.memory")
+      .doc("The amount of memory to be allocated to PySpark for Python UDTF 
analyzer, in MiB " +
+        "unless otherwise specified. If set, PySpark memory for Python UDTF 
analyzer will be " +
+        "limited to this amount. If not set, Spark will not limit Python's " +
+        "memory use and it is up to the application to avoid exceeding the 
overhead memory space " +
+        "shared with other non-JVM processes.\nNote: Windows does not support 
resource limiting " +
+        "and actual resource is not limited on MacOS.")
+      .version("4.0.0")
+      .bytesConf(ByteUnit.MiB)
+      .createOptional
+
   val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME =
     
buildConf("spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName")
       .internal()
@@ -5012,6 +5024,8 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
   def pysparkWorkerPythonExecutable: Option[String] =
     getConf(SQLConf.PYSPARK_WORKER_PYTHON_EXECUTABLE)
 
+  def pythonUDTFAnalyzerMemory: Option[Long] = 
getConf(PYTHON_TABLE_UDF_ANALYZER_MEMORY)
+
   def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER)
 
   def decimalOperationsAllowPrecisionLoss: Boolean = 
getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index 05239d8d164..36cb2e17835 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -175,6 +175,7 @@ object UserDefinedPythonTableFunction {
     val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
     val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
     val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+    val workerMemoryMb = SQLConf.get.pythonUDTFAnalyzerMemory
 
     val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
 
@@ -192,6 +193,9 @@ object UserDefinedPythonTableFunction {
     if (simplifiedTraceback) {
       envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
     }
+    workerMemoryMb.foreach { memoryMb =>
+      envVars.put("PYSPARK_UDTF_ANALYZER_MEMORY_MB", memoryMb.toString)
+    }
     envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
     envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to