Author: Ronan Lamy <ronan.l...@gmail.com> Branch: fix-result-types Changeset: r77296:2fc8c1b68f07 Date: 2015-05-12 06:16 +0100 http://bitbucket.org/pypy/pypy/changeset/2fc8c1b68f07/
Log: Use the same logic as cnumpy in W_Ufunc1.find_specialization() diff --git a/pypy/module/micronumpy/descriptor.py b/pypy/module/micronumpy/descriptor.py --- a/pypy/module/micronumpy/descriptor.py +++ b/pypy/module/micronumpy/descriptor.py @@ -900,17 +900,20 @@ NPY.CDOUBLE: self.w_float64dtype, NPY.CLONGDOUBLE: self.w_floatlongdtype, } - self.builtin_dtypes = [ - self.w_booldtype, + integer_dtypes = [ self.w_int8dtype, self.w_uint8dtype, self.w_int16dtype, self.w_uint16dtype, + self.w_int32dtype, self.w_uint32dtype, self.w_longdtype, self.w_ulongdtype, - self.w_int32dtype, self.w_uint32dtype, - self.w_int64dtype, self.w_uint64dtype, - ] + float_dtypes + complex_dtypes + [ - self.w_stringdtype, self.w_unicodedtype, self.w_voiddtype, - self.w_objectdtype, - ] + self.w_int64dtype, self.w_uint64dtype] + self.builtin_dtypes = ([self.w_booldtype] + integer_dtypes + + float_dtypes + complex_dtypes + [ + self.w_stringdtype, self.w_unicodedtype, self.w_voiddtype, + self.w_objectdtype, + ]) + self.integer_dtypes = integer_dtypes + self.float_dtypes = float_dtypes + self.complex_dtypes = complex_dtypes self.float_dtypes_by_num_bytes = sorted( (dtype.elsize, dtype) for dtype in float_dtypes diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py --- a/pypy/module/micronumpy/test/test_ufuncs.py +++ b/pypy/module/micronumpy/test/test_ufuncs.py @@ -1,5 +1,5 @@ from pypy.module.micronumpy.test.test_base import BaseNumpyAppTest -from pypy.module.micronumpy.ufuncs import W_UfuncGeneric +from pypy.module.micronumpy.ufuncs import W_UfuncGeneric, W_Ufunc1 from pypy.module.micronumpy.support import _parse_signature from pypy.module.micronumpy.descriptor import get_dtype_cache from pypy.module.micronumpy.base import W_NDimArray @@ -54,6 +54,20 @@ exc = raises(OperationError, ufunc.type_resolver, space, [f32_array], [None], 'i->i', ufunc.dtypes) + def test_allowed_types(self, space): + dt_bool = get_dtype_cache(space).w_booldtype + dt_float16 = get_dtype_cache(space).w_float16dtype + dt_int32 = get_dtype_cache(space).w_int32dtype + ufunc = W_Ufunc1(None, 'x', int_only=True) + assert ufunc._calc_dtype(space, dt_bool) == dt_bool + assert ufunc.allowed_types(space) # XXX: shouldn't contain too much stuff + + ufunc = W_Ufunc1(None, 'x', promote_to_float=True) + assert ufunc._calc_dtype(space, dt_bool) == dt_float16 + + ufunc = W_Ufunc1(None, 'x') + assert ufunc._calc_dtype(space, dt_int32) == dt_int32 + class AppTestUfuncs(BaseNumpyAppTest): def test_constants(self): import numpy as np diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py --- a/pypy/module/micronumpy/ufuncs.py +++ b/pypy/module/micronumpy/ufuncs.py @@ -18,7 +18,8 @@ from pypy.module.micronumpy.nditer import W_NDIter, coalesce_iter from pypy.module.micronumpy.strides import shape_agreement from pypy.module.micronumpy.support import _parse_signature, product, get_storage_as_int -from .casting import find_unaryop_result_dtype, find_binop_result_dtype +from .casting import ( + find_unaryop_result_dtype, find_binop_result_dtype, can_cast_type) def done_if_true(dtype, val): return dtype.itemtype.bool(val) @@ -384,12 +385,36 @@ not self.allow_complex and dtype.is_complex()): raise oefmt(space.w_TypeError, "ufunc %s not supported for the input type", self.name) - calc_dtype = find_unaryop_result_dtype(space, - dtype, - promote_to_float=self.promote_to_float, - promote_bools=self.promote_bools) + calc_dtype = self._calc_dtype(space, dtype) return calc_dtype, self.func + def _calc_dtype(self, space, arg_dtype): + use_min_scalar=False + if arg_dtype.is_object(): + return arg_dtype + for dtype in self.allowed_types(space): + if use_min_scalar: + if can_cast_array(space, w_arg, dtype, casting='safe'): + return dtype + else: + if can_cast_type(space, arg_dtype, dtype, casting='safe'): + return dtype + else: + raise oefmt(space.w_TypeError, + "No loop matching the specified signature was found " + "for ufunc %s", self.name) + + def allowed_types(self, space): + dtypes = [] + cache = get_dtype_cache(space) + if not self.promote_bools and not self.promote_to_float: + dtypes.append(cache.w_booldtype) + if not self.promote_to_float: + dtypes.extend(cache.integer_dtypes) + dtypes.extend(cache.float_dtypes) + dtypes.extend(cache.complex_dtypes) + return dtypes + class W_Ufunc2(W_Ufunc): _immutable_fields_ = ["func", "comparison_func", "done_func"] _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit