Author: Brian Kearns <[email protected]>
Branch: 
Changeset: r68972:a4ca1e94409a
Date: 2014-01-27 19:59 -0500
http://bitbucket.org/pypy/pypy/changeset/a4ca1e94409a/

Log:    cleanup numpy array dot

diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -903,8 +903,8 @@
             w_res = self.descr_mul(space, other)
             assert isinstance(w_res, W_NDimArray)
             return w_res.descr_sum(space, space.wrap(-1), out)
-        dtype = interp_ufuncs.find_binop_result_dtype(space,
-                                     self.get_dtype(), other.get_dtype())
+        dtype = interp_ufuncs.find_binop_result_dtype(space, self.get_dtype(),
+                                                             other.get_dtype())
         if self.get_size() < 1 and other.get_size() < 1:
             # numpy compatability
             return W_NDimArray.new_scalar(space, dtype, space.wrap(0))
@@ -912,25 +912,27 @@
         out_shape, other_critical_dim = _match_dot_shapes(space, self, other)
         if out:
             matches = True
-            if len(out.get_shape()) != len(out_shape):
+            if dtype != out.get_dtype():
+                matches = False
+            elif not out.implementation.order == "C":
+                matches = False
+            elif len(out.get_shape()) != len(out_shape):
                 matches = False
             else:
                 for i in range(len(out_shape)):
                     if out.get_shape()[i] != out_shape[i]:
                         matches = False
                         break
-            if dtype != out.get_dtype():
-                matches = False
-            if not out.implementation.order == "C":
-                matches = False
             if not matches:
                 raise OperationError(space.w_ValueError, space.wrap(
-                    'output array is not acceptable (must have the right type, 
nr dimensions, and be a C-Array)'))
+                    'output array is not acceptable (must have the right type, 
'
+                    'nr dimensions, and be a C-Array)'))
             w_res = out
+            w_res.fill(space, self.get_dtype().coerce(space, None))
         else:
             w_res = W_NDimArray.from_shape(space, out_shape, dtype, 
w_instance=self)
         # This is the place to add fpypy and blas
-        return loop.multidim_dot(space, self, other,  w_res, dtype,
+        return loop.multidim_dot(space, self, other, w_res, dtype,
                                  other_critical_dim)
 
     def descr_mean(self, space, __args__):
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -146,8 +146,7 @@
     while not obj_iter.done():
         reduce_driver.jit_merge_point(shapelen=shapelen, func=func,
                                       done_func=done_func,
-                                      calc_dtype=calc_dtype,
-                                      )
+                                      calc_dtype=calc_dtype)
         rval = obj_iter.getitem().convert_to(space, calc_dtype)
         if done_func is not None and done_func(calc_dtype, rval):
             return rval
@@ -172,8 +171,7 @@
     shapelen = len(obj.get_shape())
     while not obj_iter.done():
         reduce_cum_driver.jit_merge_point(shapelen=shapelen, func=func,
-                                          dtype=calc_dtype,
-                                         )
+                                          dtype=calc_dtype)
         rval = obj_iter.getitem().convert_to(space, calc_dtype)
         cur_value = func(calc_dtype, cur_value, rval)
         out_iter.setitem(cur_value)
@@ -271,8 +269,7 @@
         iter.next()
         shapelen = len(arr.get_shape())
         while not iter.done():
-            arg_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
-                                      )
+            arg_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
             w_val = iter.getitem()
             new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
             if dtype.itemtype.ne(new_best, cur_best):
@@ -311,6 +308,7 @@
                                          if i != right_critical_dim]
     right_skip = range(len(left_shape) - 1)
     result_skip = [len(result.get_shape()) - (len(right_shape) > 1)]
+    assert result.get_dtype() == dtype
     outi = result.create_dot_iter(broadcast_shape, result_skip)
     lefti = left.create_dot_iter(broadcast_shape, left_skip)
     righti = right.create_dot_iter(broadcast_shape, right_skip)
@@ -318,10 +316,10 @@
         dot_driver.jit_merge_point(dtype=dtype)
         lval = lefti.getitem().convert_to(space, dtype)
         rval = righti.getitem().convert_to(space, dtype)
-        outval = outi.getitem().convert_to(space, dtype)
+        outval = outi.getitem()
         v = dtype.itemtype.mul(lval, rval)
-        value = dtype.itemtype.add(v, outval).convert_to(space, dtype)
-        outi.setitem(value)
+        v = dtype.itemtype.add(v, outval)
+        outi.setitem(v)
         outi.next()
         righti.next()
         lefti.next()
@@ -652,8 +650,8 @@
     out_iter = out.create_iter(shape)
     while not arr_iter.done():
         round_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
-        w_v = dtype.itemtype.round(arr_iter.getitem().convert_to(space, dtype),
-                     decimals)
+        w_v = arr_iter.getitem().convert_to(space, dtype)
+        w_v = dtype.itemtype.round(w_v, decimals)
         out_iter.setitem(w_v)
         arr_iter.next()
         out_iter.next()
diff --git a/pypy/module/micronumpy/test/test_arrayops.py 
b/pypy/module/micronumpy/test/test_arrayops.py
--- a/pypy/module/micronumpy/test/test_arrayops.py
+++ b/pypy/module/micronumpy/test/test_arrayops.py
@@ -56,6 +56,10 @@
         b = arange(12).reshape(4, 3)
         c = a.dot(b)
         assert (c == [[ 42, 48, 54], [114, 136, 158], [186, 224, 262]]).all()
+        c = a.dot(b.astype(float))
+        assert (c == [[ 42, 48, 54], [114, 136, 158], [186, 224, 262]]).all()
+        c = a.astype(float).dot(b)
+        assert (c == [[ 42, 48, 54], [114, 136, 158], [186, 224, 262]]).all()
 
         a = arange(24).reshape(2, 3, 4)
         raises(ValueError, "a.dot(a)")
@@ -91,9 +95,11 @@
         out = arange(9).reshape(3, 3)
         c = dot(a, b, out=out)
         assert (c == out).all()
-        out = arange(9,dtype=float).reshape(3, 3)
+        assert (c == [[42, 48, 54], [114, 136, 158], [186, 224, 262]]).all()
+        out = arange(9, dtype=float).reshape(3, 3)
         exc = raises(ValueError, dot, a, b, out)
-        assert exc.value[0].find('not acceptable') > 0
+        assert exc.value[0] == ('output array is not acceptable (must have the 
'
+                                'right type, nr dimensions, and be a C-Array)')
 
     def test_choose_basic(self):
         from numpypy import array
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to