On 05/10/2012 02:20 PM, Petr Viktorin wrote:
While investigating ticket 2674, I found several problems with our
implementation of the CallbackInterface ­— it required complicated
calling code, and would subtly break if command classes were
instantiated in different ways than they are currently.

Here's my fix. See commit message for details.


Rebased to current master


--
Petr³
From e6e4c0b64cd5c7fc2c80cbfe1c4fa9979b51bc59 Mon Sep 17 00:00:00 2001
From: Petr Viktorin <pvikt...@redhat.com>
Date: Wed, 25 Apr 2012 10:31:10 -0400
Subject: [PATCH] Rework the CallbackInterface

Fix several problems with the callback interface:
- Automatically registered callbacks (i.e. methods named
    exc_callback, pre_callback etc) were registered on every
    instantiation.
    Fix: Do not register callbacks in __init__; instead return the
    method when asked for it.
- The calling code had to distinguish between bound methods and
    plain functions by checking the 'im_self' attribute.
    Fix: Always return the "default" callback as an unbound method.
    Registered callbacks now always take the extra `self` argument,
    whether they happen to be bound methods or not.
    Calling code now always needs to pass the `self` argument.
- Did not work well with inheritance: due to the fact that Python
    looks up missing attributes in superclasses, callbacks could
    get attached to a superclass if it was instantiated early enough. *
    Fix: Instead of attribute lookup, use a dictionary with class keys.
- The interface included the callback types, which are LDAP-specific.
    Fix: Create generic register_callback and get_callback mehods,
    move LDAP-specific code to BaseLDAPCommand

Update code that calls the callbacks.
Add tests.
Remove lint exceptions for CallbackInterface.

* https://fedorahosted.org/freeipa/ticket/2674
---
 ipalib/cli.py                             |    9 +-
 ipalib/plugins/baseldap.py                |  354 +++++++++++------------------
 make-lint                                 |    2 -
 tests/test_xmlrpc/test_baseldap_plugin.py |   95 +++++++-
 4 files changed, 239 insertions(+), 221 deletions(-)

diff --git a/ipalib/cli.py b/ipalib/cli.py
index 8279345a909ac2e3b764d1f92cfc45ed679b9204..d53e6cd403947ad8cdce1d20cb692db10a0c3dd5 100644
--- a/ipalib/cli.py
+++ b/ipalib/cli.py
@@ -1195,8 +1195,13 @@ def prompt_interactively(self, cmd, kw):
                     param.label, param.confirm
                 )
 
