This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 26a74598830a [SPARK-52144][PYTHON][TESTS] Refactor memory profiling 
tests with shared assertion utility
26a74598830a is described below

commit 26a74598830aaf396d8dae8cbf20806cfa3ec8ab
Author: Xinrong Meng <xinr...@apache.org>
AuthorDate: Fri May 16 10:12:23 2025 +0800

    [SPARK-52144][PYTHON][TESTS] Refactor memory profiling tests with shared 
assertion utility
    
    ### What changes were proposed in this pull request?
    Refactor memory profiling tests with shared assertion utility
    
    ### Why are the changes needed?
    To reduce code duplication, improve readability and ensure consistent 
validation of UDF profiling output across test cases
    
    Part of https://issues.apache.org/jira/browse/SPARK-52093
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Test changes only.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #50896 from xinrong-meng/test_memory_profile.
    
    Authored-by: Xinrong Meng <xinr...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/tests/test_memory_profiler.py | 113 +++++++--------------------
 1 file changed, 27 insertions(+), 86 deletions(-)

diff --git a/python/pyspark/tests/test_memory_profiler.py 
b/python/pyspark/tests/test_memory_profiler.py
index 0e921b48afc4..ef427271c97a 100644
--- a/python/pyspark/tests/test_memory_profiler.py
+++ b/python/pyspark/tests/test_memory_profiler.py
@@ -219,6 +219,21 @@ class MemoryProfiler2TestsMixin:
         finally:
             sys.stdout = old_stdout
 
+    def assert_udf_memory_profile_present(self, udf_id, dump_dir=None):
+        """
+        Assert that a memory profile for the given UDF ID exists, has expected 
content,
+        and is associated with the source file of `_do_computation`.
+        """
+        with self.trap_stdout() as io:
+            self.spark.profile.show(udf_id, type="memory")
+
+        self.assertIn(f"Profile of UDF<id={udf_id}>", io.getvalue())
+        self.assertRegex(
+            io.getvalue(), 
f"Filename.*{os.path.basename(inspect.getfile(_do_computation))}"
+        )
+        if dump_dir:
+            self.assertTrue(f"udf_{udf_id}_memory.txt" in os.listdir(dump_dir))
+
     @property
     def profile_results(self):
         return self.spark._profiler_collector._memory_profile_results
@@ -246,15 +261,7 @@ class MemoryProfiler2TestsMixin:
 
             for id in self.profile_results:
                 self.assertIn(f"Profile of UDF<id={id}>", io_all.getvalue())
