Author: Ronan Lamy <ronan.l...@gmail.com> Branch: fix-result-types Changeset: r77236:f1bdad98b4a8 Date: 2015-05-08 20:56 +0100 http://bitbucket.org/pypy/pypy/changeset/f1bdad98b4a8/
Log: move some stuff diff --git a/pypy/module/micronumpy/arrayops.py b/pypy/module/micronumpy/arrayops.py --- a/pypy/module/micronumpy/arrayops.py +++ b/pypy/module/micronumpy/arrayops.py @@ -1,11 +1,12 @@ from pypy.interpreter.error import OperationError, oefmt from pypy.interpreter.gateway import unwrap_spec -from pypy.module.micronumpy import loop, descriptor, ufuncs, support, \ - constants as NPY +from pypy.module.micronumpy import loop, descriptor, support +from pypy.module.micronumpy import constants as NPY from pypy.module.micronumpy.base import convert_to_array, W_NDimArray from pypy.module.micronumpy.converters import clipmode_converter from pypy.module.micronumpy.strides import ( Chunk, Chunks, shape_agreement, shape_agreement_multiple) +from .casting import find_binop_result_dtype def where(space, w_arr, w_x=None, w_y=None): @@ -84,8 +85,7 @@ if arr.get_dtype().itemtype.bool(arr.get_scalar_value()): return x return y - dtype = ufuncs.find_binop_result_dtype(space, x.get_dtype(), - y.get_dtype()) + dtype = find_binop_result_dtype(space, x.get_dtype(), y.get_dtype()) shape = shape_agreement(space, arr.get_shape(), x) shape = shape_agreement(space, shape, y) out = W_NDimArray.from_shape(space, shape, dtype) @@ -148,7 +148,7 @@ elif dtype.is_record() or a_dt.is_record(): raise OperationError(space.w_TypeError, space.wrap("invalid type promotion")) - dtype = ufuncs.find_binop_result_dtype(space, dtype, + dtype = find_binop_result_dtype(space, dtype, arr.get_dtype()) # concatenate does not handle ndarray subtypes, it always returns a ndarray res = W_NDimArray.from_shape(space, shape, dtype, 'C') diff --git a/pypy/module/micronumpy/casting.py b/pypy/module/micronumpy/casting.py --- a/pypy/module/micronumpy/casting.py +++ b/pypy/module/micronumpy/casting.py @@ -1,16 +1,15 @@ """Functions and helpers for converting between dtypes""" from rpython.rlib import jit +from rpython.rlib.rarithmetic import LONG_BIT from pypy.interpreter.gateway import unwrap_spec -from pypy.interpreter.error import oefmt +from pypy.interpreter.error import oefmt, OperationError from pypy.module.micronumpy.base import W_NDimArray, convert_to_array from pypy.module.micronumpy import constants as NPY -from pypy.module.micronumpy.ufuncs import ( - find_binop_result_dtype, find_dtype_for_scalar) from .types import ( Bool, ULong, Long, Float64, Complex64, UnicodeType, VoidType, ObjectType) -from .descriptor import get_dtype_cache, as_dtype, is_scalar_w +from .descriptor import get_dtype_cache, as_dtype, is_scalar_w, variable_dtype @jit.unroll_safe def result_type(space, __args__): @@ -106,3 +105,162 @@ return get_dtype_cache(space).dtypes_by_num[num] else: return dtype + +@jit.unroll_safe +def find_unaryop_result_dtype(space, dt, promote_to_float=False, + promote_bools=False, promote_to_largest=False): + if dt.is_object(): + return dt + if promote_to_largest: + if dt.kind == NPY.GENBOOLLTR or dt.kind == NPY.SIGNEDLTR: + if dt.elsize * 8 < LONG_BIT: + return get_dtype_cache(space).w_longdtype + elif dt.kind == NPY.UNSIGNEDLTR: + if dt.elsize * 8 < LONG_BIT: + return get_dtype_cache(space).w_ulongdtype + else: + assert dt.kind == NPY.FLOATINGLTR or dt.kind == NPY.COMPLEXLTR + return dt + if promote_bools and (dt.kind == NPY.GENBOOLLTR): + return get_dtype_cache(space).w_int8dtype + if promote_to_float: + if dt.kind == NPY.FLOATINGLTR or dt.kind == NPY.COMPLEXLTR: + return dt + if dt.num >= NPY.INT: + return get_dtype_cache(space).w_float64dtype + for bytes, dtype in get_dtype_cache(space).float_dtypes_by_num_bytes: + if (dtype.kind == NPY.FLOATINGLTR and + dtype.itemtype.get_element_size() > + dt.itemtype.get_element_size()): + return dtype + return dt + +def find_binop_result_dtype(space, dt1, dt2, promote_to_float=False, + promote_bools=False): + if dt2 is None: + return dt1 + if dt1 is None: + return dt2 + + if dt1.num == NPY.OBJECT or dt2.num == NPY.OBJECT: + return get_dtype_cache(space).w_objectdtype + + # dt1.num should be <= dt2.num + if dt1.num > dt2.num: + dt1, dt2 = dt2, dt1 + # Some operations promote op(bool, bool) to return int8, rather than bool + if promote_bools and (dt1.kind == dt2.kind == NPY.GENBOOLLTR): + return get_dtype_cache(space).w_int8dtype + + # Everything numeric promotes to complex + if dt2.is_complex() or dt1.is_complex(): + if dt2.num == NPY.HALF: + dt1, dt2 = dt2, dt1 + if dt2.num == NPY.CFLOAT: + if dt1.num == NPY.DOUBLE: + return get_dtype_cache(space).w_complex128dtype + elif dt1.num == NPY.LONGDOUBLE: + return get_dtype_cache(space).w_complexlongdtype + return get_dtype_cache(space).w_complex64dtype + elif dt2.num == NPY.CDOUBLE: + if dt1.num == NPY.LONGDOUBLE: + return get_dtype_cache(space).w_complexlongdtype + return get_dtype_cache(space).w_complex128dtype + elif dt2.num == NPY.CLONGDOUBLE: + return get_dtype_cache(space).w_complexlongdtype + else: + raise OperationError(space.w_TypeError, space.wrap("Unsupported types")) + + if promote_to_float: + return find_unaryop_result_dtype(space, dt2, promote_to_float=True) + # If they're the same kind, choose the greater one. + if dt1.kind == dt2.kind and not dt2.is_flexible(): + if dt2.num == NPY.HALF: + return dt1 + return dt2 + + # Everything promotes to float, and bool promotes to everything. + if dt2.kind == NPY.FLOATINGLTR or dt1.kind == NPY.GENBOOLLTR: + if dt2.num == NPY.HALF and dt1.itemtype.get_element_size() == 2: + return get_dtype_cache(space).w_float32dtype + if dt2.num == NPY.HALF and dt1.itemtype.get_element_size() >= 4: + return get_dtype_cache(space).w_float64dtype + if dt2.num == NPY.FLOAT and dt1.itemtype.get_element_size() >= 4: + return get_dtype_cache(space).w_float64dtype + return dt2 + + # for now this means mixing signed and unsigned + if dt2.kind == NPY.SIGNEDLTR: + # if dt2 has a greater number of bytes, then just go with it + if dt1.itemtype.get_element_size() < dt2.itemtype.get_element_size(): + return dt2 + # we need to promote both dtypes + dtypenum = dt2.num + 2 + elif dt2.num == NPY.ULONGLONG or (LONG_BIT == 64 and dt2.num == NPY.ULONG): + # UInt64 + signed = Float64 + dtypenum = NPY.DOUBLE + elif dt2.is_flexible(): + # For those operations that get here (concatenate, stack), + # flexible types take precedence over numeric type + if dt2.is_record(): + return dt2 + if dt1.is_str_or_unicode(): + if dt2.elsize >= dt1.elsize: + return dt2 + return dt1 + return dt2 + else: + # increase to the next signed type + dtypenum = dt2.num + 1 + newdtype = get_dtype_cache(space).dtypes_by_num[dtypenum] + + if (newdtype.itemtype.get_element_size() > dt2.itemtype.get_element_size() or + newdtype.kind == NPY.FLOATINGLTR): + return newdtype + else: + # we only promoted to long on 32-bit or to longlong on 64-bit + # this is really for dealing with the Long and Ulong dtypes + dtypenum += 2 + return get_dtype_cache(space).dtypes_by_num[dtypenum] + +def find_dtype_for_scalar(space, w_obj, current_guess=None): + from .boxes import W_GenericBox + bool_dtype = get_dtype_cache(space).w_booldtype + long_dtype = get_dtype_cache(space).w_longdtype + int64_dtype = get_dtype_cache(space).w_int64dtype + uint64_dtype = get_dtype_cache(space).w_uint64dtype + complex_dtype = get_dtype_cache(space).w_complex128dtype + float_dtype = get_dtype_cache(space).w_float64dtype + object_dtype = get_dtype_cache(space).w_objectdtype + if isinstance(w_obj, W_GenericBox): + dtype = w_obj.get_dtype(space) + return find_binop_result_dtype(space, dtype, current_guess) + + if space.isinstance_w(w_obj, space.w_bool): + return find_binop_result_dtype(space, bool_dtype, current_guess) + elif space.isinstance_w(w_obj, space.w_int): + return find_binop_result_dtype(space, long_dtype, current_guess) + elif space.isinstance_w(w_obj, space.w_long): + try: + space.int_w(w_obj) + except OperationError, e: + if e.match(space, space.w_OverflowError): + if space.is_true(space.le(w_obj, space.wrap(0))): + return find_binop_result_dtype(space, int64_dtype, + current_guess) + return find_binop_result_dtype(space, uint64_dtype, + current_guess) + raise + return find_binop_result_dtype(space, int64_dtype, current_guess) + elif space.isinstance_w(w_obj, space.w_float): + return find_binop_result_dtype(space, float_dtype, current_guess) + elif space.isinstance_w(w_obj, space.w_complex): + return complex_dtype + elif space.isinstance_w(w_obj, space.w_str): + if current_guess is None: + return variable_dtype(space, 'S%d' % space.len_w(w_obj)) + elif current_guess.num == NPY.STRING: + if current_guess.elsize < space.len_w(w_obj): + return variable_dtype(space, 'S%d' % space.len_w(w_obj)) + return current_guess + return object_dtype diff --git a/pypy/module/micronumpy/descriptor.py b/pypy/module/micronumpy/descriptor.py --- a/pypy/module/micronumpy/descriptor.py +++ b/pypy/module/micronumpy/descriptor.py @@ -29,7 +29,7 @@ """ agree on dtype from a list of arrays. if out is allocated, use it's dtype, otherwise allocate a new one with agreed dtype """ - from pypy.module.micronumpy.ufuncs import find_binop_result_dtype + from .casting import find_binop_result_dtype if not space.is_none(out): return out @@ -1011,7 +1011,7 @@ return space.fromcache(DtypeCache) def as_dtype(space, w_arg, allow_None=True): - from pypy.module.micronumpy.ufuncs import find_dtype_for_scalar + from pypy.module.micronumpy.casting import find_dtype_for_scalar # roughly equivalent to CNumPy's PyArray_DescrConverter2 if not allow_None and space.is_none(w_arg): raise TypeError("Cannot create dtype from None here") diff --git a/pypy/module/micronumpy/ndarray.py b/pypy/module/micronumpy/ndarray.py --- a/pypy/module/micronumpy/ndarray.py +++ b/pypy/module/micronumpy/ndarray.py @@ -988,6 +988,7 @@ return space.newtuple([w_quotient, w_remainder]) def descr_dot(self, space, w_other, w_out=None): + from .casting import find_binop_result_dtype if space.is_none(w_out): out = None elif not isinstance(w_out, W_NDimArray): @@ -1002,7 +1003,7 @@ w_res = self.descr_mul(space, other) assert isinstance(w_res, W_NDimArray) return w_res.descr_sum(space, space.wrap(-1), out) - dtype = ufuncs.find_binop_result_dtype(space, self.get_dtype(), + dtype = find_binop_result_dtype(space, self.get_dtype(), other.get_dtype()) if self.get_size() < 1 and other.get_size() < 1: # numpy compatability diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py --- a/pypy/module/micronumpy/nditer.py +++ b/pypy/module/micronumpy/nditer.py @@ -9,6 +9,7 @@ from pypy.module.micronumpy.iterators import ArrayIter from pypy.module.micronumpy.strides import (calculate_broadcast_strides, shape_agreement, shape_agreement_multiple) +from pypy.module.micronumpy.casting import find_binop_result_dtype def parse_op_arg(space, name, w_op_flags, n, parse_one_arg): @@ -173,7 +174,7 @@ def __init__(self, array, size, shape, strides, backstrides, op_flags, base): OperandIter.__init__(self, array, size, shape, strides, backstrides) - self.slice_shape =[] + self.slice_shape =[] self.slice_stride = [] self.slice_backstride = [] if op_flags.rw == 'r': @@ -302,7 +303,7 @@ But after coalesce(), getoperand() will return a slice by removing the fastest varying dimension(s) from the beginning or end of the shape. If flat is true, then the slice will be 1d, otherwise stack up the shape of - the fastest varying dimension in the slice, so an iterator of a 'C' array + the fastest varying dimension in the slice, so an iterator of a 'C' array of shape (2,4,3) after two calls to coalesce will iterate 2 times over a slice of shape (4,3) by setting the offset to the beginning of the data at each iteration ''' @@ -367,8 +368,6 @@ _immutable_fields_ = ['ndim', ] def __init__(self, space, w_seq, w_flags, w_op_flags, w_op_dtypes, w_casting, w_op_axes, w_itershape, buffersize=0, order='K'): - from pypy.module.micronumpy.ufuncs import find_binop_result_dtype - self.order = order self.external_loop = False self.buffered = False diff --git a/pypy/module/micronumpy/strides.py b/pypy/module/micronumpy/strides.py --- a/pypy/module/micronumpy/strides.py +++ b/pypy/module/micronumpy/strides.py @@ -215,7 +215,7 @@ def find_dtype_for_seq(space, elems_w, dtype): - from pypy.module.micronumpy.ufuncs import find_dtype_for_scalar + from pypy.module.micronumpy.casting import find_dtype_for_scalar if len(elems_w) == 1: w_elem = elems_w[0] if isinstance(w_elem, W_NDimArray) and w_elem.is_scalar(): @@ -225,7 +225,7 @@ def _find_dtype_for_seq(space, elems_w, dtype): - from pypy.module.micronumpy.ufuncs import find_dtype_for_scalar + from pypy.module.micronumpy.casting import find_dtype_for_scalar for w_elem in elems_w: if isinstance(w_elem, W_NDimArray) and w_elem.is_scalar(): w_elem = w_elem.get_scalar_value() diff --git a/pypy/module/micronumpy/test/test_casting.py b/pypy/module/micronumpy/test/test_casting.py --- a/pypy/module/micronumpy/test/test_casting.py +++ b/pypy/module/micronumpy/test/test_casting.py @@ -1,4 +1,7 @@ from pypy.module.micronumpy.test.test_base import BaseNumpyAppTest +from pypy.module.micronumpy.descriptor import get_dtype_cache +from pypy.module.micronumpy.casting import ( + find_unaryop_result_dtype, find_binop_result_dtype) class AppTestNumSupport(BaseNumpyAppTest): @@ -119,3 +122,80 @@ assert np.min_scalar_type(2**64 - 1) == np.dtype('uint64') # XXX: np.asarray(2**64) fails with OverflowError # assert np.min_scalar_type(2**64) == np.dtype('O') + +class TestCoercion(object): + def test_binops(self, space): + bool_dtype = get_dtype_cache(space).w_booldtype + int8_dtype = get_dtype_cache(space).w_int8dtype + int32_dtype = get_dtype_cache(space).w_int32dtype + float64_dtype = get_dtype_cache(space).w_float64dtype + c64_dtype = get_dtype_cache(space).w_complex64dtype + c128_dtype = get_dtype_cache(space).w_complex128dtype + cld_dtype = get_dtype_cache(space).w_complexlongdtype + fld_dtype = get_dtype_cache(space).w_floatlongdtype + + # Basic pairing + assert find_binop_result_dtype(space, bool_dtype, bool_dtype) is bool_dtype + assert find_binop_result_dtype(space, bool_dtype, float64_dtype) is float64_dtype + assert find_binop_result_dtype(space, float64_dtype, bool_dtype) is float64_dtype + assert find_binop_result_dtype(space, int32_dtype, int8_dtype) is int32_dtype + assert find_binop_result_dtype(space, int32_dtype, bool_dtype) is int32_dtype + assert find_binop_result_dtype(space, c64_dtype, float64_dtype) is c128_dtype + assert find_binop_result_dtype(space, c64_dtype, fld_dtype) is cld_dtype + assert find_binop_result_dtype(space, c128_dtype, fld_dtype) is cld_dtype + + # With promote bool (happens on div), the result is that the op should + # promote bools to int8 + assert find_binop_result_dtype(space, bool_dtype, bool_dtype, promote_bools=True) is int8_dtype + assert find_binop_result_dtype(space, bool_dtype, float64_dtype, promote_bools=True) is float64_dtype + + # Coerce to floats + assert find_binop_result_dtype(space, bool_dtype, float64_dtype, promote_to_float=True) is float64_dtype + + def test_unaryops(self, space): + bool_dtype = get_dtype_cache(space).w_booldtype + int8_dtype = get_dtype_cache(space).w_int8dtype + uint8_dtype = get_dtype_cache(space).w_uint8dtype + int16_dtype = get_dtype_cache(space).w_int16dtype + uint16_dtype = get_dtype_cache(space).w_uint16dtype + int32_dtype = get_dtype_cache(space).w_int32dtype + uint32_dtype = get_dtype_cache(space).w_uint32dtype + long_dtype = get_dtype_cache(space).w_longdtype + ulong_dtype = get_dtype_cache(space).w_ulongdtype + int64_dtype = get_dtype_cache(space).w_int64dtype + uint64_dtype = get_dtype_cache(space).w_uint64dtype + float16_dtype = get_dtype_cache(space).w_float16dtype + float32_dtype = get_dtype_cache(space).w_float32dtype + float64_dtype = get_dtype_cache(space).w_float64dtype + + # Normal rules, everything returns itself + assert find_unaryop_result_dtype(space, bool_dtype) is bool_dtype + assert find_unaryop_result_dtype(space, int8_dtype) is int8_dtype + assert find_unaryop_result_dtype(space, uint8_dtype) is uint8_dtype + assert find_unaryop_result_dtype(space, int16_dtype) is int16_dtype + assert find_unaryop_result_dtype(space, uint16_dtype) is uint16_dtype + assert find_unaryop_result_dtype(space, int32_dtype) is int32_dtype + assert find_unaryop_result_dtype(space, uint32_dtype) is uint32_dtype + assert find_unaryop_result_dtype(space, long_dtype) is long_dtype + assert find_unaryop_result_dtype(space, ulong_dtype) is ulong_dtype + assert find_unaryop_result_dtype(space, int64_dtype) is int64_dtype + assert find_unaryop_result_dtype(space, uint64_dtype) is uint64_dtype + assert find_unaryop_result_dtype(space, float32_dtype) is float32_dtype + assert find_unaryop_result_dtype(space, float64_dtype) is float64_dtype + + # Coerce to floats, some of these will eventually be float16, or + # whatever our smallest float type is. + assert find_unaryop_result_dtype(space, bool_dtype, promote_to_float=True) is float16_dtype + assert find_unaryop_result_dtype(space, int8_dtype, promote_to_float=True) is float16_dtype + assert find_unaryop_result_dtype(space, uint8_dtype, promote_to_float=True) is float16_dtype + assert find_unaryop_result_dtype(space, int16_dtype, promote_to_float=True) is float32_dtype + assert find_unaryop_result_dtype(space, uint16_dtype, promote_to_float=True) is float32_dtype + assert find_unaryop_result_dtype(space, int32_dtype, promote_to_float=True) is float64_dtype + assert find_unaryop_result_dtype(space, uint32_dtype, promote_to_float=True) is float64_dtype + assert find_unaryop_result_dtype(space, int64_dtype, promote_to_float=True) is float64_dtype + assert find_unaryop_result_dtype(space, uint64_dtype, promote_to_float=True) is float64_dtype + assert find_unaryop_result_dtype(space, float32_dtype, promote_to_float=True) is float32_dtype + assert find_unaryop_result_dtype(space, float64_dtype, promote_to_float=True) is float64_dtype + + # promote bools, happens with sign ufunc + assert find_unaryop_result_dtype(space, bool_dtype, promote_bools=True) is int8_dtype diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py --- a/pypy/module/micronumpy/test/test_ufuncs.py +++ b/pypy/module/micronumpy/test/test_ufuncs.py @@ -1,93 +1,12 @@ from pypy.module.micronumpy.test.test_base import BaseNumpyAppTest -from pypy.module.micronumpy.ufuncs import (find_binop_result_dtype, - find_unaryop_result_dtype, W_UfuncGeneric) +from pypy.module.micronumpy.ufuncs import W_UfuncGeneric from pypy.module.micronumpy.support import _parse_signature from pypy.module.micronumpy.descriptor import get_dtype_cache from pypy.module.micronumpy.base import W_NDimArray from pypy.module.micronumpy.concrete import VoidBoxStorage -from pypy.interpreter.gateway import interp2app -from pypy.conftest import option from pypy.interpreter.error import OperationError -class TestUfuncCoercion(object): - def test_binops(self, space): - bool_dtype = get_dtype_cache(space).w_booldtype - int8_dtype = get_dtype_cache(space).w_int8dtype - int32_dtype = get_dtype_cache(space).w_int32dtype - float64_dtype = get_dtype_cache(space).w_float64dtype - c64_dtype = get_dtype_cache(space).w_complex64dtype - c128_dtype = get_dtype_cache(space).w_complex128dtype - cld_dtype = get_dtype_cache(space).w_complexlongdtype - fld_dtype = get_dtype_cache(space).w_floatlongdtype - - # Basic pairing - assert find_binop_result_dtype(space, bool_dtype, bool_dtype) is bool_dtype - assert find_binop_result_dtype(space, bool_dtype, float64_dtype) is float64_dtype - assert find_binop_result_dtype(space, float64_dtype, bool_dtype) is float64_dtype - assert find_binop_result_dtype(space, int32_dtype, int8_dtype) is int32_dtype - assert find_binop_result_dtype(space, int32_dtype, bool_dtype) is int32_dtype - assert find_binop_result_dtype(space, c64_dtype, float64_dtype) is c128_dtype - assert find_binop_result_dtype(space, c64_dtype, fld_dtype) is cld_dtype - assert find_binop_result_dtype(space, c128_dtype, fld_dtype) is cld_dtype - - # With promote bool (happens on div), the result is that the op should - # promote bools to int8 - assert find_binop_result_dtype(space, bool_dtype, bool_dtype, promote_bools=True) is int8_dtype - assert find_binop_result_dtype(space, bool_dtype, float64_dtype, promote_bools=True) is float64_dtype - - # Coerce to floats - assert find_binop_result_dtype(space, bool_dtype, float64_dtype, promote_to_float=True) is float64_dtype - - def test_unaryops(self, space): - bool_dtype = get_dtype_cache(space).w_booldtype - int8_dtype = get_dtype_cache(space).w_int8dtype - uint8_dtype = get_dtype_cache(space).w_uint8dtype - int16_dtype = get_dtype_cache(space).w_int16dtype - uint16_dtype = get_dtype_cache(space).w_uint16dtype - int32_dtype = get_dtype_cache(space).w_int32dtype - uint32_dtype = get_dtype_cache(space).w_uint32dtype - long_dtype = get_dtype_cache(space).w_longdtype - ulong_dtype = get_dtype_cache(space).w_ulongdtype - int64_dtype = get_dtype_cache(space).w_int64dtype - uint64_dtype = get_dtype_cache(space).w_uint64dtype - float16_dtype = get_dtype_cache(space).w_float16dtype - float32_dtype = get_dtype_cache(space).w_float32dtype - float64_dtype = get_dtype_cache(space).w_float64dtype - - # Normal rules, everything returns itself - assert find_unaryop_result_dtype(space, bool_dtype) is bool_dtype - assert find_unaryop_result_dtype(space, int8_dtype) is int8_dtype - assert find_unaryop_result_dtype(space, uint8_dtype) is uint8_dtype - assert find_unaryop_result_dtype(space, int16_dtype) is int16_dtype - assert find_unaryop_result_dtype(space, uint16_dtype) is uint16_dtype - assert find_unaryop_result_dtype(space, int32_dtype) is int32_dtype - assert find_unaryop_result_dtype(space, uint32_dtype) is uint32_dtype - assert find_unaryop_result_dtype(space, long_dtype) is long_dtype - assert find_unaryop_result_dtype(space, ulong_dtype) is ulong_dtype - assert find_unaryop_result_dtype(space, int64_dtype) is int64_dtype - assert find_unaryop_result_dtype(space, uint64_dtype) is uint64_dtype - assert find_unaryop_result_dtype(space, float32_dtype) is float32_dtype - assert find_unaryop_result_dtype(space, float64_dtype) is float64_dtype - - # Coerce to floats, some of these will eventually be float16, or - # whatever our smallest float type is. - assert find_unaryop_result_dtype(space, bool_dtype, promote_to_float=True) is float16_dtype - assert find_unaryop_result_dtype(space, int8_dtype, promote_to_float=True) is float16_dtype - assert find_unaryop_result_dtype(space, uint8_dtype, promote_to_float=True) is float16_dtype - assert find_unaryop_result_dtype(space, int16_dtype, promote_to_float=True) is float32_dtype - assert find_unaryop_result_dtype(space, uint16_dtype, promote_to_float=True) is float32_dtype - assert find_unaryop_result_dtype(space, int32_dtype, promote_to_float=True) is float64_dtype - assert find_unaryop_result_dtype(space, uint32_dtype, promote_to_float=True) is float64_dtype - assert find_unaryop_result_dtype(space, int64_dtype, promote_to_float=True) is float64_dtype - assert find_unaryop_result_dtype(space, uint64_dtype, promote_to_float=True) is float64_dtype - assert find_unaryop_result_dtype(space, float32_dtype, promote_to_float=True) is float32_dtype - assert find_unaryop_result_dtype(space, float64_dtype, promote_to_float=True) is float64_dtype - - # promote bools, happens with sign ufunc - assert find_unaryop_result_dtype(space, bool_dtype, promote_bools=True) is int8_dtype - - class TestGenericUfuncOperation(object): def test_signature_parser(self, space): class Ufunc(object): @@ -96,10 +15,10 @@ self.nout = nout self.nargs = nin + nout self.core_enabled = True - self.core_num_dim_ix = 0 - self.core_num_dims = [0] * self.nargs + self.core_num_dim_ix = 0 + self.core_num_dims = [0] * self.nargs self.core_offsets = [0] * self.nargs - self.core_dim_ixs = [] + self.core_dim_ixs = [] u = Ufunc(2, 1) _parse_signature(space, u, '(m,n), (n,r)->(m,r)') @@ -116,8 +35,8 @@ b_dtype = get_dtype_cache(space).w_booldtype ufunc = W_UfuncGeneric(space, [None, None, None], 'eigenvals', None, 1, 1, - [f32_dtype, c64_dtype, - f64_dtype, c128_dtype, + [f32_dtype, c64_dtype, + f64_dtype, c128_dtype, c128_dtype, c128_dtype], '') f32_array = W_NDimArray(VoidBoxStorage(0, f32_dtype)) @@ -167,7 +86,7 @@ assert 'object' in str(e) # Use pypy specific extension for out_dtype adder_ufunc0 = frompyfunc(adder, 2, 1, dtypes=['match']) - sumdiff = frompyfunc(sumdiff, 2, 2, dtypes=['match'], + sumdiff = frompyfunc(sumdiff, 2, 2, dtypes=['match'], signature='(i),(i)->(i),(i)') adder_ufunc1 = frompyfunc([adder, adder], 2, 1, dtypes=[int, int, int, float, float, float]) diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py --- a/pypy/module/micronumpy/ufuncs.py +++ b/pypy/module/micronumpy/ufuncs.py @@ -4,22 +4,21 @@ from pypy.interpreter.typedef import TypeDef, GetSetProperty, interp_attrproperty from pypy.interpreter.argument import Arguments from rpython.rlib import jit -from rpython.rlib.rarithmetic import LONG_BIT, maxint +from rpython.rlib.rarithmetic import LONG_BIT, maxint, _get_bitsize from rpython.tool.sourcetools import func_with_new_name +from rpython.rlib.rawstorage import ( + raw_storage_setitem, free_raw_storage, alloc_raw_storage) +from rpython.rtyper.lltypesystem import rffi, lltype +from rpython.rlib.objectmodel import keepalive_until_here + from pypy.module.micronumpy import boxes, loop, constants as NPY -from pypy.module.micronumpy.descriptor import (get_dtype_cache, - variable_dtype, decode_w_dtype) +from pypy.module.micronumpy.descriptor import get_dtype_cache, decode_w_dtype from pypy.module.micronumpy.base import convert_to_array, W_NDimArray from pypy.module.micronumpy.ctors import numpify from pypy.module.micronumpy.nditer import W_NDIter, coalesce_iter from pypy.module.micronumpy.strides import shape_agreement from pypy.module.micronumpy.support import _parse_signature, product, get_storage_as_int -from rpython.rlib.rawstorage import (raw_storage_setitem, free_raw_storage, - alloc_raw_storage) -from rpython.rtyper.lltypesystem import rffi, lltype -from rpython.rlib.rarithmetic import LONG_BIT, _get_bitsize -from rpython.rlib.objectmodel import keepalive_until_here - +from .casting import find_unaryop_result_dtype, find_binop_result_dtype def done_if_true(dtype, val): return dtype.itemtype.bool(val) @@ -445,7 +444,7 @@ self.comparison_func and w_out is None: if self.name in ('equal', 'less_equal', 'less'): return space.wrap(False) - return space.wrap(True) + return space.wrap(True) elif (w_rdtype.is_str()) and \ self.comparison_func and w_out is None: if self.name in ('not_equal','less', 'less_equal'): @@ -955,170 +954,6 @@ ) -def find_binop_result_dtype(space, dt1, dt2, promote_to_float=False, - promote_bools=False): - if dt2 is None: - return dt1 - if dt1 is None: - return dt2 - - if dt1.num == NPY.OBJECT or dt2.num == NPY.OBJECT: - return get_dtype_cache(space).w_objectdtype - - # dt1.num should be <= dt2.num - if dt1.num > dt2.num: - dt1, dt2 = dt2, dt1 - # Some operations promote op(bool, bool) to return int8, rather than bool - if promote_bools and (dt1.kind == dt2.kind == NPY.GENBOOLLTR): - return get_dtype_cache(space).w_int8dtype - - # Everything numeric promotes to complex - if dt2.is_complex() or dt1.is_complex(): - if dt2.num == NPY.HALF: - dt1, dt2 = dt2, dt1 - if dt2.num == NPY.CFLOAT: - if dt1.num == NPY.DOUBLE: - return get_dtype_cache(space).w_complex128dtype - elif dt1.num == NPY.LONGDOUBLE: - return get_dtype_cache(space).w_complexlongdtype - return get_dtype_cache(space).w_complex64dtype - elif dt2.num == NPY.CDOUBLE: - if dt1.num == NPY.LONGDOUBLE: - return get_dtype_cache(space).w_complexlongdtype - return get_dtype_cache(space).w_complex128dtype - elif dt2.num == NPY.CLONGDOUBLE: - return get_dtype_cache(space).w_complexlongdtype - else: - raise OperationError(space.w_TypeError, space.wrap("Unsupported types")) - - if promote_to_float: - return find_unaryop_result_dtype(space, dt2, promote_to_float=True) - # If they're the same kind, choose the greater one. - if dt1.kind == dt2.kind and not dt2.is_flexible(): - if dt2.num == NPY.HALF: - return dt1 - return dt2 - - # Everything promotes to float, and bool promotes to everything. - if dt2.kind == NPY.FLOATINGLTR or dt1.kind == NPY.GENBOOLLTR: - if dt2.num == NPY.HALF and dt1.itemtype.get_element_size() == 2: - return get_dtype_cache(space).w_float32dtype - if dt2.num == NPY.HALF and dt1.itemtype.get_element_size() >= 4: - return get_dtype_cache(space).w_float64dtype - if dt2.num == NPY.FLOAT and dt1.itemtype.get_element_size() >= 4: - return get_dtype_cache(space).w_float64dtype - return dt2 - - # for now this means mixing signed and unsigned - if dt2.kind == NPY.SIGNEDLTR: - # if dt2 has a greater number of bytes, then just go with it - if dt1.itemtype.get_element_size() < dt2.itemtype.get_element_size(): - return dt2 - # we need to promote both dtypes - dtypenum = dt2.num + 2 - elif dt2.num == NPY.ULONGLONG or (LONG_BIT == 64 and dt2.num == NPY.ULONG): - # UInt64 + signed = Float64 - dtypenum = NPY.DOUBLE - elif dt2.is_flexible(): - # For those operations that get here (concatenate, stack), - # flexible types take precedence over numeric type - if dt2.is_record(): - return dt2 - if dt1.is_str_or_unicode(): - if dt2.elsize >= dt1.elsize: - return dt2 - return dt1 - return dt2 - else: - # increase to the next signed type - dtypenum = dt2.num + 1 - newdtype = get_dtype_cache(space).dtypes_by_num[dtypenum] - - if (newdtype.itemtype.get_element_size() > dt2.itemtype.get_element_size() or - newdtype.kind == NPY.FLOATINGLTR): - return newdtype - else: - # we only promoted to long on 32-bit or to longlong on 64-bit - # this is really for dealing with the Long and Ulong dtypes - dtypenum += 2 - return get_dtype_cache(space).dtypes_by_num[dtypenum] - - -@jit.unroll_safe -def find_unaryop_result_dtype(space, dt, promote_to_float=False, - promote_bools=False, promote_to_largest=False): - if dt.is_object(): - return dt - if promote_to_largest: - if dt.kind == NPY.GENBOOLLTR or dt.kind == NPY.SIGNEDLTR: - if dt.elsize * 8 < LONG_BIT: - return get_dtype_cache(space).w_longdtype - elif dt.kind == NPY.UNSIGNEDLTR: - if dt.elsize * 8 < LONG_BIT: - return get_dtype_cache(space).w_ulongdtype - else: - assert dt.kind == NPY.FLOATINGLTR or dt.kind == NPY.COMPLEXLTR - return dt - if promote_bools and (dt.kind == NPY.GENBOOLLTR): - return get_dtype_cache(space).w_int8dtype - if promote_to_float: - if dt.kind == NPY.FLOATINGLTR or dt.kind == NPY.COMPLEXLTR: - return dt - if dt.num >= NPY.INT: - return get_dtype_cache(space).w_float64dtype - for bytes, dtype in get_dtype_cache(space).float_dtypes_by_num_bytes: - if (dtype.kind == NPY.FLOATINGLTR and - dtype.itemtype.get_element_size() > - dt.itemtype.get_element_size()): - return dtype - return dt - - -def find_dtype_for_scalar(space, w_obj, current_guess=None): - bool_dtype = get_dtype_cache(space).w_booldtype - long_dtype = get_dtype_cache(space).w_longdtype - int64_dtype = get_dtype_cache(space).w_int64dtype - uint64_dtype = get_dtype_cache(space).w_uint64dtype - complex_dtype = get_dtype_cache(space).w_complex128dtype - float_dtype = get_dtype_cache(space).w_float64dtype - object_dtype = get_dtype_cache(space).w_objectdtype - if isinstance(w_obj, boxes.W_GenericBox): - dtype = w_obj.get_dtype(space) - return find_binop_result_dtype(space, dtype, current_guess) - - if space.isinstance_w(w_obj, space.w_bool): - return find_binop_result_dtype(space, bool_dtype, current_guess) - elif space.isinstance_w(w_obj, space.w_int): - return find_binop_result_dtype(space, long_dtype, current_guess) - elif space.isinstance_w(w_obj, space.w_long): - try: - space.int_w(w_obj) - except OperationError, e: - if e.match(space, space.w_OverflowError): - if space.is_true(space.le(w_obj, space.wrap(0))): - return find_binop_result_dtype(space, int64_dtype, - current_guess) - return find_binop_result_dtype(space, uint64_dtype, - current_guess) - raise - return find_binop_result_dtype(space, int64_dtype, current_guess) - elif space.isinstance_w(w_obj, space.w_float): - return find_binop_result_dtype(space, float_dtype, current_guess) - elif space.isinstance_w(w_obj, space.w_complex): - return complex_dtype - elif space.isinstance_w(w_obj, space.w_str): - if current_guess is None: - return variable_dtype(space, - 'S%d' % space.len_w(w_obj)) - elif current_guess.num == NPY.STRING: - if current_guess.elsize < space.len_w(w_obj): - return variable_dtype(space, - 'S%d' % space.len_w(w_obj)) - return current_guess - return object_dtype - #raise oefmt(space.w_NotImplementedError, - # 'unable to create dtype from objects, "%T" instance not ' - # 'supported', w_obj) def ufunc_dtype_caller(space, ufunc_name, op_name, nin, comparison_func, _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit