Author: mattip <[email protected]>
Branch: numpy-fixes
Changeset: r77151:f57b263d57d4
Date: 2015-05-06 00:44 +0300
http://bitbucket.org/pypy/pypy/changeset/f57b263d57d4/

Log:    test, fix obscure use of __array_priority__ in classes like
        numpy.Polynomial

diff --git a/pypy/module/micronumpy/compile.py 
b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -203,6 +203,9 @@
         assert isinstance(w_obj, BoolObject)
         return bool(w_obj.intval)
 
+    def gt(self, w_lhs, w_rhs):
+        return BoolObject(self.int_w(w_lhs) > self.int_w(w_rhs))
+
     def lt(self, w_lhs, w_rhs):
         return BoolObject(self.int_w(w_lhs) < self.int_w(w_rhs))
 
diff --git a/pypy/module/micronumpy/test/test_subtype.py 
b/pypy/module/micronumpy/test/test_subtype.py
--- a/pypy/module/micronumpy/test/test_subtype.py
+++ b/pypy/module/micronumpy/test/test_subtype.py
@@ -99,7 +99,27 @@
         ret = np.ndarray.__new__(np.ndarray, arr.shape, arr.dtype, buffer=arr)
         assert ret.__array_priority__ == 0.0
         assert (arr == ret).all()
+    
+    def test_priority(self):
+        from numpy import ndarray, arange, add
+        class DoReflected(object):
+            __array_priority__ = 10
+            def __radd__(self, other):
+                return 42
 
+        class A(object):
+            def __add__(self, other):
+                return NotImplemented
+
+
+        a = arange(10)
+        b = DoReflected()
+        c = A()
+        assert c + b == 42
+        assert a.__add__(b) is NotImplemented # not an exception
+        assert b.__radd__(a) == 42
+        assert a + b == 42
+        
     def test_finalize(self):
         #taken from 
http://docs.scipy.org/doc/numpy/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray
         import numpy as np
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -322,6 +322,32 @@
         extobj_w = space.newlist([space.wrap(8192), space.wrap(0), 
space.w_None])
         return extobj_w
 
+def _has_reflected_op(space, w_obj, op):
+    refops ={ 'add': 'radd',
+            'subtract': 'rsub',
+            'multiply': 'rmul',
+            'divide': 'rdiv',
+            'true_divide': 'rtruediv',
+            'floor_divide': 'rfloordiv',
+            'remainder': 'rmod',
+            'power': 'rpow',
+            'left_shift': 'rlshift',
+            'right_shift': 'rrshift',
+            'bitwise_and': 'rand',
+            'bitwise_xor': 'rxor',
+            'bitwise_or': 'ror',
+            #/* Comparisons */
+            'equal': 'eq',
+            'not_equal': 'ne',
+            'greater': 'lt',
+            'less': 'gt',
+            'greater_equal': 'le',
+            'less_equal': 'ge',
+        }
+    if op not in refops:
+        return False
+    return space.getattr(w_obj, space.wrap('__' + refops[op] + '__')) is not 
None
+
 class W_Ufunc1(W_Ufunc):
     _immutable_fields_ = ["func", "bool_result"]
     nin = 1
@@ -432,6 +458,19 @@
         else:
             [w_lhs, w_rhs] = args_w
             w_out = None
+        if not isinstance(w_rhs, W_NDimArray):
+            # numpy implementation detail, useful for things like 
numpy.Polynomial
+            # FAIL with NotImplemented if the other object has
+            # the __r<op>__ method and has __array_priority__ as
+            # an attribute (signalling it can handle ndarray's)
+            # and is not already an ndarray or a subtype of the same type.
+            w_zero = space.wrap(0.0)
+            w_priority_l = space.findattr(w_lhs, 
space.wrap('__array_priority__')) or w_zero
+            w_priority_r = space.findattr(w_rhs, 
space.wrap('__array_priority__')) or w_zero
+            # XXX what is better, unwrapping values or space.gt?
+            r_greater = space.is_true(space.gt(w_priority_r, w_priority_l))
+            if r_greater and _has_reflected_op(space, w_rhs, self.name):
+                return space.w_NotImplemented
         w_lhs = numpify(space, w_lhs)
         w_rhs = numpify(space, w_rhs)
         w_ldtype = _get_dtype(space, w_lhs)
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to