While investigating ticket 2674, I found several problems with our implementation of the CallbackInterface ­— it required complicated calling code, and would subtly break if command classes were instantiated in different ways than they are currently.

Here's my fix. See commit message for details.

--
Petr³
From 20036ab7b06f00380670243c49b3959983f39320 Mon Sep 17 00:00:00 2001
From: Petr Viktorin <pvikt...@redhat.com>
Date: Wed, 25 Apr 2012 10:31:10 -0400
Subject: [PATCH] Rework the CallbackInterface

Fix several problems with the callback interface:
- Automatically registered callbacks (i.e. methods named
    exc_callback, pre_callback etc) were registered on every
    instantiation.
    Fix: Do not register callbacks in __init__; instead return the
    method when asked for it.
- The calling code had to distinguish between bound methods and
    plain functions by checking the 'im_self' attribute.
    Fix: Always return the "default" callback as an unbound method.
    Registered callbacks now always take the extra `self` argument,
    whether they happen to be bound methods or not.
    Calling code now always needs to pass the `self` argument.
- Did not work well with inheritance: due to the fact that Python
    looks up missing attributes in superclasses, callbacks could
    get attached to a superclass if it was instantiated early enough. *
    Fix: Instead of attribute lookup, use a dictionary with class keys.
- The interface included the callback types, which are LDAP-specific.
    Fix: Create generic register_callback and get_callback mehods,
    move LDAP-specific code to BaseLDAPCommand

Update code that calls the callbacks.
Add tests.
Remove lint exceptions for CallbackInterface.

* https://fedorahosted.org/freeipa/ticket/2674
---
 ipalib/cli.py                             |    9 +-
 ipalib/plugins/baseldap.py                |  354 +++++++++++------------------
 make-lint                                 |    2 -
 tests/test_xmlrpc/test_baseldap_plugin.py |   95 +++++++-
 4 files changed, 239 insertions(+), 221 deletions(-)

diff --git a/ipalib/cli.py b/ipalib/cli.py
index 8279345a909ac2e3b764d1f92cfc45ed679b9204..d53e6cd403947ad8cdce1d20cb692db10a0c3dd5 100644
--- a/ipalib/cli.py
+++ b/ipalib/cli.py
@@ -1195,8 +1195,13 @@ def prompt_interactively(self, cmd, kw):
                     param.label, param.confirm
                 )
 
