Author: Brian Kearns <[email protected]>
Branch:
Changeset: r68972:a4ca1e94409a
Date: 2014-01-27 19:59 -0500
http://bitbucket.org/pypy/pypy/changeset/a4ca1e94409a/
Log: cleanup numpy array dot
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -903,8 +903,8 @@
w_res = self.descr_mul(space, other)
assert isinstance(w_res, W_NDimArray)
return w_res.descr_sum(space, space.wrap(-1), out)
- dtype = interp_ufuncs.find_binop_result_dtype(space,
- self.get_dtype(), other.get_dtype())
+ dtype = interp_ufuncs.find_binop_result_dtype(space, self.get_dtype(),
+ other.get_dtype())
if self.get_size() < 1 and other.get_size() < 1:
# numpy compatability
return W_NDimArray.new_scalar(space, dtype, space.wrap(0))
@@ -912,25 +912,27 @@
out_shape, other_critical_dim = _match_dot_shapes(space, self, other)
if out:
matches = True
- if len(out.get_shape()) != len(out_shape):
+ if dtype != out.get_dtype():
+ matches = False
+ elif not out.implementation.order == "C":
+ matches = False
+ elif len(out.get_shape()) != len(out_shape):
matches = False
else:
for i in range(len(out_shape)):
if out.get_shape()[i] != out_shape[i]:
matches = False
break
- if dtype != out.get_dtype():
- matches = False
- if not out.implementation.order == "C":
- matches = False
if not matches:
raise OperationError(space.w_ValueError, space.wrap(
- 'output array is not acceptable (must have the right type,
nr dimensions, and be a C-Array)'))
+ 'output array is not acceptable (must have the right type,
'
+ 'nr dimensions, and be a C-Array)'))
w_res = out
+ w_res.fill(space, self.get_dtype().coerce(space, None))
else:
w_res = W_NDimArray.from_shape(space, out_shape, dtype,
w_instance=self)
# This is the place to add fpypy and blas
- return loop.multidim_dot(space, self, other, w_res, dtype,
+ return loop.multidim_dot(space, self, other, w_res, dtype,
other_critical_dim)
def descr_mean(self, space, __args__):
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -146,8 +146,7 @@
while not obj_iter.done():
reduce_driver.jit_merge_point(shapelen=shapelen, func=func,
done_func=done_func,
- calc_dtype=calc_dtype,
- )
+ calc_dtype=calc_dtype)
rval = obj_iter.getitem().convert_to(space, calc_dtype)
if done_func is not None and done_func(calc_dtype, rval):
return rval
@@ -172,8 +171,7 @@
shapelen = len(obj.get_shape())
while not obj_iter.done():
reduce_cum_driver.jit_merge_point(shapelen=shapelen, func=func,
- dtype=calc_dtype,
- )
+ dtype=calc_dtype)
rval = obj_iter.getitem().convert_to(space, calc_dtype)
cur_value = func(calc_dtype, cur_value, rval)
out_iter.setitem(cur_value)
@@ -271,8 +269,7 @@
iter.next()
shapelen = len(arr.get_shape())
while not iter.done():
- arg_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
- )
+ arg_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
w_val = iter.getitem()
new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
if dtype.itemtype.ne(new_best, cur_best):
@@ -311,6 +308,7 @@
if i != right_critical_dim]
right_skip = range(len(left_shape) - 1)
result_skip = [len(result.get_shape()) - (len(right_shape) > 1)]
+ assert result.get_dtype() == dtype
outi = result.create_dot_iter(broadcast_shape, result_skip)
lefti = left.create_dot_iter(broadcast_shape, left_skip)
righti = right.create_dot_iter(broadcast_shape, right_skip)
@@ -318,10 +316,10 @@
dot_driver.jit_merge_point(dtype=dtype)
lval = lefti.getitem().convert_to(space, dtype)
rval = righti.getitem().convert_to(space, dtype)
- outval = outi.getitem().convert_to(space, dtype)
+ outval = outi.getitem()
v = dtype.itemtype.mul(lval, rval)
- value = dtype.itemtype.add(v, outval).convert_to(space, dtype)
- outi.setitem(value)
+ v = dtype.itemtype.add(v, outval)
+ outi.setitem(v)
outi.next()
righti.next()
lefti.next()
@@ -652,8 +650,8 @@
out_iter = out.create_iter(shape)
while not arr_iter.done():
round_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
- w_v = dtype.itemtype.round(arr_iter.getitem().convert_to(space, dtype),
- decimals)
+ w_v = arr_iter.getitem().convert_to(space, dtype)
+ w_v = dtype.itemtype.round(w_v, decimals)
out_iter.setitem(w_v)
arr_iter.next()
out_iter.next()
diff --git a/pypy/module/micronumpy/test/test_arrayops.py
b/pypy/module/micronumpy/test/test_arrayops.py
--- a/pypy/module/micronumpy/test/test_arrayops.py
+++ b/pypy/module/micronumpy/test/test_arrayops.py
@@ -56,6 +56,10 @@
b = arange(12).reshape(4, 3)
c = a.dot(b)
assert (c == [[ 42, 48, 54], [114, 136, 158], [186, 224, 262]]).all()
+ c = a.dot(b.astype(float))
+ assert (c == [[ 42, 48, 54], [114, 136, 158], [186, 224, 262]]).all()
+ c = a.astype(float).dot(b)
+ assert (c == [[ 42, 48, 54], [114, 136, 158], [186, 224, 262]]).all()
a = arange(24).reshape(2, 3, 4)
raises(ValueError, "a.dot(a)")
@@ -91,9 +95,11 @@
out = arange(9).reshape(3, 3)
c = dot(a, b, out=out)
assert (c == out).all()
- out = arange(9,dtype=float).reshape(3, 3)
+ assert (c == [[42, 48, 54], [114, 136, 158], [186, 224, 262]]).all()
+ out = arange(9, dtype=float).reshape(3, 3)
exc = raises(ValueError, dot, a, b, out)
- assert exc.value[0].find('not acceptable') > 0
+ assert exc.value[0] == ('output array is not acceptable (must have the
'
+ 'right type, nr dimensions, and be a C-Array)')
def test_choose_basic(self):
from numpypy import array
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit