This is an automated email from the ASF dual-hosted git repository.
shunping pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new c33bc972f15 Fix race condition in UserPipelineTracker.clear() and
various problems (#38537)
c33bc972f15 is described below
commit c33bc972f150f89ace5d2f3ef708cdfc1c231095
Author: Shunping Huang <[email protected]>
AuthorDate: Wed May 20 12:16:33 2026 -0400
Fix race condition in UserPipelineTracker.clear() and various problems
(#38537)
* Fix race condition in UserPipelineTracker.clear()
* Address comments.
* Fix lints.
* Apply suggestions from code review
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
* Fix failed test RecordingTest.test_describe
* Fix failed tests test_instrument_example_pipeline_to_write_cache and
test_instrument_example_pipeline_to_read_cache.
* Formatting.
* Fix InteractiveBeamTest.test_recordings_clear and test_recordings_record.
---------
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
.../runners/interactive/interactive_environment.py | 10 +-
.../runners/interactive/user_pipeline_tracker.py | 109 ++++++++++++---------
.../interactive/user_pipeline_tracker_test.py | 48 +++++++++
3 files changed, 119 insertions(+), 48 deletions(-)
diff --git
a/sdks/python/apache_beam/runners/interactive/interactive_environment.py
b/sdks/python/apache_beam/runners/interactive/interactive_environment.py
index bfb1a7f1190..b243d20ff85 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_environment.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_environment.py
@@ -365,11 +365,17 @@ class InteractiveEnvironment(object):
if self.get_cache_manager(pipeline) is cache_manager:
# NOOP if setting to the same cache_manager.
return
+ # Check if the pipeline is already tracked as a user pipeline before
cleanup.
+ is_user_pipeline = self._tracked_user_pipelines.get_user_pipeline(
+ pipeline) is pipeline
if self.get_cache_manager(pipeline):
# Invoke cleanup routine when a new cache_manager is forcefully set and
# current cache_manager is not None.
self.cleanup(pipeline)
self._cache_managers[str(id(pipeline))] = cache_manager
+ if is_user_pipeline:
+ # Re-track the user pipeline because the self.cleanup() call above
evicts it.
+ self.add_user_pipeline(pipeline)
def get_cache_manager(self, pipeline, create_if_absent=False):
"""Gets the cache manager held by current Interactive Environment for the
@@ -468,8 +474,8 @@ class InteractiveEnvironment(object):
def describe_all_recordings(self):
"""Returns a description of the recording for all watched pipelnes."""
return {
- self.pipeline_id_to_pipeline(pid): rm.describe()
- for pid, rm in self._recording_managers.items()
+ rm.user_pipeline: rm.describe()
+ for rm in self._recording_managers.values()
}
def set_pipeline_result(self, pipeline, result):
diff --git
a/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py
b/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py
index 53ee54ac8a3..4c7871c02be 100644
--- a/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py
+++ b/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py
@@ -25,6 +25,7 @@ that derived pipelines can link back to the parent user
pipeline.
"""
import shutil
+import threading
from typing import Iterator
from typing import Optional
@@ -39,13 +40,16 @@ class UserPipelineTracker:
derived pipelines.
"""
def __init__(self):
+ self._lock = threading.RLock()
self._user_pipelines: dict[beam.Pipeline, list[beam.Pipeline]] = {}
- self._derived_pipelines: dict[beam.Pipeline] = {}
- self._pid_to_pipelines: dict[beam.Pipeline] = {}
+ self._derived_pipelines: dict[beam.Pipeline, beam.Pipeline] = {}
+ self._pid_to_pipelines: dict[str, beam.Pipeline] = {}
def __iter__(self) -> Iterator[beam.Pipeline]:
"""Iterates through all the user pipelines."""
- for p in self._user_pipelines:
+ with self._lock:
+ pipelines = list(self._user_pipelines.keys())
+ for p in pipelines:
yield p
def _key(self, pipeline: beam.Pipeline) -> str:
@@ -57,45 +61,57 @@ class UserPipelineTracker:
Removes the given pipeline and derived pipelines if a user pipeline.
Otherwise, removes the given derived pipeline.
"""
- user_pipeline = self.get_user_pipeline(pipeline)
- if user_pipeline:
- for d in self._user_pipelines[user_pipeline]:
- del self._derived_pipelines[d]
- del self._user_pipelines[user_pipeline]
- elif pipeline in self._derived_pipelines:
- del self._derived_pipelines[pipeline]
+ with self._lock:
+ if pipeline in self._user_pipelines:
+ for d in self._user_pipelines[pipeline]:
+ self._derived_pipelines.pop(d, None)
+ self._pid_to_pipelines.pop(self._key(d), None)
+ self._user_pipelines.pop(pipeline, None)
+ elif pipeline in self._derived_pipelines:
+ user_pipeline = self._derived_pipelines.pop(pipeline, None)
+ if user_pipeline in self._user_pipelines:
+ try:
+ self._user_pipelines[user_pipeline].remove(pipeline)
+ except ValueError:
+ pass
+ self._pid_to_pipelines.pop(self._key(pipeline), None)
def clear(self) -> None:
"""Clears the tracker of all user and derived pipelines."""
# Remove all local_tempdir of created pipelines.
- for p in self._pid_to_pipelines.values():
- shutil.rmtree(p.local_tempdir, ignore_errors=True)
+ with self._lock:
+ pipelines = list(self._pid_to_pipelines.values())
+ self._user_pipelines.clear()
+ self._derived_pipelines.clear()
+ self._pid_to_pipelines.clear()
- self._user_pipelines.clear()
- self._derived_pipelines.clear()
- self._pid_to_pipelines.clear()
+ for p in pipelines:
+ shutil.rmtree(p.local_tempdir, ignore_errors=True)
def get_pipeline(self, pid: str) -> Optional[beam.Pipeline]:
"""Returns the pipeline corresponding to the given pipeline id."""
- return self._pid_to_pipelines.get(pid, None)
+ with self._lock:
+ return self._pid_to_pipelines.get(pid, None)
def add_user_pipeline(self, p: beam.Pipeline) -> beam.Pipeline:
"""Adds a user pipeline with an empty set of derived pipelines."""
- self._memoize_pipieline(p)
+ with self._lock:
+ self._memoize_pipeline(p)
- # Create a new node for the user pipeline if it doesn't exist already.
- user_pipeline = self.get_user_pipeline(p)
- if not user_pipeline:
- user_pipeline = p
- self._user_pipelines[p] = []
+ # Create a new node for the user pipeline if it doesn't exist already.
+ user_pipeline = self.get_user_pipeline(p)
+ if not user_pipeline:
+ user_pipeline = p
+ self._user_pipelines[p] = []
- return user_pipeline
+ return user_pipeline
- def _memoize_pipieline(self, p: beam.Pipeline) -> None:
+ def _memoize_pipeline(self, p: beam.Pipeline) -> None:
"""Memoizes the pid of the pipeline to the pipeline object."""
pid = self._key(p)
- if pid not in self._pid_to_pipelines:
- self._pid_to_pipelines[pid] = p
+ with self._lock:
+ if pid not in self._pid_to_pipelines:
+ self._pid_to_pipelines[pid] = p
def add_derived_pipeline(
self, maybe_user_pipeline: beam.Pipeline,
@@ -119,20 +135,21 @@ class UserPipelineTracker:
# Returns p.
ut.get_user_pipeline(derived2)
"""
- self._memoize_pipieline(maybe_user_pipeline)
- self._memoize_pipieline(derived_pipeline)
+ with self._lock:
+ self._memoize_pipeline(maybe_user_pipeline)
+ self._memoize_pipeline(derived_pipeline)
- # Cannot add a derived pipeline twice.
- assert derived_pipeline not in self._derived_pipelines
+ # Cannot add a derived pipeline twice.
+ assert derived_pipeline not in self._derived_pipelines
- # Get the "true" user pipeline. This allows for the user to derive a
- # pipeline from another derived pipeline, use both as arguments, and this
- # method will still get the correct user pipeline.
- user = self.add_user_pipeline(maybe_user_pipeline)
+ # Get the "true" user pipeline. This allows for the user to derive a
+ # pipeline from another derived pipeline, use both as arguments, and this
+ # method will still get the correct user pipeline.
+ user = self.add_user_pipeline(maybe_user_pipeline)
- # Map the derived pipeline to the user pipeline.
- self._derived_pipelines[derived_pipeline] = user
- self._user_pipelines[user].append(derived_pipeline)
+ # Map the derived pipeline to the user pipeline.
+ self._derived_pipelines[derived_pipeline] = user
+ self._user_pipelines[user].append(derived_pipeline)
def get_user_pipeline(self, p: beam.Pipeline) -> Optional[beam.Pipeline]:
"""Returns the user pipeline of the given pipeline.
@@ -142,14 +159,14 @@ class UserPipelineTracker:
returns the same pipeline. If the given pipeline is a derived pipeline then
this returns the user pipeline.
"""
+ with self._lock:
+ # If `p` is a user pipeline then return it.
+ if p in self._user_pipelines:
+ return p
- # If `p` is a user pipeline then return it.
- if p in self._user_pipelines:
- return p
-
- # If `p` exists then return its user pipeline.
- if p in self._derived_pipelines:
- return self._derived_pipelines[p]
+ # If `p` exists then return its user pipeline.
+ if p in self._derived_pipelines:
+ return self._derived_pipelines[p]
- # Otherwise, `p` is not in this tracker.
- return None
+ # Otherwise, `p` is not in this tracker.
+ return None
diff --git
a/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py
b/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py
index f7025b8b75b..6fb8e4dbad9 100644
--- a/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py
+++ b/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py
@@ -15,7 +15,9 @@
# limitations under the License.
#
+import threading
import unittest
+from unittest.mock import patch
import apache_beam as beam
from apache_beam.runners.interactive.user_pipeline_tracker import
UserPipelineTracker
@@ -202,6 +204,52 @@ class UserPipelineTrackerTest(unittest.TestCase):
self.assertIs(user2, ut.get_user_pipeline(derived21))
self.assertIs(user2, ut.get_user_pipeline(derived22))
+ def test_clear_race_condition(self):
+ ut = UserPipelineTracker()
+ # Add a pipeline so clear() has at least one element to iterate over.
+ p1 = beam.Pipeline()
+ derived1 = beam.Pipeline()
+ ut.add_derived_pipeline(p1, derived1)
+
+ # Set by the mock when clear() enters its loop. Signals the background
+ # worker to mutate.
+ in_loop_event = threading.Event()
+ # Set by the worker when mutation is complete. Signals mock that it can
+ # safely resume clear().
+ mutate_done_event = threading.Event()
+
+ def mock_rmtree(path, ignore_errors=False):
+ # Signal the worker that clear() is iterating.
+ in_loop_event.set()
+ # Pause here to give the worker thread time to perform the mutation.
+ mutate_done_event.wait(timeout=5)
+
+ def worker():
+ # Wait for clear() to start iterating.
+ if in_loop_event.wait(timeout=5):
+ # Concurrently mutate the tracker dictionary.
+ p2 = beam.Pipeline()
+ derived2 = beam.Pipeline()
+ try:
+ ut.add_derived_pipeline(p2, derived2)
+ finally:
+ # Resume the main thread.
+ mutate_done_event.set()
+
+ thread = threading.Thread(target=worker)
+ thread.start()
+
+ try:
+ # Intercept shutil.rmtree inside clear() to orchestrate the concurrent
+ # mutation.
+ with patch('shutil.rmtree', side_effect=mock_rmtree):
+ ut.clear()
+ finally:
+ # Avoid hanging tests if events are missed.
+ in_loop_event.set()
+ mutate_done_event.set()
+ thread.join()
+
if __name__ == '__main__':
unittest.main()