-        for callback in getattr(cmd, 'INTERACTIVE_PROMPT_CALLBACKS', []):
-            callback(kw)
+        try:
+            callbacks = cmd.get_callbacks('interactive_prompt')
+        except AttributeError:
+            pass
+        else:
+            for callback in callbacks:
+                callback(cmd, kw)
 
     def load_files(self, cmd, kw):
         """
diff --git a/ipalib/plugins/baseldap.py b/ipalib/plugins/baseldap.py
index 2851f0f270d9e2bdba4780cc7bf308a76e180fd2..dd5c1411b2232522061955aa8aa8b7f9ff7f9c16 100644
--- a/ipalib/plugins/baseldap.py
+++ b/ipalib/plugins/baseldap.py
@@ -690,93 +690,57 @@ def _check_limit_object_class(attributes, attrs, allow_only):
     if len(limitattrs) > 0 and allow_only:
         raise errors.ObjectclassViolation(info='attribute "%(attribute)s" not allowed' % dict(attribute=limitattrs[0]))
 
+
 class CallbackInterface(Method):
-    """
-    Callback registration interface
-    """
-    def __init__(self):
-        #pylint: disable=E1003
-        if not hasattr(self.__class__, 'PRE_CALLBACKS'):
-            self.__class__.PRE_CALLBACKS = []
-        if not hasattr(self.__class__, 'POST_CALLBACKS'):
-            self.__class__.POST_CALLBACKS = []
-        if not hasattr(self.__class__, 'EXC_CALLBACKS'):
-            self.__class__.EXC_CALLBACKS = []
-        if not hasattr(self.__class__, 'INTERACTIVE_PROMPT_CALLBACKS'):
-            self.__class__.INTERACTIVE_PROMPT_CALLBACKS = []
-        if hasattr(self, 'pre_callback'):
-            self.register_pre_callback(self.pre_callback, True)
-        if hasattr(self, 'post_callback'):
-            self.register_post_callback(self.post_callback, True)
-        if hasattr(self, 'exc_callback'):
-            self.register_exc_callback(self.exc_callback, True)
-        if hasattr(self, 'interactive_prompt_callback'):
-            self.register_interactive_prompt_callback(
-                    self.interactive_prompt_callback, True) #pylint: disable=E1101
-        super(Method, self).__init__()
-
-    @classmethod
-    def register_pre_callback(klass, callback, first=False):
-        assert callable(callback)
-        if not hasattr(klass, 'PRE_CALLBACKS'):
-            klass.PRE_CALLBACKS = []
-        if first:
-            klass.PRE_CALLBACKS.insert(0, callback)
-        else:
-            klass.PRE_CALLBACKS.append(callback)
-
-    @classmethod
-    def register_post_callback(klass, callback, first=False):
-        assert callable(callback)
-        if not hasattr(klass, 'POST_CALLBACKS'):
-            klass.POST_CALLBACKS = []
-        if first:
-            klass.POST_CALLBACKS.insert(0, callback)
-        else:
-            klass.POST_CALLBACKS.append(callback)
-
-    @classmethod
-    def register_exc_callback(klass, callback, first=False):
-        assert callable(callback)
-        if not hasattr(klass, 'EXC_CALLBACKS'):
-            klass.EXC_CALLBACKS = []
-        if first:
-            klass.EXC_CALLBACKS.insert(0, callback)
-        else:
-            klass.EXC_CALLBACKS.append(callback)
-
-    @classmethod
-    def register_interactive_prompt_callback(klass, callback, first=False):
-        assert callable(callback)
-        if not hasattr(klass, 'INTERACTIVE_PROMPT_CALLBACKS'):
-            klass.INTERACTIVE_PROMPT_CALLBACKS = []
-        if first:
-            klass.INTERACTIVE_PROMPT_CALLBACKS.insert(0, callback)
-        else:
-            klass.INTERACTIVE_PROMPT_CALLBACKS.append(callback)
-
-    def _exc_wrapper(self, keys, options, call_func):
-        """Function wrapper that automatically calls exception callbacks"""
-        def wrapped(*call_args, **call_kwargs):
-            # call call_func first
-            func = call_func
-            callbacks = list(getattr(self, 'EXC_CALLBACKS', []))
-            while True:
-                try:
-                    return func(*call_args, **call_kwargs)
-                except errors.ExecutionError, e:
-                    if not callbacks:
-                        raise
-                    # call exc_callback in the next loop
-                    callback = callbacks.pop(0)
-                    if hasattr(callback, 'im_self'):
-                        def exc_func(*args, **kwargs):
-                            return callback(keys, options, e, call_func, *args, **kwargs)
-                    else:
-                        def exc_func(*args, **kwargs):
-                            return callback(self, keys, options, e, call_func, *args, **kwargs)
-                    func = exc_func
-        return wrapped
+    """Callback registration interface
+
+    This class's subclasses allow different types of callbacks to be added and
+    removed to them.
+    Registering a callback is done either by ``register_callback``, or by
+    defining a ``<type>_callback`` method.
+
+    Subclasses should define the `_callback_registry` attribute as a dictionary
+    mapping allowed callback types to (initially) empty dictionaries.
+    """
+
+    _callback_registry = dict()
+
+    @classmethod
+    def get_callbacks(cls, callback_type):
+        """Yield callbacks of the given type"""
+        # Use one shared callback registry, keyed on class, to avoid problems
+        # with missing attributes being looked up in superclasses
+        callbacks = cls._callback_registry[callback_type].get(cls, [None])
+        for callback in callbacks:
+            if callback is None:
+                try:
+                    yield getattr(cls, '%s_callback' % callback_type)
+                except AttributeError:
+                    pass
+            else:
+                yield callback
+
+    @classmethod
+    def register_callback(cls, callback_type, callback, first=False):
+        """Register a callback
+
+        :param callback_type: The callback type (e.g. 'pre', 'post')
+        :param callback: The callable added
+        :param first: If true, the new callback will be added before all
+            existing callbacks; otherwise it's added after them
+
+        Note that callbacks registered this way will be attached to this class
+        only, not to its subclasses.
+        """
+        assert callable(callback)
+        try:
+            callbacks = cls._callback_registry[callback_type][cls]
+        except KeyError:
+            callbacks = cls._callback_registry[callback_type][cls] = [None]
+        if first:
+            callbacks.insert(0, callback)
+        else:
+            callbacks.append(callback)
 
 
 class BaseLDAPCommand(CallbackInterface, Command):
@@ -802,6 +766,8 @@ class BaseLDAPCommand(CallbackInterface, Command):
                          exclude='webui',
                         )
 
+    _callback_registry = dict(pre={}, post={}, exc={}, interactive_prompt={})
+
     def _convert_2_dict(self, attrs):
         """
         Convert a string in the form of name/value pairs into a dictionary.
