
from overloaded import overloaded

thisclass = object()

class overload(object):
    def __init__(self, to_overload, args):
        self.to_overload = to_overload
        self.args = args
        
    def __call__(self, func):
        if isinstance(func, overload):
            self.chain = func
            self.func = func.func
        else:
            self.chain = None
            self.func = func
        return self
        
    def apply(self, klass):
        newargs = self.args[:]
        for i, obj in enumerate(newargs):
            if obj is thisclass:
                newargs[i] = klass
        self.to_overload.register_func(newargs, self.func)
     
class overload_meta(type):
    def __new__(cls, name, bases, dict):
        replacements = []
        newtype = type.__new__(cls, name, bases, dict)
        for name, obj in dict.iteritems():
            if isinstance(obj, overload):
                while obj is not None:
                    obj.apply(newtype)
                    obj = obj.chain
                replacements.append(name)
        for name in replacements:
            delattr(newtype, name)
        return newtype
        
        
class mixin(object):
    __metaclass__ = overload_meta
     

if __name__ == "__main__":
    @overloaded 
    def len(sequence):
        count = 0
        for i in sequence:
            count += 1
        return count
    
    @overloaded
    def adder(lhs, rhs):
        raise TypeError("Can't add these classes")
    
    class Foo(mixin):
        def __init__(self, val):
            self.val = val
    
        def __repr__(self):
            return "<repr '%s'>" % (self.val,)
    
        @overload(len, [thisclass])
        def len(self):
            return self.val
            
        @overload(adder, [thisclass, int])
        @overload(adder, [thisclass, float])
        def add(self, rhs):
            return Foo(self.val + rhs)
            
        @overload(adder, [int, thisclass])
        @overload(adder, [float, thisclass])
        def radd(lhs, self):
            return Foo(self.val + lhs)
   
    f = Foo(5)
    print len(f)
    print adder(f, 2.0)
    print adder(2, f)