-
-                with self.trap_stdout() as io:
-                    self.spark.profile.show(id, type="memory")
-
-                self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
-                self.assertRegex(
-                    io.getvalue(), 
f"Filename.*{os.path.basename(inspect.getfile(_do_computation))}"
-                )
-                self.assertTrue(f"udf_{id}_memory.txt" in os.listdir(d))
+                self.assert_udf_memory_profile_present(udf_id=id, dump_dir=d)
 
     @unittest.skipIf(
         not have_pandas or not have_pyarrow,
@@ -267,13 +274,7 @@ class MemoryProfiler2TestsMixin:
         self.assertEqual(3, len(self.profile_results), 
str(list(self.profile_results)))
 
         for id in self.profile_results:
-            with self.trap_stdout() as io:
-                self.spark.profile.show(id, type="memory")
-
-            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
-            self.assertRegex(
-                io.getvalue(), 
f"Filename.*{os.path.basename(inspect.getfile(_do_computation))}"
-            )
+            self.assert_udf_memory_profile_present(udf_id=id)
 
     def test_memory_profiler_udf_multiple_actions(self):
         def action(df):
@@ -286,13 +287,7 @@ class MemoryProfiler2TestsMixin:
         self.assertEqual(3, len(self.profile_results), 
str(list(self.profile_results)))
 
         for id in self.profile_results:
-            with self.trap_stdout() as io:
-                self.spark.profile.show(id, type="memory")
-
-            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
-            self.assertRegex(
-                io.getvalue(), 
f"Filename.*{os.path.basename(inspect.getfile(_do_computation))}"
-            )
+            self.assert_udf_memory_profile_present(udf_id=id)
 
     def test_memory_profiler_udf_registered(self):
         @udf("long")
@@ -307,13 +302,7 @@ class MemoryProfiler2TestsMixin:
         self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
 
         for id in self.profile_results:
-            with self.trap_stdout() as io:
-                self.spark.profile.show(id, type="memory")
-
-            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
-            self.assertRegex(
-                io.getvalue(), 
f"Filename.*{os.path.basename(inspect.getfile(_do_computation))}"
-            )
+            self.assert_udf_memory_profile_present(udf_id=id)
 
     @unittest.skipIf(
         not have_pandas or not have_pyarrow,
@@ -337,13 +326,7 @@ class MemoryProfiler2TestsMixin:
         self.assertEqual(3, len(self.profile_results), 
str(self.profile_results.keys()))
 
         for id in self.profile_results:
-            with self.trap_stdout() as io:
-                self.spark.profile.show(id, type="memory")
-
-            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
-            self.assertRegex(
-                io.getvalue(), 
f"Filename.*{os.path.basename(inspect.getfile(_do_computation))}"
-            )
+            self.assert_udf_memory_profile_present(udf_id=id)
 
     @unittest.skipIf(
         not have_pandas or not have_pyarrow,
@@ -370,13 +353,7 @@ class MemoryProfiler2TestsMixin:
         self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
 
         for id in self.profile_results:
-            with self.trap_stdout() as io:
-                self.spark.profile.show(id, type="memory")
-
-            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
-            self.assertRegex(
-                io.getvalue(), 
f"Filename.*{os.path.basename(inspect.getfile(_do_computation))}"
-            )
+            self.assert_udf_memory_profile_present(udf_id=id)
 
     @unittest.skipIf(
         not have_pandas or not have_pyarrow,
@@ -417,13 +394,7 @@ class MemoryProfiler2TestsMixin:
         self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
 
         for id in self.profile_results:
-            with self.trap_stdout() as io:
-                self.spark.profile.show(id, type="memory")
-
-            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
-            self.assertRegex(
-                io.getvalue(), 
f"Filename.*{os.path.basename(inspect.getfile(_do_computation))}"
-            )
+            self.assert_udf_memory_profile_present(udf_id=id)
 
     @unittest.skipIf(
         not have_pandas or not have_pyarrow,
@@ -446,13 +417,7 @@ class MemoryProfiler2TestsMixin:
         self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
 
         for id in self.profile_results:
-            with self.trap_stdout() as io:
-                self.spark.profile.show(id, type="memory")
-
-            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
-            self.assertRegex(
-                io.getvalue(), 
f"Filename.*{os.path.basename(inspect.getfile(_do_computation))}"
-            )
+            self.assert_udf_memory_profile_present(udf_id=id)
 
     @unittest.skipIf(
         not have_pandas or not have_pyarrow,
@@ -474,13 +439,7 @@ class MemoryProfiler2TestsMixin:
         self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
 
         for id in self.profile_results:
-            with self.trap_stdout() as io:
-                self.spark.profile.show(id, type="memory")
-
-            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
-            self.assertRegex(
-                io.getvalue(), 
f"Filename.*{os.path.basename(inspect.getfile(_do_computation))}"
-            )
+            self.assert_udf_memory_profile_present(udf_id=id)
 
     @unittest.skipIf(
         not have_pandas or not have_pyarrow,
@@ -509,13 +468,7 @@ class MemoryProfiler2TestsMixin:
         self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
 
         for id in self.profile_results:
-            with self.trap_stdout() as io:
-                self.spark.profile.show(id, type="memory")
-
-            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
-            self.assertRegex(
-                io.getvalue(), 
f"Filename.*{os.path.basename(inspect.getfile(_do_computation))}"
-            )
+            self.assert_udf_memory_profile_present(udf_id=id)
 
     @unittest.skipIf(
         not have_pandas or not have_pyarrow,
@@ -540,13 +493,7 @@ class MemoryProfiler2TestsMixin:
         self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
 
         for id in self.profile_results:
-            with self.trap_stdout() as io:
-                self.spark.profile.show(id, type="memory")
-
-            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
-            self.assertRegex(
-                io.getvalue(), 
f"Filename.*{os.path.basename(inspect.getfile(_do_computation))}"
-            )
+            self.assert_udf_memory_profile_present(udf_id=id)
 
     @unittest.skipIf(
         not have_pandas or not have_pyarrow,
@@ -569,13 +516,7 @@ class MemoryProfiler2TestsMixin:
         self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
 
         for id in self.profile_results:
-            with self.trap_stdout() as io:
-                self.spark.profile.show(id, type="memory")
-
-            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
-            self.assertRegex(
-                io.getvalue(), 
f"Filename.*{os.path.basename(inspect.getfile(_do_computation))}"
-            )
+            self.assert_udf_memory_profile_present(udf_id=id)
 
     def test_memory_profiler_clear(self):
         with self.sql_conf({"spark.sql.pyspark.udf.profiler": "memory"}):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to