This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 44a17cbe429 Allow multi_process_shared objects to be called (#26202) 44a17cbe429 is described below commit 44a17cbe429699160e689bd88935c32c6c2de6b2 Author: Danny McCormick <dannymccorm...@google.com> AuthorDate: Mon Apr 24 13:35:33 2023 -0400 Allow multi_process_shared objects to be called (#26202) * Allow multi_process_shared objects to be called * Allow multi_process_shared objects to be called (fixed, test passing) * formatting * Update sdks/python/apache_beam/utils/multi_process_shared.py Co-authored-by: Anand Inguva <34158215+ananding...@users.noreply.github.com> * Type hint * Type hint --------- Co-authored-by: Anand Inguva <34158215+ananding...@users.noreply.github.com> --- .../apache_beam/utils/multi_process_shared.py | 31 ++++++++++++++++++++-- .../apache_beam/utils/multi_process_shared_test.py | 26 ++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index cc0c01dd428..07814e82977 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -50,6 +50,12 @@ class _SingletonProxy: self._SingletonProxy_entry = entry self._SingletonProxy_valid = True + # Used to make the shared object callable (see _AutoProxyWrapper below) + def singletonProxy_call__(self, *args, **kwargs): + if not self._SingletonProxy_valid: + raise RuntimeError('Entry was released.') + return self._SingletonProxy_entry.obj.__call__(*args, **kwargs) + def _SingletonProxy_release(self): assert self._SingletonProxy_valid self._SingletonProxy_valid = False @@ -61,7 +67,9 @@ class _SingletonProxy: def __dir__(self): # Needed for multiprocessing.managers's proxying. - return self._SingletonProxy_entry.obj.__dir__() + dir = self._SingletonProxy_entry.obj.__dir__() + dir.append('singletonProxy_call__') + return dir class _SingletonEntry: @@ -127,6 +135,24 @@ _SingletonRegistrar.register( callable=_process_level_singleton_manager.release_singleton) +# By default, objects registered with BaseManager.register will have only +# public methods available (excluding __call__). If you know the functions +# you would like to expose, you can do so at register time with the `exposed` +# attribute. Since we don't, we will add a wrapper around the returned AutoProxy +# object to handle __call__ function calls and turn them into +# singletonProxy_call__ calls (which is a wrapper around the underlying +# object's __call__ function) +class _AutoProxyWrapper: + def __init__(self, proxyObject: multiprocessing.managers.BaseProxy): + self._proxyObject = proxyObject + + def __call__(self, *args, **kwargs): + return self._proxyObject.singletonProxy_call__(*args, **kwargs) + + def __getattr__(self, name): + return getattr(self._proxyObject, name) + + class MultiProcessShared(Generic[T]): """MultiProcessShared is used to share a single object across processes. @@ -223,7 +249,8 @@ class MultiProcessShared(Generic[T]): # inputs) # Caveat: They must always agree, as they will be ignored if the object # is already constructed. - return self._get_manager().acquire_singleton(self._tag) + singleton = self._get_manager().acquire_singleton(self._tag) + return _AutoProxyWrapper(singleton) def release(self, obj): self._manager.release_singleton(self._tag, obj) diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index d0cc2b16a80..de2702df59c 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -23,6 +23,23 @@ import unittest from apache_beam.utils import multi_process_shared +class CallableCounter(object): + def __init__(self, start=0): + self.running = start + self.lock = threading.Lock() + + def __call__(self): + return self.running + + def increment(self, value=1): + with self.lock: + self.running += value + return self.running + + def error(self, msg): + raise RuntimeError(msg) + + class Counter(object): def __init__(self, start=0): self.running = start @@ -45,6 +62,8 @@ class MultiProcessSharedTest(unittest.TestCase): def setUpClass(cls): cls.shared = multi_process_shared.MultiProcessShared( Counter, always_proxy=True).acquire() + cls.sharedCallable = multi_process_shared.MultiProcessShared( + CallableCounter, always_proxy=True).acquire() def test_call(self): self.assertEqual(self.shared.get(), 0) @@ -53,6 +72,13 @@ class MultiProcessSharedTest(unittest.TestCase): self.assertEqual(self.shared.increment(value=10), 21) self.assertEqual(self.shared.get(), 21) + def test_call_callable(self): + self.assertEqual(self.sharedCallable(), 0) + self.assertEqual(self.sharedCallable.increment(), 1) + self.assertEqual(self.sharedCallable.increment(10), 11) + self.assertEqual(self.sharedCallable.increment(value=10), 21) + self.assertEqual(self.sharedCallable(), 21) + def test_error(self): with self.assertRaisesRegex(Exception, 'something bad'): self.shared.error('something bad')