-        for callback in getattr(cmd, 'INTERACTIVE_PROMPT_CALLBACKS', []):
-            callback(kw)
+        try:
+            callbacks = cmd.get_callbacks('interactive_prompt')
+        except AttributeError:
+            pass
+        else:
+            for callback in callbacks:
+                callback(cmd, kw)
 
     def load_files(self, cmd, kw):
         """
diff --git a/ipalib/plugins/baseldap.py b/ipalib/plugins/baseldap.py
index 85a81723175f38f10c711530971f173a54f1150a..3e1c51357eadb7b6a3f3d7fa900e97ee45267c1a 100644
--- a/ipalib/plugins/baseldap.py
+++ b/ipalib/plugins/baseldap.py
@@ -679,93 +679,57 @@ def _check_limit_object_class(attributes, attrs, allow_only):
     if len(limitattrs) > 0 and allow_only:
         raise errors.ObjectclassViolation(info='attribute "%(attribute)s" not allowed' % dict(attribute=limitattrs[0]))
 
+
 class CallbackInterface(Method):
-    """
-    Callback registration interface
-    """
-    def __init__(self):
-        #pylint: disable=E1003
-        if not hasattr(self.__class__, 'PRE_CALLBACKS'):
-            self.__class__.PRE_CALLBACKS = []
-        if not hasattr(self.__class__, 'POST_CALLBACKS'):
-            self.__class__.POST_CALLBACKS = []
-        if not hasattr(self.__class__, 'EXC_CALLBACKS'):
-            self.__class__.EXC_CALLBACKS = []
-        if not hasattr(self.__class__, 'INTERACTIVE_PROMPT_CALLBACKS'):
-            self.__class__.INTERACTIVE_PROMPT_CALLBACKS = []
-        if hasattr(self, 'pre_callback'):
-            self.register_pre_callback(self.pre_callback, True)
-        if hasattr(self, 'post_callback'):
-            self.register_post_callback(self.post_callback, True)
-        if hasattr(self, 'exc_callback'):
-            self.register_exc_callback(self.exc_callback, True)
-        if hasattr(self, 'interactive_prompt_callback'):
-            self.register_interactive_prompt_callback(
-                    self.interactive_prompt_callback, True) #pylint: disable=E1101
-        super(Method, self).__init__()
-
-    @classmethod
-    def register_pre_callback(klass, callback, first=False):
-        assert callable(callback)
-        if not hasattr(klass, 'PRE_CALLBACKS'):
-            klass.PRE_CALLBACKS = []
-        if first:
-            klass.PRE_CALLBACKS.insert(0, callback)
-        else:
-            klass.PRE_CALLBACKS.append(callback)
-
-    @classmethod
-    def register_post_callback(klass, callback, first=False):
-        assert callable(callback)
-        if not hasattr(klass, 'POST_CALLBACKS'):
-            klass.POST_CALLBACKS = []
-        if first:
-            klass.POST_CALLBACKS.insert(0, callback)
-        else:
-            klass.POST_CALLBACKS.append(callback)
-
-    @classmethod
-    def register_exc_callback(klass, callback, first=False):
-        assert callable(callback)
-        if not hasattr(klass, 'EXC_CALLBACKS'):
-            klass.EXC_CALLBACKS = []
-        if first:
-            klass.EXC_CALLBACKS.insert(0, callback)
-        else:
-            klass.EXC_CALLBACKS.append(callback)
-
-    @classmethod
-    def register_interactive_prompt_callback(klass, callback, first=False):
-        assert callable(callback)
-        if not hasattr(klass, 'INTERACTIVE_PROMPT_CALLBACKS'):
-            klass.INTERACTIVE_PROMPT_CALLBACKS = []
-        if first:
-            klass.INTERACTIVE_PROMPT_CALLBACKS.insert(0, callback)
-        else:
-            klass.INTERACTIVE_PROMPT_CALLBACKS.append(callback)
-
-    def _exc_wrapper(self, keys, options, call_func):
-        """Function wrapper that automatically calls exception callbacks"""
-        def wrapped(*call_args, **call_kwargs):
-            # call call_func first
-            func = call_func
-            callbacks = list(getattr(self, 'EXC_CALLBACKS', []))
-            while True:
-                try:
-                    return func(*call_args, **call_kwargs)
-                except errors.ExecutionError, e:
-                    if not callbacks:
-                        raise
-                    # call exc_callback in the next loop
-                    callback = callbacks.pop(0)
-                    if hasattr(callback, 'im_self'):
-                        def exc_func(*args, **kwargs):
-                            return callback(keys, options, e, call_func, *args, **kwargs)
-                    else:
-                        def exc_func(*args, **kwargs):
-                            return callback(self, keys, options, e, call_func, *args, **kwargs)
-                    func = exc_func
-        return wrapped
+    """Callback registration interface
+
+    This class's subclasses allow different types of callbacks to be added and
+    removed to them.
+    Registering a callback is done either by ``register_callback``, or by
+    defining a ``<type>_callback`` method.
+
+    Subclasses should define the `_callback_registry` attribute as a dictionary
+    mapping allowed callback types to (initially) empty dictionaries.
+    """
+
+    _callback_registry = dict()
+
+    @classmethod
+    def get_callbacks(cls, callback_type):
+        """Yield callbacks of the given type"""
+        # Use one shared callback registry, keyed on class, to avoid problems
+        # with missing attributes being looked up in superclasses
+        callbacks = cls._callback_registry[callback_type].get(cls, [None])
+        for callback in callbacks:
+            if callback is None:
+                try:
+                    yield getattr(cls, '%s_callback' % callback_type)
+                except AttributeError:
+                    pass
+            else:
+                yield callback
+
+    @classmethod
+    def register_callback(cls, callback_type, callback, first=False):
+        """Register a callback
+
+        :param callback_type: The callback type (e.g. 'pre', 'post')
+        :param callback: The callable added
+        :param first: If true, the new callback will be added before all
+            existing callbacks; otherwise it's added after them
+
+        Note that callbacks registered this way will be attached to this class
+        only, not to its subclasses.
+        """
+        assert callable(callback)
+        try:
+            callbacks = cls._callback_registry[callback_type][cls]
+        except KeyError:
+            callbacks = cls._callback_registry[callback_type][cls] = [None]
+        if first:
+            callbacks.insert(0, callback)
+        else:
+            callbacks.append(callback)
 
 
 class BaseLDAPCommand(CallbackInterface, Command):
@@ -791,6 +755,8 @@ class BaseLDAPCommand(CallbackInterface, Command):
                          exclude='webui',
                         )
 
+    _callback_registry = dict(pre={}, post={}, exc={}, interactive_prompt={})
+
     def _convert_2_dict(self, attrs):
         """
         Convert a string in the form of name/value pairs into a dictionary.
