This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 44a17cbe429 Allow multi_process_shared objects to be called (#26202)
44a17cbe429 is described below

commit 44a17cbe429699160e689bd88935c32c6c2de6b2
Author: Danny McCormick <dannymccorm...@google.com>
AuthorDate: Mon Apr 24 13:35:33 2023 -0400

    Allow multi_process_shared objects to be called (#26202)
    
    * Allow multi_process_shared objects to be called
    
    * Allow multi_process_shared objects to be called (fixed, test passing)
    
    * formatting
    
    * Update sdks/python/apache_beam/utils/multi_process_shared.py
    
    Co-authored-by: Anand Inguva <34158215+ananding...@users.noreply.github.com>
    
    * Type hint
    
    * Type hint
    
    ---------
    
    Co-authored-by: Anand Inguva <34158215+ananding...@users.noreply.github.com>
---
 .../apache_beam/utils/multi_process_shared.py      | 31 ++++++++++++++++++++--
 .../apache_beam/utils/multi_process_shared_test.py | 26 ++++++++++++++++++
 2 files changed, 55 insertions(+), 2 deletions(-)

diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py 
b/sdks/python/apache_beam/utils/multi_process_shared.py
index cc0c01dd428..07814e82977 100644
--- a/sdks/python/apache_beam/utils/multi_process_shared.py
+++ b/sdks/python/apache_beam/utils/multi_process_shared.py
@@ -50,6 +50,12 @@ class _SingletonProxy:
     self._SingletonProxy_entry = entry
     self._SingletonProxy_valid = True
 
+  # Used to make the shared object callable (see _AutoProxyWrapper below)
+  def singletonProxy_call__(self, *args, **kwargs):
+    if not self._SingletonProxy_valid:
+      raise RuntimeError('Entry was released.')
+    return self._SingletonProxy_entry.obj.__call__(*args, **kwargs)
+
   def _SingletonProxy_release(self):
     assert self._SingletonProxy_valid
     self._SingletonProxy_valid = False
@@ -61,7 +67,9 @@ class _SingletonProxy:
 
   def __dir__(self):
     # Needed for multiprocessing.managers's proxying.
-    return self._SingletonProxy_entry.obj.__dir__()
+    dir = self._SingletonProxy_entry.obj.__dir__()
+    dir.append('singletonProxy_call__')
+    return dir
 
 
 class _SingletonEntry:
@@ -127,6 +135,24 @@ _SingletonRegistrar.register(
     callable=_process_level_singleton_manager.release_singleton)
 
 
+# By default, objects registered with BaseManager.register will have only
+# public methods available (excluding __call__). If you know the functions
+# you would like to expose, you can do so at register time with the `exposed`
+# attribute. Since we don't, we will add a wrapper around the returned 
AutoProxy
+# object to handle __call__ function calls and turn them into
+# singletonProxy_call__ calls (which is a wrapper around the underlying
+# object's __call__ function)
+class _AutoProxyWrapper:
+  def __init__(self, proxyObject: multiprocessing.managers.BaseProxy):
+    self._proxyObject = proxyObject
+
+  def __call__(self, *args, **kwargs):
+    return self._proxyObject.singletonProxy_call__(*args, **kwargs)
+
+  def __getattr__(self, name):
+    return getattr(self._proxyObject, name)
+
+
 class MultiProcessShared(Generic[T]):
   """MultiProcessShared is used to share a single object across processes.
 
@@ -223,7 +249,8 @@ class MultiProcessShared(Generic[T]):
     # inputs)
     # Caveat: They must always agree, as they will be ignored if the object
     # is already constructed.
-    return self._get_manager().acquire_singleton(self._tag)
+    singleton = self._get_manager().acquire_singleton(self._tag)
+    return _AutoProxyWrapper(singleton)
 
   def release(self, obj):
     self._manager.release_singleton(self._tag, obj)
diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py 
b/sdks/python/apache_beam/utils/multi_process_shared_test.py
index d0cc2b16a80..de2702df59c 100644
--- a/sdks/python/apache_beam/utils/multi_process_shared_test.py
+++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py
@@ -23,6 +23,23 @@ import unittest
 from apache_beam.utils import multi_process_shared
 
 
+class CallableCounter(object):
+  def __init__(self, start=0):
+    self.running = start
+    self.lock = threading.Lock()
+
+  def __call__(self):
+    return self.running
+
+  def increment(self, value=1):
+    with self.lock:
+      self.running += value
+      return self.running
+
+  def error(self, msg):
+    raise RuntimeError(msg)
+
+
 class Counter(object):
   def __init__(self, start=0):
     self.running = start
@@ -45,6 +62,8 @@ class MultiProcessSharedTest(unittest.TestCase):
   def setUpClass(cls):
     cls.shared = multi_process_shared.MultiProcessShared(
         Counter, always_proxy=True).acquire()
+    cls.sharedCallable = multi_process_shared.MultiProcessShared(
+        CallableCounter, always_proxy=True).acquire()
 
   def test_call(self):
     self.assertEqual(self.shared.get(), 0)
@@ -53,6 +72,13 @@ class MultiProcessSharedTest(unittest.TestCase):
     self.assertEqual(self.shared.increment(value=10), 21)
     self.assertEqual(self.shared.get(), 21)
 
+  def test_call_callable(self):
+    self.assertEqual(self.sharedCallable(), 0)
+    self.assertEqual(self.sharedCallable.increment(), 1)
+    self.assertEqual(self.sharedCallable.increment(10), 11)
+    self.assertEqual(self.sharedCallable.increment(value=10), 21)
+    self.assertEqual(self.sharedCallable(), 21)
+
   def test_error(self):
     with self.assertRaisesRegex(Exception, 'something bad'):
       self.shared.error('something bad')

Reply via email to