@@ -953,6 +919,45 @@ def process_attr_options(self, entry_attrs, dn, keys, options):
                 elif isinstance(entry_attrs[attr], (tuple, list)) and len(entry_attrs[attr]) == 1:
                     entry_attrs[attr] = entry_attrs[attr][0]
 
+    @classmethod
+    def register_pre_callback(cls, callback, first=False):
+        """Shortcut for register_callback('pre', ...)"""
+        cls.register_callback('pre', callback, first)
+
+    @classmethod
+    def register_post_callback(cls, callback, first=False):
+        """Shortcut for register_callback('post', ...)"""
+        cls.register_callback('post', callback, first)
+
+    @classmethod
+    def register_exc_callback(cls, callback, first=False):
+        """Shortcut for register_callback('exc', ...)"""
+        cls.register_callback('exc', callback, first)
+
+    @classmethod
+    def register_interactive_prompt_callback(cls, callback, first=False):
+        """Shortcut for register_callback('interactive_prompt', ...)"""
+        cls.register_callback('interactive_prompt', callback, first)
+
+    def _exc_wrapper(self, keys, options, call_func):
+        """Function wrapper that automatically calls exception callbacks"""
+        def wrapped(*call_args, **call_kwargs):
+            # call call_func first
+            func = call_func
+            callbacks = list(self.get_callbacks('exc'))
+            while True:
+                try:
+                    return func(*call_args, **call_kwargs)
+                except errors.ExecutionError, e:
+                    if not callbacks:
+                        raise
+                    # call exc_callback in the next loop
+                    callback = callbacks.pop(0)
+                    def exc_func(*args, **kwargs):
+                        return callback(
+                            self, keys, options, e, call_func, *args, **kwargs)
+                    func = exc_func
+        return wrapped
 
 class LDAPCreate(BaseLDAPCommand, crud.Create):
     """
@@ -1004,15 +1009,9 @@ def execute(self, *keys, **options):
                 set(self.obj.default_attributes + entry_attrs.keys())
             )
 
-        for callback in self.PRE_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(
-                    ldap, dn, entry_attrs, attrs_list, *keys, **options
-                )
-            else:
-                dn = callback(
-                    self, ldap, dn, entry_attrs, attrs_list, *keys, **options
-                )
+        for callback in self.get_callbacks('pre'):
+            dn = callback(
+                self, ldap, dn, entry_attrs, attrs_list, *keys, **options)
 
         _check_single_value_attrs(self.params, entry_attrs)
         ldap.get_schema()
@@ -1056,11 +1055,8 @@ def execute(self, *keys, **options):
         except errors.NotFound:
             self.obj.handle_not_found(*keys)
 
-        for callback in self.POST_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(ldap, dn, entry_attrs, *keys, **options)
-            else:
-                dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
+        for callback in self.get_callbacks('post'):
+            dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
 
         entry_attrs['dn'] = dn
 
@@ -1163,11 +1159,8 @@ def execute(self, *keys, **options):
         else:
             attrs_list = list(self.obj.default_attributes)
 
-        for callback in self.PRE_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(ldap, dn, attrs_list, *keys, **options)
-            else:
-                dn = callback(self, ldap, dn, attrs_list, *keys, **options)
+        for callback in self.get_callbacks('pre'):
+            dn = callback(self, ldap, dn, attrs_list, *keys, **options)
 
         try:
             (dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.get_entry)(
@@ -1179,11 +1172,8 @@ def execute(self, *keys, **options):
         if options.get('rights', False) and options.get('all', False):
             entry_attrs['attributelevelrights'] = get_effective_rights(ldap, dn)
 
-        for callback in self.POST_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(ldap, dn, entry_attrs, *keys, **options)
-            else:
-                dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
+        for callback in self.get_callbacks('post'):
+            dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
 
         self.obj.convert_attribute_members(entry_attrs, *keys, **options)
         entry_attrs['dn'] = dn
@@ -1258,15 +1248,9 @@ def execute(self, *keys, **options):
         _check_single_value_attrs(self.params, entry_attrs)
         _check_empty_attrs(self.obj.params, entry_attrs)
 
-        for callback in self.PRE_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(
-                    ldap, dn, entry_attrs, attrs_list, *keys, **options
-                )
-            else:
-                dn = callback(
-                    self, ldap, dn, entry_attrs, attrs_list, *keys, **options
-                )
+        for callback in self.get_callbacks('pre'):
+            dn = callback(
+                self, ldap, dn, entry_attrs, attrs_list, *keys, **options)
 
         ldap.get_schema()
         _check_limit_object_class(self.api.Backend.ldap2.schema.attribute_types(self.obj.limit_object_classes), entry_attrs.keys(), allow_only=True)
@@ -1313,11 +1297,8 @@ def execute(self, *keys, **options):
         if options.get('rights', False) and options.get('all', False):
             entry_attrs['attributelevelrights'] = get_effective_rights(ldap, dn)
 
-        for callback in self.POST_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(ldap, dn, entry_attrs, *keys, **options)
-            else:
-                dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
+        for callback in self.get_callbacks('post'):
+            dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
 
         self.obj.convert_attribute_members(entry_attrs, *keys, **options)
         if self.obj.primary_key and keys[-1] is not None:
@@ -1352,11 +1333,8 @@ def delete_entry(pkey):
             nkeys = keys[:-1] + (pkey, )
             dn = self.obj.get_dn(*nkeys, **options)
 
-            for callback in self.PRE_CALLBACKS:
-                if hasattr(callback, 'im_self'):
-                    dn = callback(ldap, dn, *nkeys, **options)
-                else:
-                    dn = callback(self, ldap, dn, *nkeys, **options)
+            for callback in self.get_callbacks('pre'):
+                dn = callback(self, ldap, dn, *nkeys, **options)
 
             def delete_subtree(base_dn):
                 truncated = True
@@ -1377,11 +1355,8 @@ def delete_subtree(base_dn):
 
             delete_subtree(dn)
 
-            for callback in self.POST_CALLBACKS:
-                if hasattr(callback, 'im_self'):
-                    result = callback(ldap, dn, *nkeys, **options)
-                else:
-                    result = callback(self, ldap, dn, *nkeys, **options)
+            for callback in self.get_callbacks('post'):
+                result = callback(self, ldap, dn, *nkeys, **options)
 
             return result
 
@@ -1493,13 +1468,8 @@ def execute(self, *keys, **options):
 
         dn = self.obj.get_dn(*keys, **options)
 
-        for callback in self.PRE_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(ldap, dn, member_dns, failed, *keys, **options)
-            else:
-                dn = callback(
-                    self, ldap, dn, member_dns, failed, *keys, **options
-                )
+        for callback in self.get_callbacks('pre'):
+            dn = callback(self, ldap, dn, member_dns, failed, *keys, **options)
 
         completed = 0
         for (attr, objs) in member_dns.iteritems():
@@ -1532,16 +1502,10 @@ def execute(self, *keys, **options):
         except errors.NotFound:
             self.obj.handle_not_found(*keys)
 
-        for callback in self.POST_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                (completed, dn) = callback(
-                    ldap, completed, failed, dn, entry_attrs, *keys, **options
-                )
-            else:
-                (completed, dn) = callback(
-                    self, ldap, completed, failed, dn, entry_attrs, *keys,
-                    **options
-                )
+        for callback in self.get_callbacks('post'):
+            (completed, dn) = callback(
+                self, ldap, completed, failed, dn, entry_attrs, *keys,
+                **options)
 
         entry_attrs['dn'] = dn
         self.obj.convert_attribute_members(entry_attrs, *keys, **options)
@@ -1592,13 +1556,8 @@ def execute(self, *keys, **options):
 
         dn = self.obj.get_dn(*keys, **options)
 
-        for callback in self.PRE_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(ldap, dn, member_dns, failed, *keys, **options)
-            else:
-                dn = callback(
-                    self, ldap, dn, member_dns, failed, *keys, **options
-                )
+        for callback in self.get_callbacks('pre'):
+            dn = callback(self, ldap, dn, member_dns, failed, *keys, **options)
 
         completed = 0
         for (attr, objs) in member_dns.iteritems():
@@ -1634,16 +1593,10 @@ def execute(self, *keys, **options):
         except errors.NotFound:
             self.obj.handle_not_found(*keys)
 
-        for callback in self.POST_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                (completed, dn) = callback(
-                    ldap, completed, failed, dn, entry_attrs, *keys, **options
-                )
-            else:
-                (completed, dn) = callback(
-                    self, ldap, completed, failed, dn, entry_attrs, *keys,
-                    **options
-                )
+        for callback in self.get_callbacks('post'):
+            (completed, dn) = callback(
+                self, ldap, completed, failed, dn, entry_attrs, *keys,
+                **options)
 
         entry_attrs['dn'] = dn
 
@@ -1828,15 +1781,9 @@ def execute(self, *args, **options):
         )
 
         scope = ldap.SCOPE_ONELEVEL
-        for callback in self.PRE_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                    (filter, base_dn, scope) = callback(
-                        ldap, filter, attrs_list, base_dn, scope, *args, **options
-                    )
-            else:
-                (filter, base_dn, scope) = callback(
-                    self, ldap, filter, attrs_list, base_dn, scope, *args, **options
-                )
+        for callback in self.get_callbacks('pre'):
+            (filter, base_dn, scope) = callback(
+                self, ldap, filter, attrs_list, base_dn, scope, *args, **options)
 
         try:
             (entries, truncated) = self._exc_wrapper(args, options, ldap.find_entries)(
@@ -1847,11 +1794,8 @@ def execute(self, *args, **options):
         except errors.NotFound:
             (entries, truncated) = ([], False)
 
-        for callback in self.POST_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                callback(ldap, entries, truncated, *args, **options)
-            else:
-                callback(self, ldap, entries, truncated, *args, **options)
+        for callback in self.get_callbacks('post'):
+            callback(self, ldap, entries, truncated, *args, **options)
 
         if self.sort_result_entries:
             if self.obj.primary_key:
@@ -1954,13 +1898,8 @@ def execute(self, *keys, **options):
         result = self.api.Command[self.show_command](keys[-1])['result']
         dn = result['dn']
 
-        for callback in self.PRE_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(ldap, dn, *keys, **options)
-            else:
-                dn = callback(
-                    self, ldap, dn, *keys, **options
-                )
+        for callback in self.get_callbacks('pre'):
+            dn = callback(self, ldap, dn, *keys, **options)
 
         if options.get('all', False):
             attrs_list = ['*'] + self.obj.default_attributes
@@ -1995,16 +1934,10 @@ def execute(self, *keys, **options):
         except Exception, e:
             raise errors.ReverseMemberError(verb=_('added'), exc=str(e))
 
-        for callback in self.POST_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                (completed, dn) = callback(
-                    ldap, completed, failed, dn, entry_attrs, *keys, **options
-                )
-            else:
-                (completed, dn) = callback(
-                    self, ldap, completed, failed, dn, entry_attrs, *keys,
-                    **options
-                )
+        for callback in self.get_callbacks('post'):
+            (completed, dn) = callback(
+                self, ldap, completed, failed, dn, entry_attrs, *keys,
+                **options)
 
         entry_attrs['dn'] = dn
         return dict(
@@ -2061,13 +1994,8 @@ def execute(self, *keys, **options):
         result = self.api.Command[self.show_command](keys[-1])['result']
         dn = result['dn']
 
-        for callback in self.PRE_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(ldap, dn, *keys, **options)
-            else:
-                dn = callback(
-                    self, ldap, dn, *keys, **options
-                )
+        for callback in self.get_callbacks('pre'):
+            dn = callback(self, ldap, dn, *keys, **options)
 
         if options.get('all', False):
             attrs_list = ['*'] + self.obj.default_attributes
@@ -2102,16 +2030,10 @@ def execute(self, *keys, **options):
         except Exception, e:
             raise errors.ReverseMemberError(verb=_('removed'), exc=str(e))
 
-        for callback in self.POST_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                (completed, dn) = callback(
-                    ldap, completed, failed, dn, entry_attrs, *keys, **options
-                )
-            else:
-                (completed, dn) = callback(
-                    self, ldap, completed, failed, dn, entry_attrs, *keys,
-                    **options
-                )
+        for callback in self.get_callbacks('post'):
+            (completed, dn) = callback(
+                self, ldap, completed, failed, dn, entry_attrs, *keys,
+                **options)
 
         entry_attrs['dn'] = dn
         return dict(
diff --git a/make-lint b/make-lint
index 7ecd59d7e8c5a644f812d4b8987866e7d06236b5..f619260434e33886175f5b7d5a1d008466f92a54 100755
--- a/make-lint
+++ b/make-lint
@@ -51,8 +51,6 @@ class IPATypeChecker(TypeChecker):
         'ipalib.plugable.Plugin': ['Command', 'Object', 'Method', 'Property',
             'Backend', 'env', 'debug', 'info', 'warning', 'error', 'critical',
             'exception', 'context', 'log'],
-        'ipalib.plugins.baseldap.CallbackInterface': ['pre_callback',
-            'post_callback', 'exc_callback'],
         'ipalib.plugins.misc.env': ['env'],
         'ipalib.parameters.Param': ['cli_name', 'cli_short_name', 'label',
             'doc', 'required', 'multivalue', 'primary_key', 'normalizer',
diff --git a/tests/test_xmlrpc/test_baseldap_plugin.py b/tests/test_xmlrpc/test_baseldap_plugin.py
index 0800a5d5249b937176770965b6d41cf2b641d4c7..6a8501f76529c811e48fa0b36403da052af69b57 100644
--- a/tests/test_xmlrpc/test_baseldap_plugin.py
+++ b/tests/test_xmlrpc/test_baseldap_plugin.py
@@ -24,11 +24,12 @@
 from ipalib import errors
 from ipalib.plugins import baseldap
 
+
 def test_exc_wrapper():
     """Test the CallbackInterface._exc_wrapper helper method"""
     handled_exceptions = []
 
-    class test_callback(baseldap.CallbackInterface):
+    class test_callback(baseldap.BaseLDAPCommand):
         """Fake IPA method"""
         def test_fail(self):
             self._exc_wrapper([], {}, self.fail)(1, 2, a=1, b=2)
@@ -64,3 +65,95 @@ def dont_handle(self, keys, options, e, call_func, *args, **kwargs):
 
     instance.test_fail()
     assert handled_exceptions == [None, errors.ExecutionError]
+
+
+def test_callback_registration():
+    class callbacktest_base(baseldap.CallbackInterface):
+        _callback_registry = dict(test={})
+
+        def test_callback(self, param):
+            messages.append(('Base test_callback', param))
+
+    def registered_callback(self, param):
+        messages.append(('Base registered callback', param))
+    callbacktest_base.register_callback('test', registered_callback)
+
+    class SomeClass(object):
+        def registered_callback(self, command, param):
+            messages.append(('Registered callback from another class', param))
+    callbacktest_base.register_callback('test', SomeClass().registered_callback)
+
+    class callbacktest_subclass(callbacktest_base):
+        pass
+
+    def subclass_callback(self, param):
+        messages.append(('Subclass registered callback', param))
+    callbacktest_subclass.register_callback('test', subclass_callback)
+
+
+    messages = []
+    instance = callbacktest_base()
+    for callback in instance.get_callbacks('test'):
+        callback(instance, 42)
+    assert messages == [
+            ('Base test_callback', 42),
+            ('Base registered callback', 42),
+            ('Registered callback from another class', 42)]
+
+    messages = []
+    instance = callbacktest_subclass()
+    for callback in instance.get_callbacks('test'):
+        callback(instance, 42)
+    assert messages == [
+            ('Base test_callback', 42),
+            ('Subclass registered callback', 42)]
+
+
+def test_exc_callback_registration():
+    messages = []
+    class callbacktest_base(baseldap.BaseLDAPCommand):
+        """A method superclass with an exception callback"""
+        def exc_callback(self, keys, options, exc, call_func, *args, **kwargs):
+            """Let the world know we saw the error, but don't handle it"""
+            messages.append('Base exc_callback')
+            raise exc
+
+        def test_fail(self):
+            """Raise a handled exception"""
+            try:
+                self._exc_wrapper([], {}, self.fail)(1, 2, a=1, b=2)
+            except Exception:
+                pass
+
+        def fail(self, *args, **kwargs):
+            """Raise an error"""
+            raise errors.ExecutionError('failure')
+
+    base_instance = callbacktest_base()
+
+    class callbacktest_subclass(callbacktest_base):
+        pass
+
+    @callbacktest_subclass.register_exc_callback
+    def exc_callback(self, keys, options, exc, call_func, *args, **kwargs):
+        """Subclass's private exception callback"""
+        messages.append('Subclass registered callback')
+        raise exc
+
+    subclass_instance = callbacktest_subclass()
+
+    # Make sure exception in base class is only handled by the base class
+    base_instance.test_fail()
+    assert messages == ['Base exc_callback']
+
+
+    @callbacktest_base.register_exc_callback
+    def exc_callback(self, keys, options, exc, call_func, *args, **kwargs):
+        """Callback on super class; doesn't affect the subclass"""
+        messages.append('Superclass registered callback')
+        raise exc
+
+    # Make sure exception in subclass is only handled by both
+    messages = []
+    subclass_instance.test_fail()
+    assert messages == ['Base exc_callback', 'Subclass registered callback']
--
1.7.10.1

_______________________________________________
Freeipa-devel mailing list
Freeipa-devel@redhat.com
https://www.redhat.com/mailman/listinfo/freeipa-devel

Reply via email to