@@ -942,6 +908,45 @@ def process_attr_options(self, entry_attrs, dn, keys, options):
                 elif isinstance(entry_attrs[attr], (tuple, list)) and len(entry_attrs[attr]) == 1:
                     entry_attrs[attr] = entry_attrs[attr][0]
 
+    @classmethod
+    def register_pre_callback(cls, callback, first=False):
+        """Shortcut for register_callback('pre', ...)"""
+        cls.register_callback('pre', callback, first)
+
+    @classmethod
+    def register_post_callback(cls, callback, first=False):
+        """Shortcut for register_callback('post', ...)"""
+        cls.register_callback('post', callback, first)
+
+    @classmethod
+    def register_exc_callback(cls, callback, first=False):
+        """Shortcut for register_callback('exc', ...)"""
+        cls.register_callback('exc', callback, first)
+
+    @classmethod
+    def register_interactive_prompt_callback(cls, callback, first=False):
+        """Shortcut for register_callback('interactive_prompt', ...)"""
+        cls.register_callback('interactive_prompt', callback, first)
+
+    def _exc_wrapper(self, keys, options, call_func):
+        """Function wrapper that automatically calls exception callbacks"""
+        def wrapped(*call_args, **call_kwargs):
+            # call call_func first
+            func = call_func
+            callbacks = list(self.get_callbacks('exc'))
+            while True:
+                try:
+                    return func(*call_args, **call_kwargs)
+                except errors.ExecutionError, e:
+                    if not callbacks:
+                        raise
+                    # call exc_callback in the next loop
+                    callback = callbacks.pop(0)
+                    def exc_func(*args, **kwargs):
+                        return callback(
+                            self, keys, options, e, call_func, *args, **kwargs)
+                    func = exc_func
+        return wrapped
 
 class LDAPCreate(BaseLDAPCommand, crud.Create):
     """
@@ -993,15 +998,9 @@ def execute(self, *keys, **options):
                 set(self.obj.default_attributes + entry_attrs.keys())
             )
 
-        for callback in self.PRE_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(
-                    ldap, dn, entry_attrs, attrs_list, *keys, **options
-                )
-            else:
-                dn = callback(
-                    self, ldap, dn, entry_attrs, attrs_list, *keys, **options
-                )
+        for callback in self.get_callbacks('pre'):
+            dn = callback(
+                self, ldap, dn, entry_attrs, attrs_list, *keys, **options)
 
         _check_single_value_attrs(self.params, entry_attrs)
         ldap.get_schema()
@@ -1045,11 +1044,8 @@ def execute(self, *keys, **options):
         except errors.NotFound:
             self.obj.handle_not_found(*keys)
 
-        for callback in self.POST_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(ldap, dn, entry_attrs, *keys, **options)
-            else:
-                dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
+        for callback in self.get_callbacks('post'):
+            dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
 
         entry_attrs['dn'] = dn
 
@@ -1152,11 +1148,8 @@ def execute(self, *keys, **options):
         else:
             attrs_list = list(self.obj.default_attributes)
 
-        for callback in self.PRE_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(ldap, dn, attrs_list, *keys, **options)
-            else:
-                dn = callback(self, ldap, dn, attrs_list, *keys, **options)
+        for callback in self.get_callbacks('pre'):
+            dn = callback(self, ldap, dn, attrs_list, *keys, **options)
 
         try:
             (dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.get_entry)(
@@ -1168,11 +1161,8 @@ def execute(self, *keys, **options):
         if options.get('rights', False) and options.get('all', False):
             entry_attrs['attributelevelrights'] = get_effective_rights(ldap, dn)
 
-        for callback in self.POST_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(ldap, dn, entry_attrs, *keys, **options)
-            else:
-                dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
+        for callback in self.get_callbacks('post'):
+            dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
 
         self.obj.convert_attribute_members(entry_attrs, *keys, **options)
         entry_attrs['dn'] = dn
@@ -1244,15 +1234,9 @@ def execute(self, *keys, **options):
                 set(self.obj.default_attributes + entry_attrs.keys())
             )
 
-        for callback in self.PRE_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(
-                    ldap, dn, entry_attrs, attrs_list, *keys, **options
-                )
-            else:
-                dn = callback(
-                    self, ldap, dn, entry_attrs, attrs_list, *keys, **options
-                )
+        for callback in self.get_callbacks('pre'):
+            dn = callback(
+                self, ldap, dn, entry_attrs, attrs_list, *keys, **options)
 
         _check_single_value_attrs(self.params, entry_attrs)
         _check_empty_attrs(self.obj.params, entry_attrs)
@@ -1301,11 +1285,8 @@ def execute(self, *keys, **options):
         if options.get('rights', False) and options.get('all', False):
             entry_attrs['attributelevelrights'] = get_effective_rights(ldap, dn)
 
-        for callback in self.POST_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(ldap, dn, entry_attrs, *keys, **options)
-            else:
-                dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
+        for callback in self.get_callbacks('post'):
+            dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
 
         self.obj.convert_attribute_members(entry_attrs, *keys, **options)
         if self.obj.primary_key and keys[-1] is not None:
@@ -1340,11 +1321,8 @@ def delete_entry(pkey):
             nkeys = keys[:-1] + (pkey, )
             dn = self.obj.get_dn(*nkeys, **options)
 
-            for callback in self.PRE_CALLBACKS:
-                if hasattr(callback, 'im_self'):
-                    dn = callback(ldap, dn, *nkeys, **options)
-                else:
-                    dn = callback(self, ldap, dn, *nkeys, **options)
+            for callback in self.get_callbacks('pre'):
+                dn = callback(self, ldap, dn, *nkeys, **options)
 
             def delete_subtree(base_dn):
                 truncated = True
@@ -1365,11 +1343,8 @@ def delete_subtree(base_dn):
 
             delete_subtree(dn)
 
-            for callback in self.POST_CALLBACKS:
-                if hasattr(callback, 'im_self'):
-                    result = callback(ldap, dn, *nkeys, **options)
-                else:
-                    result = callback(self, ldap, dn, *nkeys, **options)
+            for callback in self.get_callbacks('post'):
+                result = callback(self, ldap, dn, *nkeys, **options)
 
             return result
 
@@ -1481,13 +1456,8 @@ def execute(self, *keys, **options):
 
         dn = self.obj.get_dn(*keys, **options)
 
-        for callback in self.PRE_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(ldap, dn, member_dns, failed, *keys, **options)
-            else:
-                dn = callback(
-                    self, ldap, dn, member_dns, failed, *keys, **options
-                )
+        for callback in self.get_callbacks('pre'):
+            dn = callback(self, ldap, dn, member_dns, failed, *keys, **options)
 
         completed = 0
         for (attr, objs) in member_dns.iteritems():
@@ -1520,16 +1490,10 @@ def execute(self, *keys, **options):
         except errors.NotFound:
             self.obj.handle_not_found(*keys)
 
-        for callback in self.POST_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                (completed, dn) = callback(
-                    ldap, completed, failed, dn, entry_attrs, *keys, **options
-                )
-            else:
-                (completed, dn) = callback(
-                    self, ldap, completed, failed, dn, entry_attrs, *keys,
-                    **options
-                )
+        for callback in self.get_callbacks('post'):
+            (completed, dn) = callback(
+                self, ldap, completed, failed, dn, entry_attrs, *keys,
+                **options)
 
         entry_attrs['dn'] = dn
         self.obj.convert_attribute_members(entry_attrs, *keys, **options)
@@ -1580,13 +1544,8 @@ def execute(self, *keys, **options):
 
         dn = self.obj.get_dn(*keys, **options)
 
-        for callback in self.PRE_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(ldap, dn, member_dns, failed, *keys, **options)
-            else:
-                dn = callback(
-                    self, ldap, dn, member_dns, failed, *keys, **options
-                )
+        for callback in self.get_callbacks('pre'):
+            dn = callback(self, ldap, dn, member_dns, failed, *keys, **options)
 
         completed = 0
         for (attr, objs) in member_dns.iteritems():
@@ -1622,16 +1581,10 @@ def execute(self, *keys, **options):
         except errors.NotFound:
             self.obj.handle_not_found(*keys)
 
-        for callback in self.POST_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                (completed, dn) = callback(
-                    ldap, completed, failed, dn, entry_attrs, *keys, **options
-                )
-            else:
-                (completed, dn) = callback(
-                    self, ldap, completed, failed, dn, entry_attrs, *keys,
-                    **options
-                )
+        for callback in self.get_callbacks('post'):
+            (completed, dn) = callback(
+                self, ldap, completed, failed, dn, entry_attrs, *keys,
+                **options)
 
         entry_attrs['dn'] = dn
 
@@ -1816,15 +1769,9 @@ def execute(self, *args, **options):
         )
 
         scope = ldap.SCOPE_ONELEVEL
-        for callback in self.PRE_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                    (filter, base_dn, scope) = callback(
-                        ldap, filter, attrs_list, base_dn, scope, *args, **options
-                    )
-            else:
-                (filter, base_dn, scope) = callback(
-                    self, ldap, filter, attrs_list, base_dn, scope, *args, **options
-                )
+        for callback in self.get_callbacks('pre'):
+            (filter, base_dn, scope) = callback(
+                self, ldap, filter, attrs_list, base_dn, scope, *args, **options)
 
         try:
             (entries, truncated) = self._exc_wrapper(args, options, ldap.find_entries)(
@@ -1835,11 +1782,8 @@ def execute(self, *args, **options):
         except errors.NotFound:
             (entries, truncated) = ([], False)
 
-        for callback in self.POST_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                callback(ldap, entries, truncated, *args, **options)
-            else:
-                callback(self, ldap, entries, truncated, *args, **options)
+        for callback in self.get_callbacks('post'):
+            callback(self, ldap, entries, truncated, *args, **options)
 
         if self.sort_result_entries:
             if self.obj.primary_key:
@@ -1942,13 +1886,8 @@ def execute(self, *keys, **options):
         result = self.api.Command[self.show_command](keys[-1])['result']
         dn = result['dn']
 
-        for callback in self.PRE_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(ldap, dn, *keys, **options)
-            else:
-                dn = callback(
-                    self, ldap, dn, *keys, **options
-                )
+        for callback in self.get_callbacks('pre'):
+            dn = callback(self, ldap, dn, *keys, **options)
 
         if options.get('all', False):
             attrs_list = ['*'] + self.obj.default_attributes
@@ -1983,16 +1922,10 @@ def execute(self, *keys, **options):
         except Exception, e:
             raise errors.ReverseMemberError(verb=_('added'), exc=str(e))
 
-        for callback in self.POST_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                (completed, dn) = callback(
-                    ldap, completed, failed, dn, entry_attrs, *keys, **options
-                )
-            else:
-                (completed, dn) = callback(
-                    self, ldap, completed, failed, dn, entry_attrs, *keys,
-                    **options
-                )
+        for callback in self.get_callbacks('post'):
+            (completed, dn) = callback(
+                self, ldap, completed, failed, dn, entry_attrs, *keys,
+                **options)
 
         entry_attrs['dn'] = dn
         return dict(
@@ -2049,13 +1982,8 @@ def execute(self, *keys, **options):
         result = self.api.Command[self.show_command](keys[-1])['result']
         dn = result['dn']
 
-        for callback in self.PRE_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                dn = callback(ldap, dn, *keys, **options)
-            else:
-                dn = callback(
-                    self, ldap, dn, *keys, **options
-                )
+        for callback in self.get_callbacks('pre'):
+            dn = callback(self, ldap, dn, *keys, **options)
 
         if options.get('all', False):
             attrs_list = ['*'] + self.obj.default_attributes
@@ -2090,16 +2018,10 @@ def execute(self, *keys, **options):
         except Exception, e:
             raise errors.ReverseMemberError(verb=_('removed'), exc=str(e))
 
-        for callback in self.POST_CALLBACKS:
-            if hasattr(callback, 'im_self'):
-                (completed, dn) = callback(
-                    ldap, completed, failed, dn, entry_attrs, *keys, **options
-                )
-            else:
-                (completed, dn) = callback(
-                    self, ldap, completed, failed, dn, entry_attrs, *keys,
-                    **options
-                )
+        for callback in self.get_callbacks('post'):
+            (completed, dn) = callback(
+                self, ldap, completed, failed, dn, entry_attrs, *keys,
+                **options)
 
         entry_attrs['dn'] = dn
         return dict(
diff --git a/make-lint b/make-lint
index 7ecd59d7e8c5a644f812d4b8987866e7d06236b5..f619260434e33886175f5b7d5a1d008466f92a54 100755
--- a/make-lint
+++ b/make-lint
@@ -51,8 +51,6 @@ class IPATypeChecker(TypeChecker):
         'ipalib.plugable.Plugin': ['Command', 'Object', 'Method', 'Property',
             'Backend', 'env', 'debug', 'info', 'warning', 'error', 'critical',
             'exception', 'context', 'log'],
-        'ipalib.plugins.baseldap.CallbackInterface': ['pre_callback',
-            'post_callback', 'exc_callback'],
         'ipalib.plugins.misc.env': ['env'],
         'ipalib.parameters.Param': ['cli_name', 'cli_short_name', 'label',
             'doc', 'required', 'multivalue', 'primary_key', 'normalizer',
diff --git a/tests/test_xmlrpc/test_baseldap_plugin.py b/tests/test_xmlrpc/test_baseldap_plugin.py
index 0800a5d5249b937176770965b6d41cf2b641d4c7..6a8501f76529c811e48fa0b36403da052af69b57 100644
--- a/tests/test_xmlrpc/test_baseldap_plugin.py
+++ b/tests/test_xmlrpc/test_baseldap_plugin.py
@@ -24,11 +24,12 @@
 from ipalib import errors
 from ipalib.plugins import baseldap
 
+
 def test_exc_wrapper():
     """Test the CallbackInterface._exc_wrapper helper method"""
     handled_exceptions = []
 
-    class test_callback(baseldap.CallbackInterface):
+    class test_callback(baseldap.BaseLDAPCommand):
         """Fake IPA method"""
         def test_fail(self):
             self._exc_wrapper([], {}, self.fail)(1, 2, a=1, b=2)
@@ -64,3 +65,95 @@ def dont_handle(self, keys, options, e, call_func, *args, **kwargs):
 
     instance.test_fail()
     assert handled_exceptions == [None, errors.ExecutionError]
+
+
+def test_callback_registration():
+    class callbacktest_base(baseldap.CallbackInterface):
+        _callback_registry = dict(test={})
+
+        def test_callback(self, param):
+            messages.append(('Base test_callback', param))
+
+    def registered_callback(self, param):
+        messages.append(('Base registered callback', param))
+    callbacktest_base.register_callback('test', registered_callback)
+
+    class SomeClass(object):
+        def registered_callback(self, command, param):
+            messages.append(('Registered callback from another class', param))
+    callbacktest_base.register_callback('test', SomeClass().registered_callback)
+
+    class callbacktest_subclass(callbacktest_base):
+        pass
+
+    def subclass_callback(self, param):
+        messages.append(('Subclass registered callback', param))
+    callbacktest_subclass.register_callback('test', subclass_callback)
+
+
+    messages = []
+    instance = callbacktest_base()
+    for callback in instance.get_callbacks('test'):
+        callback(instance, 42)
+    assert messages == [
+            ('Base test_callback', 42),
+            ('Base registered callback', 42),
+            ('Registered callback from another class', 42)]
+
+    messages = []
+    instance = callbacktest_subclass()
+    for callback in instance.get_callbacks('test'):
+        callback(instance, 42)
+    assert messages == [
+            ('Base test_callback', 42),
+            ('Subclass registered callback', 42)]
+
+
+def test_exc_callback_registration():
+    messages = []
+    class callbacktest_base(baseldap.BaseLDAPCommand):
+        """A method superclass with an exception callback"""
+        def exc_callback(self, keys, options, exc, call_func, *args, **kwargs):
+            """Let the world know we saw the error, but don't handle it"""
+            messages.append('Base exc_callback')
+            raise exc
+
+        def test_fail(self):
+            """Raise a handled exception"""
+            try:
+                self._exc_wrapper([], {}, self.fail)(1, 2, a=1, b=2)
+            except Exception:
+                pass
+
+        def fail(self, *args, **kwargs):
+            """Raise an error"""
+            raise errors.ExecutionError('failure')
+
+    base_instance = callbacktest_base()
+
+    class callbacktest_subclass(callbacktest_base):
+        pass
+
+    @callbacktest_subclass.register_exc_callback
+    def exc_callback(self, keys, options, exc, call_func, *args, **kwargs):
+        """Subclass's private exception callback"""
+        messages.append('Subclass registered callback')
+        raise exc
+
+    subclass_instance = callbacktest_subclass()
+
+    # Make sure exception in base class is only handled by the base class
+    base_instance.test_fail()
+    assert messages == ['Base exc_callback']
+
+
+    @callbacktest_base.register_exc_callback
+    def exc_callback(self, keys, options, exc, call_func, *args, **kwargs):
+        """Callback on super class; doesn't affect the subclass"""
+        messages.append('Superclass registered callback')
+        raise exc
+
+    # Make sure exception in subclass is only handled by both
+    messages = []
+    subclass_instance.test_fail()
+    assert messages == ['Base exc_callback', 'Subclass registered callback']
--
1.7.10.1

_______________________________________________
Freeipa-devel mailing list
Freeipa-devel@redhat.com
https://www.redhat.com/mailman/listinfo/freeipa-devel

Reply via email to