Author: Alex Gaynor <[email protected]>
Branch: numpy-dtype-alt
Changeset: r46706:5377b6e0918b
Date: 2011-08-22 12:41 -0500
http://bitbucket.org/pypy/pypy/changeset/5377b6e0918b/
Log: fix for sum/prod with various dtypes. breaks test_zjit.
diff --git a/pypy/module/micronumpy/interp_dtype.py
b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -218,6 +218,10 @@
class IntegerArithmeticDtype(object):
_mixin_ = True
+ @binop
+ def add(self, v1, v2):
+ return v1 + v2
+
def str_format(self, item):
return str(widen(self.unbox(item)))
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -92,13 +92,19 @@
reduce_driver.jit_merge_point(signature=self.signature,
self=self, res_dtype=res_dtype,
size=size, i=i, result=result)
- result = getattr(res_dtype, op_name)(result, self.eval(i))
+ result = getattr(res_dtype, op_name)(
+ result,
+ self.eval(i).convert_to(res_dtype)
+ )
i += 1
return result
def impl(self, space):
- result =
space.fromcache(interp_dtype.W_Float64Dtype).box(init).convert_to(self.find_dtype())
- return loop(self, self.find_dtype(), result,
self.find_size()).wrap(space)
+ dtype = interp_ufuncs.find_unaryop_result_dtype(
+ space, self.find_dtype(), promote_to_largest=True
+ )
+ result = dtype.adapt_val(init)
+ return loop(self, dtype, result, self.find_size()).wrap(space)
return func_with_new_name(impl, "reduce_%s_impl" % op_name)
def _reduce_max_min_impl(op_name):
@@ -178,8 +184,8 @@
def descr_any(self, space):
return space.wrap(self._any())
- descr_sum = _reduce_sum_prod_impl("add", 0.0)
- descr_prod = _reduce_sum_prod_impl("mul", 1.0)
+ descr_sum = _reduce_sum_prod_impl("add", 0)
+ descr_prod = _reduce_sum_prod_impl("mul", 1)
descr_max = _reduce_max_min_impl("max")
descr_min = _reduce_max_min_impl("min")
descr_argmax = _reduce_argmax_argmin_impl("max")
diff --git a/pypy/module/micronumpy/interp_ufuncs.py
b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -73,11 +73,19 @@
assert False
-def find_unaryop_result_dtype(space, dt, promote_to_float=False):
+def find_unaryop_result_dtype(space, dt, promote_to_float=False,
+ promote_to_largest=False):
if promote_to_float:
for bytes, dtype in interp_dtype.dtypes_by_num_bytes:
if dtype.kind == interp_dtype.FLOATINGLTR and dtype.num_bytes >=
dt.num_bytes:
return space.fromcache(dtype)
+ if promote_to_largest:
+ if dt.kind == interp_dtype.BOOLLTR or dt.kind ==
interp_dtype.SIGNEDLTR:
+ return space.fromcache(interp_dtype.W_Int64Dtype)
+ elif dt.kind == interp_dtype.FLOATINGLTR:
+ return space.fromcache(interp_dtype.W_Float64Dtype)
+ else:
+ assert False
return dt
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -412,6 +412,9 @@
assert a.sum() == 10.0
assert a[:4].sum() == 6.0
+ a = array([True] * 5, bool)
+ assert a.sum() == 5
+
def test_prod(self):
from numpy import array
a = array(range(1,6))
diff --git a/pypy/module/micronumpy/test/test_zjit.py
b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -14,7 +14,7 @@
class TestNumpyJIt(LLJitMixin):
def setup_class(cls):
cls.space = FakeSpace()
- cls.float64_dtype = W_Float64Dtype(cls.space)
+ cls.float64_dtype = cls.space.fromcache(W_Float64Dtype)
def test_add(self):
def f(i):
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit