Author: Ronan Lamy <ronan.l...@gmail.com> Branch: Changeset: r77796:cad6015d5380 Date: 2015-06-03 00:04 +0100 http://bitbucket.org/pypy/pypy/changeset/cad6015d5380/
Log: Merged use_min_scalar into default diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py --- a/pypy/module/micronumpy/test/test_ufuncs.py +++ b/pypy/module/micronumpy/test/test_ufuncs.py @@ -1349,3 +1349,6 @@ assert np.add(np.float16(0), np.longdouble(0)).dtype == np.longdouble assert np.add(np.float16(0), np.complex64(0)).dtype == np.complex64 assert np.add(np.float16(0), np.complex128(0)).dtype == np.complex128 + assert np.add(np.zeros(5, dtype=np.int8), 257).dtype == np.int16 + assert np.subtract(np.zeros(5, dtype=np.int8), 257).dtype == np.int16 + assert np.divide(np.zeros(5, dtype=np.int8), 257).dtype == np.int16 diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py --- a/pypy/module/micronumpy/ufuncs.py +++ b/pypy/module/micronumpy/ufuncs.py @@ -21,7 +21,8 @@ from pypy.module.micronumpy.support import (_parse_signature, product, get_storage_as_int, is_rhs_priority_higher) from .casting import ( - can_cast_type, can_cast_to, find_result_type, promote_types) + can_cast_type, can_cast_array, can_cast_to, + find_result_type, promote_types) from .boxes import W_GenericBox, W_ObjectBox def done_if_true(dtype, val): @@ -495,17 +496,12 @@ return dt_in, dt_out, self.func def _calc_dtype(self, space, arg_dtype, out=None, casting='unsafe'): - use_min_scalar = False if arg_dtype.is_object(): return arg_dtype, arg_dtype in_casting = safe_casting_mode(casting) for dt_in, dt_out in self.dtypes: - if use_min_scalar: - if not can_cast_array(space, w_arg, dt_in, in_casting): - continue - else: - if not can_cast_type(space, arg_dtype, dt_in, in_casting): - continue + if not can_cast_type(space, arg_dtype, dt_in, in_casting): + continue if out is not None: res_dtype = out.get_dtype() if not can_cast_type(space, dt_out, res_dtype, casting): @@ -605,21 +601,18 @@ w_rdtype.get_name(), w_ldtype.get_name(), self.name) - if self.are_common_types(w_ldtype, w_rdtype): - if not w_lhs.is_scalar() and w_rhs.is_scalar(): - w_rdtype = w_ldtype - elif w_lhs.is_scalar() and not w_rhs.is_scalar(): - w_ldtype = w_rdtype - calc_dtype, dt_out, func = self.find_specialization(space, w_ldtype, w_rdtype, out, casting) if (isinstance(w_lhs, W_GenericBox) and isinstance(w_rhs, W_GenericBox) and out is None): - return self.call_scalar(space, w_lhs, w_rhs, calc_dtype) + return self.call_scalar(space, w_lhs, w_rhs, casting) if isinstance(w_lhs, W_GenericBox): w_lhs = W_NDimArray.from_scalar(space, w_lhs) assert isinstance(w_lhs, W_NDimArray) if isinstance(w_rhs, W_GenericBox): w_rhs = W_NDimArray.from_scalar(space, w_rhs) assert isinstance(w_rhs, W_NDimArray) + calc_dtype, dt_out, func = self.find_specialization( + space, w_ldtype, w_rdtype, out, casting, w_lhs, w_rhs) + new_shape = shape_agreement(space, w_lhs.get_shape(), w_rhs) new_shape = shape_agreement(space, new_shape, out, broadcast_down=False) w_highpriority, out_subtype = array_priority(space, w_lhs, w_rhs) @@ -637,7 +630,10 @@ w_res = space.call_method(w_highpriority, '__array_wrap__', w_res, ctxt) return w_res - def call_scalar(self, space, w_lhs, w_rhs, in_dtype): + def call_scalar(self, space, w_lhs, w_rhs, casting): + in_dtype, out_dtype, func = self.find_specialization( + space, w_lhs.get_dtype(space), w_rhs.get_dtype(space), + out=None, casting=casting) w_val = self.func(in_dtype, w_lhs.convert_to(space, in_dtype), w_rhs.convert_to(space, in_dtype)) @@ -645,7 +641,8 @@ return w_val.w_obj return w_val - def _find_specialization(self, space, l_dtype, r_dtype, out, casting): + def _find_specialization(self, space, l_dtype, r_dtype, out, casting, + w_arg1, w_arg2): if (not self.allow_bool and (l_dtype.is_bool() or r_dtype.is_bool()) or not self.allow_complex and (l_dtype.is_complex() or @@ -657,15 +654,23 @@ dtype = find_result_type(space, [], [l_dtype, r_dtype]) bool_dtype = get_dtype_cache(space).w_booldtype return dtype, bool_dtype, self.func - dt_in, dt_out = self._calc_dtype(space, l_dtype, r_dtype, out, casting) + dt_in, dt_out = self._calc_dtype( + space, l_dtype, r_dtype, out, casting, w_arg1, w_arg2) return dt_in, dt_out, self.func - def find_specialization(self, space, l_dtype, r_dtype, out, casting): + def find_specialization(self, space, l_dtype, r_dtype, out, casting, + w_arg1=None, w_arg2=None): if self.simple_binary: if out is None and not (l_dtype.is_object() or r_dtype.is_object()): - dtype = promote_types(space, l_dtype, r_dtype) + if w_arg1 is not None and w_arg2 is not None: + w_arg1 = convert_to_array(space, w_arg1) + w_arg2 = convert_to_array(space, w_arg2) + dtype = find_result_type(space, [w_arg1, w_arg2], []) + else: + dtype = promote_types(space, l_dtype, r_dtype) return dtype, dtype, self.func - return self._find_specialization(space, l_dtype, r_dtype, out, casting) + return self._find_specialization( + space, l_dtype, r_dtype, out, casting, w_arg1, w_arg2) def find_binop_type(self, space, dtype): """Find a valid dtype signature of the form xx->x""" @@ -686,15 +691,21 @@ "requested type has type code '%s'" % (self.name, dtype.char)) - def _calc_dtype(self, space, l_dtype, r_dtype, out=None, casting='unsafe'): - use_min_scalar = False + def _calc_dtype(self, space, l_dtype, r_dtype, out, casting, + w_arg1, w_arg2): if l_dtype.is_object() or r_dtype.is_object(): dtype = get_dtype_cache(space).w_objectdtype return dtype, dtype + use_min_scalar = (w_arg1 is not None and w_arg2 is not None and + ((w_arg1.is_scalar() and not w_arg2.is_scalar()) or + (not w_arg1.is_scalar() and w_arg2.is_scalar()))) in_casting = safe_casting_mode(casting) for dt_in, dt_out in self.dtypes: if use_min_scalar: - if not can_cast_array(space, w_arg, dt_in, in_casting): + w_arg1 = convert_to_array(space, w_arg1) + w_arg2 = convert_to_array(space, w_arg2) + if not (can_cast_array(space, w_arg1, dt_in, in_casting) and + can_cast_array(space, w_arg2, dt_in, in_casting)): continue else: if not (can_cast_type(space, l_dtype, dt_in, in_casting) and _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit