Author: Ronan Lamy <ronan.l...@gmail.com>
Branch: fix-result-types
Changeset: r77236:f1bdad98b4a8
Date: 2015-05-08 20:56 +0100
http://bitbucket.org/pypy/pypy/changeset/f1bdad98b4a8/

Log:    move some stuff

diff --git a/pypy/module/micronumpy/arrayops.py 
b/pypy/module/micronumpy/arrayops.py
--- a/pypy/module/micronumpy/arrayops.py
+++ b/pypy/module/micronumpy/arrayops.py
@@ -1,11 +1,12 @@
 from pypy.interpreter.error import OperationError, oefmt
 from pypy.interpreter.gateway import unwrap_spec
-from pypy.module.micronumpy import loop, descriptor, ufuncs, support, \
-    constants as NPY
+from pypy.module.micronumpy import loop, descriptor, support
+from pypy.module.micronumpy import constants as NPY
 from pypy.module.micronumpy.base import convert_to_array, W_NDimArray
 from pypy.module.micronumpy.converters import clipmode_converter
 from pypy.module.micronumpy.strides import (
     Chunk, Chunks, shape_agreement, shape_agreement_multiple)
+from .casting import find_binop_result_dtype
 
 
 def where(space, w_arr, w_x=None, w_y=None):
@@ -84,8 +85,7 @@
         if arr.get_dtype().itemtype.bool(arr.get_scalar_value()):
             return x
         return y
-    dtype = ufuncs.find_binop_result_dtype(space, x.get_dtype(),
-                                                  y.get_dtype())
+    dtype = find_binop_result_dtype(space, x.get_dtype(), y.get_dtype())
     shape = shape_agreement(space, arr.get_shape(), x)
     shape = shape_agreement(space, shape, y)
     out = W_NDimArray.from_shape(space, shape, dtype)
@@ -148,7 +148,7 @@
         elif dtype.is_record() or a_dt.is_record():
             raise OperationError(space.w_TypeError,
                         space.wrap("invalid type promotion"))
-        dtype = ufuncs.find_binop_result_dtype(space, dtype,
+        dtype = find_binop_result_dtype(space, dtype,
                                                       arr.get_dtype())
     # concatenate does not handle ndarray subtypes, it always returns a ndarray
     res = W_NDimArray.from_shape(space, shape, dtype, 'C')
diff --git a/pypy/module/micronumpy/casting.py 
b/pypy/module/micronumpy/casting.py
--- a/pypy/module/micronumpy/casting.py
+++ b/pypy/module/micronumpy/casting.py
@@ -1,16 +1,15 @@
 """Functions and helpers for converting between dtypes"""
 
 from rpython.rlib import jit
+from rpython.rlib.rarithmetic import LONG_BIT
 from pypy.interpreter.gateway import unwrap_spec
-from pypy.interpreter.error import oefmt
+from pypy.interpreter.error import oefmt, OperationError
 
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
 from pypy.module.micronumpy import constants as NPY
-from pypy.module.micronumpy.ufuncs import (
-    find_binop_result_dtype, find_dtype_for_scalar)
 from .types import (
     Bool, ULong, Long, Float64, Complex64, UnicodeType, VoidType, ObjectType)
-from .descriptor import get_dtype_cache, as_dtype, is_scalar_w
+from .descriptor import get_dtype_cache, as_dtype, is_scalar_w, variable_dtype
 
 @jit.unroll_safe
 def result_type(space, __args__):
@@ -106,3 +105,162 @@
         return get_dtype_cache(space).dtypes_by_num[num]
     else:
         return dtype
+
+@jit.unroll_safe
+def find_unaryop_result_dtype(space, dt, promote_to_float=False,
+        promote_bools=False, promote_to_largest=False):
+    if dt.is_object():
+        return dt
+    if promote_to_largest:
+        if dt.kind == NPY.GENBOOLLTR or dt.kind == NPY.SIGNEDLTR:
+            if dt.elsize * 8 < LONG_BIT:
+                return get_dtype_cache(space).w_longdtype
+        elif dt.kind == NPY.UNSIGNEDLTR:
+            if dt.elsize * 8 < LONG_BIT:
+                return get_dtype_cache(space).w_ulongdtype
+        else:
+            assert dt.kind == NPY.FLOATINGLTR or dt.kind == NPY.COMPLEXLTR
+        return dt
+    if promote_bools and (dt.kind == NPY.GENBOOLLTR):
+        return get_dtype_cache(space).w_int8dtype
+    if promote_to_float:
+        if dt.kind == NPY.FLOATINGLTR or dt.kind == NPY.COMPLEXLTR:
+            return dt
+        if dt.num >= NPY.INT:
+            return get_dtype_cache(space).w_float64dtype
+        for bytes, dtype in get_dtype_cache(space).float_dtypes_by_num_bytes:
+            if (dtype.kind == NPY.FLOATINGLTR and
+                    dtype.itemtype.get_element_size() >
+                    dt.itemtype.get_element_size()):
+                return dtype
+    return dt
+
+def find_binop_result_dtype(space, dt1, dt2, promote_to_float=False,
+        promote_bools=False):
+    if dt2 is None:
+        return dt1
+    if dt1 is None:
+        return dt2
+
+    if dt1.num == NPY.OBJECT or dt2.num == NPY.OBJECT:
+        return get_dtype_cache(space).w_objectdtype
+
+    # dt1.num should be <= dt2.num
+    if dt1.num > dt2.num:
+        dt1, dt2 = dt2, dt1
+    # Some operations promote op(bool, bool) to return int8, rather than bool
+    if promote_bools and (dt1.kind == dt2.kind == NPY.GENBOOLLTR):
+        return get_dtype_cache(space).w_int8dtype
+
+    # Everything numeric promotes to complex
+    if dt2.is_complex() or dt1.is_complex():
+        if dt2.num == NPY.HALF:
+            dt1, dt2 = dt2, dt1
+        if dt2.num == NPY.CFLOAT:
+            if dt1.num == NPY.DOUBLE:
+                return get_dtype_cache(space).w_complex128dtype
+            elif dt1.num == NPY.LONGDOUBLE:
+                return get_dtype_cache(space).w_complexlongdtype
+            return get_dtype_cache(space).w_complex64dtype
+        elif dt2.num == NPY.CDOUBLE:
+            if dt1.num == NPY.LONGDOUBLE:
+                return get_dtype_cache(space).w_complexlongdtype
+            return get_dtype_cache(space).w_complex128dtype
+        elif dt2.num == NPY.CLONGDOUBLE:
+            return get_dtype_cache(space).w_complexlongdtype
+        else:
+            raise OperationError(space.w_TypeError, space.wrap("Unsupported 
types"))
+
+    if promote_to_float:
+        return find_unaryop_result_dtype(space, dt2, promote_to_float=True)
+    # If they're the same kind, choose the greater one.
+    if dt1.kind == dt2.kind and not dt2.is_flexible():
+        if dt2.num == NPY.HALF:
+            return dt1
+        return dt2
+
+    # Everything promotes to float, and bool promotes to everything.
+    if dt2.kind == NPY.FLOATINGLTR or dt1.kind == NPY.GENBOOLLTR:
+        if dt2.num == NPY.HALF and dt1.itemtype.get_element_size() == 2:
+            return get_dtype_cache(space).w_float32dtype
+        if dt2.num == NPY.HALF and dt1.itemtype.get_element_size() >= 4:
+            return get_dtype_cache(space).w_float64dtype
+        if dt2.num == NPY.FLOAT and dt1.itemtype.get_element_size() >= 4:
+            return get_dtype_cache(space).w_float64dtype
+        return dt2
+
+    # for now this means mixing signed and unsigned
+    if dt2.kind == NPY.SIGNEDLTR:
+        # if dt2 has a greater number of bytes, then just go with it
+        if dt1.itemtype.get_element_size() < dt2.itemtype.get_element_size():
+            return dt2
+        # we need to promote both dtypes
+        dtypenum = dt2.num + 2
+    elif dt2.num == NPY.ULONGLONG or (LONG_BIT == 64 and dt2.num == NPY.ULONG):
+        # UInt64 + signed = Float64
+        dtypenum = NPY.DOUBLE
+    elif dt2.is_flexible():
+        # For those operations that get here (concatenate, stack),
+        # flexible types take precedence over numeric type
+        if dt2.is_record():
+            return dt2
+        if dt1.is_str_or_unicode():
+            if dt2.elsize >= dt1.elsize:
+                return dt2
+            return dt1
+        return dt2
+    else:
+        # increase to the next signed type
+        dtypenum = dt2.num + 1
+    newdtype = get_dtype_cache(space).dtypes_by_num[dtypenum]
+
+    if (newdtype.itemtype.get_element_size() > dt2.itemtype.get_element_size() 
or
+            newdtype.kind == NPY.FLOATINGLTR):
+        return newdtype
+    else:
+        # we only promoted to long on 32-bit or to longlong on 64-bit
+        # this is really for dealing with the Long and Ulong dtypes
+        dtypenum += 2
+        return get_dtype_cache(space).dtypes_by_num[dtypenum]
+
+def find_dtype_for_scalar(space, w_obj, current_guess=None):
+    from .boxes import W_GenericBox
+    bool_dtype = get_dtype_cache(space).w_booldtype
+    long_dtype = get_dtype_cache(space).w_longdtype
+    int64_dtype = get_dtype_cache(space).w_int64dtype
+    uint64_dtype = get_dtype_cache(space).w_uint64dtype
+    complex_dtype = get_dtype_cache(space).w_complex128dtype
+    float_dtype = get_dtype_cache(space).w_float64dtype
+    object_dtype = get_dtype_cache(space).w_objectdtype
+    if isinstance(w_obj, W_GenericBox):
+        dtype = w_obj.get_dtype(space)
+        return find_binop_result_dtype(space, dtype, current_guess)
+
+    if space.isinstance_w(w_obj, space.w_bool):
+        return find_binop_result_dtype(space, bool_dtype, current_guess)
+    elif space.isinstance_w(w_obj, space.w_int):
+        return find_binop_result_dtype(space, long_dtype, current_guess)
+    elif space.isinstance_w(w_obj, space.w_long):
+        try:
+            space.int_w(w_obj)
+        except OperationError, e:
+            if e.match(space, space.w_OverflowError):
+                if space.is_true(space.le(w_obj, space.wrap(0))):
+                    return find_binop_result_dtype(space, int64_dtype,
+                                               current_guess)
+                return find_binop_result_dtype(space, uint64_dtype,
+                                               current_guess)
+            raise
+        return find_binop_result_dtype(space, int64_dtype, current_guess)
+    elif space.isinstance_w(w_obj, space.w_float):
+        return find_binop_result_dtype(space, float_dtype, current_guess)
+    elif space.isinstance_w(w_obj, space.w_complex):
+        return complex_dtype
+    elif space.isinstance_w(w_obj, space.w_str):
+        if current_guess is None:
+            return variable_dtype(space, 'S%d' % space.len_w(w_obj))
+        elif current_guess.num == NPY.STRING:
+            if current_guess.elsize < space.len_w(w_obj):
+                return variable_dtype(space, 'S%d' % space.len_w(w_obj))
+        return current_guess
+    return object_dtype
diff --git a/pypy/module/micronumpy/descriptor.py 
b/pypy/module/micronumpy/descriptor.py
--- a/pypy/module/micronumpy/descriptor.py
+++ b/pypy/module/micronumpy/descriptor.py
@@ -29,7 +29,7 @@
     """ agree on dtype from a list of arrays. if out is allocated,
     use it's dtype, otherwise allocate a new one with agreed dtype
     """
-    from pypy.module.micronumpy.ufuncs import find_binop_result_dtype
+    from .casting import find_binop_result_dtype
 
     if not space.is_none(out):
         return out
@@ -1011,7 +1011,7 @@
     return space.fromcache(DtypeCache)
 
 def as_dtype(space, w_arg, allow_None=True):
-    from pypy.module.micronumpy.ufuncs import find_dtype_for_scalar
+    from pypy.module.micronumpy.casting import find_dtype_for_scalar
     # roughly equivalent to CNumPy's PyArray_DescrConverter2
     if not allow_None and space.is_none(w_arg):
         raise TypeError("Cannot create dtype from None here")
diff --git a/pypy/module/micronumpy/ndarray.py 
b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -988,6 +988,7 @@
         return space.newtuple([w_quotient, w_remainder])
 
     def descr_dot(self, space, w_other, w_out=None):
+        from .casting import find_binop_result_dtype
         if space.is_none(w_out):
             out = None
         elif not isinstance(w_out, W_NDimArray):
@@ -1002,7 +1003,7 @@
             w_res = self.descr_mul(space, other)
             assert isinstance(w_res, W_NDimArray)
             return w_res.descr_sum(space, space.wrap(-1), out)
-        dtype = ufuncs.find_binop_result_dtype(space, self.get_dtype(),
+        dtype = find_binop_result_dtype(space, self.get_dtype(),
                                                other.get_dtype())
         if self.get_size() < 1 and other.get_size() < 1:
             # numpy compatability
diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -9,6 +9,7 @@
 from pypy.module.micronumpy.iterators import ArrayIter
 from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
                                             shape_agreement, 
shape_agreement_multiple)
+from pypy.module.micronumpy.casting import find_binop_result_dtype
 
 
 def parse_op_arg(space, name, w_op_flags, n, parse_one_arg):
@@ -173,7 +174,7 @@
     def __init__(self, array, size, shape, strides, backstrides,
                  op_flags, base):
         OperandIter.__init__(self, array, size, shape, strides, backstrides)
-        self.slice_shape =[] 
+        self.slice_shape =[]
         self.slice_stride = []
         self.slice_backstride = []
         if op_flags.rw == 'r':
@@ -302,7 +303,7 @@
     But after coalesce(), getoperand() will return a slice by removing
     the fastest varying dimension(s) from the beginning or end of the shape.
     If flat is true, then the slice will be 1d, otherwise stack up the shape of
-    the fastest varying dimension in the slice, so an iterator of a  'C' array 
+    the fastest varying dimension in the slice, so an iterator of a  'C' array
     of shape (2,4,3) after two calls to coalesce will iterate 2 times over a 
slice
     of shape (4,3) by setting the offset to the beginning of the data at each 
iteration
     '''
@@ -367,8 +368,6 @@
     _immutable_fields_ = ['ndim', ]
     def __init__(self, space, w_seq, w_flags, w_op_flags, w_op_dtypes,
                  w_casting, w_op_axes, w_itershape, buffersize=0, order='K'):
-        from pypy.module.micronumpy.ufuncs import find_binop_result_dtype
-        
         self.order = order
         self.external_loop = False
         self.buffered = False
diff --git a/pypy/module/micronumpy/strides.py 
b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -215,7 +215,7 @@
 
 
 def find_dtype_for_seq(space, elems_w, dtype):
-    from pypy.module.micronumpy.ufuncs import find_dtype_for_scalar
+    from pypy.module.micronumpy.casting import find_dtype_for_scalar
     if len(elems_w) == 1:
         w_elem = elems_w[0]
         if isinstance(w_elem, W_NDimArray) and w_elem.is_scalar():
@@ -225,7 +225,7 @@
 
 
 def _find_dtype_for_seq(space, elems_w, dtype):
-    from pypy.module.micronumpy.ufuncs import find_dtype_for_scalar
+    from pypy.module.micronumpy.casting import find_dtype_for_scalar
     for w_elem in elems_w:
         if isinstance(w_elem, W_NDimArray) and w_elem.is_scalar():
             w_elem = w_elem.get_scalar_value()
diff --git a/pypy/module/micronumpy/test/test_casting.py 
b/pypy/module/micronumpy/test/test_casting.py
--- a/pypy/module/micronumpy/test/test_casting.py
+++ b/pypy/module/micronumpy/test/test_casting.py
@@ -1,4 +1,7 @@
 from pypy.module.micronumpy.test.test_base import BaseNumpyAppTest
+from pypy.module.micronumpy.descriptor import get_dtype_cache
+from pypy.module.micronumpy.casting import (
+    find_unaryop_result_dtype, find_binop_result_dtype)
 
 
 class AppTestNumSupport(BaseNumpyAppTest):
@@ -119,3 +122,80 @@
         assert np.min_scalar_type(2**64 - 1) == np.dtype('uint64')
         # XXX: np.asarray(2**64) fails with OverflowError
         # assert np.min_scalar_type(2**64) == np.dtype('O')
+
+class TestCoercion(object):
+    def test_binops(self, space):
+        bool_dtype = get_dtype_cache(space).w_booldtype
+        int8_dtype = get_dtype_cache(space).w_int8dtype
+        int32_dtype = get_dtype_cache(space).w_int32dtype
+        float64_dtype = get_dtype_cache(space).w_float64dtype
+        c64_dtype = get_dtype_cache(space).w_complex64dtype
+        c128_dtype = get_dtype_cache(space).w_complex128dtype
+        cld_dtype = get_dtype_cache(space).w_complexlongdtype
+        fld_dtype = get_dtype_cache(space).w_floatlongdtype
+
+        # Basic pairing
+        assert find_binop_result_dtype(space, bool_dtype, bool_dtype) is 
bool_dtype
+        assert find_binop_result_dtype(space, bool_dtype, float64_dtype) is 
float64_dtype
+        assert find_binop_result_dtype(space, float64_dtype, bool_dtype) is 
float64_dtype
+        assert find_binop_result_dtype(space, int32_dtype, int8_dtype) is 
int32_dtype
+        assert find_binop_result_dtype(space, int32_dtype, bool_dtype) is 
int32_dtype
+        assert find_binop_result_dtype(space, c64_dtype, float64_dtype) is 
c128_dtype
+        assert find_binop_result_dtype(space, c64_dtype, fld_dtype) is 
cld_dtype
+        assert find_binop_result_dtype(space, c128_dtype, fld_dtype) is 
cld_dtype
+
+        # With promote bool (happens on div), the result is that the op should
+        # promote bools to int8
+        assert find_binop_result_dtype(space, bool_dtype, bool_dtype, 
promote_bools=True) is int8_dtype
+        assert find_binop_result_dtype(space, bool_dtype, float64_dtype, 
promote_bools=True) is float64_dtype
+
+        # Coerce to floats
+        assert find_binop_result_dtype(space, bool_dtype, float64_dtype, 
promote_to_float=True) is float64_dtype
+
+    def test_unaryops(self, space):
+        bool_dtype = get_dtype_cache(space).w_booldtype
+        int8_dtype = get_dtype_cache(space).w_int8dtype
+        uint8_dtype = get_dtype_cache(space).w_uint8dtype
+        int16_dtype = get_dtype_cache(space).w_int16dtype
+        uint16_dtype = get_dtype_cache(space).w_uint16dtype
+        int32_dtype = get_dtype_cache(space).w_int32dtype
+        uint32_dtype = get_dtype_cache(space).w_uint32dtype
+        long_dtype = get_dtype_cache(space).w_longdtype
+        ulong_dtype = get_dtype_cache(space).w_ulongdtype
+        int64_dtype = get_dtype_cache(space).w_int64dtype
+        uint64_dtype = get_dtype_cache(space).w_uint64dtype
+        float16_dtype = get_dtype_cache(space).w_float16dtype
+        float32_dtype = get_dtype_cache(space).w_float32dtype
+        float64_dtype = get_dtype_cache(space).w_float64dtype
+
+        # Normal rules, everything returns itself
+        assert find_unaryop_result_dtype(space, bool_dtype) is bool_dtype
+        assert find_unaryop_result_dtype(space, int8_dtype) is int8_dtype
+        assert find_unaryop_result_dtype(space, uint8_dtype) is uint8_dtype
+        assert find_unaryop_result_dtype(space, int16_dtype) is int16_dtype
+        assert find_unaryop_result_dtype(space, uint16_dtype) is uint16_dtype
+        assert find_unaryop_result_dtype(space, int32_dtype) is int32_dtype
+        assert find_unaryop_result_dtype(space, uint32_dtype) is uint32_dtype
+        assert find_unaryop_result_dtype(space, long_dtype) is long_dtype
+        assert find_unaryop_result_dtype(space, ulong_dtype) is ulong_dtype
+        assert find_unaryop_result_dtype(space, int64_dtype) is int64_dtype
+        assert find_unaryop_result_dtype(space, uint64_dtype) is uint64_dtype
+        assert find_unaryop_result_dtype(space, float32_dtype) is float32_dtype
+        assert find_unaryop_result_dtype(space, float64_dtype) is float64_dtype
+
+        # Coerce to floats, some of these will eventually be float16, or
+        # whatever our smallest float type is.
+        assert find_unaryop_result_dtype(space, bool_dtype, 
promote_to_float=True) is float16_dtype
+        assert find_unaryop_result_dtype(space, int8_dtype, 
promote_to_float=True) is float16_dtype
+        assert find_unaryop_result_dtype(space, uint8_dtype, 
promote_to_float=True) is float16_dtype
+        assert find_unaryop_result_dtype(space, int16_dtype, 
promote_to_float=True) is float32_dtype
+        assert find_unaryop_result_dtype(space, uint16_dtype, 
promote_to_float=True) is float32_dtype
+        assert find_unaryop_result_dtype(space, int32_dtype, 
promote_to_float=True) is float64_dtype
+        assert find_unaryop_result_dtype(space, uint32_dtype, 
promote_to_float=True) is float64_dtype
+        assert find_unaryop_result_dtype(space, int64_dtype, 
promote_to_float=True) is float64_dtype
+        assert find_unaryop_result_dtype(space, uint64_dtype, 
promote_to_float=True) is float64_dtype
+        assert find_unaryop_result_dtype(space, float32_dtype, 
promote_to_float=True) is float32_dtype
+        assert find_unaryop_result_dtype(space, float64_dtype, 
promote_to_float=True) is float64_dtype
+
+        # promote bools, happens with sign ufunc
+        assert find_unaryop_result_dtype(space, bool_dtype, 
promote_bools=True) is int8_dtype
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py 
b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -1,93 +1,12 @@
 from pypy.module.micronumpy.test.test_base import BaseNumpyAppTest
-from pypy.module.micronumpy.ufuncs import (find_binop_result_dtype,
-        find_unaryop_result_dtype, W_UfuncGeneric)
+from pypy.module.micronumpy.ufuncs import W_UfuncGeneric
 from pypy.module.micronumpy.support import _parse_signature
 from pypy.module.micronumpy.descriptor import get_dtype_cache
 from pypy.module.micronumpy.base import W_NDimArray
 from pypy.module.micronumpy.concrete import VoidBoxStorage
-from pypy.interpreter.gateway import interp2app
-from pypy.conftest import option
 from pypy.interpreter.error import OperationError
 
 
-class TestUfuncCoercion(object):
-    def test_binops(self, space):
-        bool_dtype = get_dtype_cache(space).w_booldtype
-        int8_dtype = get_dtype_cache(space).w_int8dtype
-        int32_dtype = get_dtype_cache(space).w_int32dtype
-        float64_dtype = get_dtype_cache(space).w_float64dtype
-        c64_dtype = get_dtype_cache(space).w_complex64dtype
-        c128_dtype = get_dtype_cache(space).w_complex128dtype
-        cld_dtype = get_dtype_cache(space).w_complexlongdtype
-        fld_dtype = get_dtype_cache(space).w_floatlongdtype
-
-        # Basic pairing
-        assert find_binop_result_dtype(space, bool_dtype, bool_dtype) is 
bool_dtype
-        assert find_binop_result_dtype(space, bool_dtype, float64_dtype) is 
float64_dtype
-        assert find_binop_result_dtype(space, float64_dtype, bool_dtype) is 
float64_dtype
-        assert find_binop_result_dtype(space, int32_dtype, int8_dtype) is 
int32_dtype
-        assert find_binop_result_dtype(space, int32_dtype, bool_dtype) is 
int32_dtype
-        assert find_binop_result_dtype(space, c64_dtype, float64_dtype) is 
c128_dtype
-        assert find_binop_result_dtype(space, c64_dtype, fld_dtype) is 
cld_dtype
-        assert find_binop_result_dtype(space, c128_dtype, fld_dtype) is 
cld_dtype
-
-        # With promote bool (happens on div), the result is that the op should
-        # promote bools to int8
-        assert find_binop_result_dtype(space, bool_dtype, bool_dtype, 
promote_bools=True) is int8_dtype
-        assert find_binop_result_dtype(space, bool_dtype, float64_dtype, 
promote_bools=True) is float64_dtype
-
-        # Coerce to floats
-        assert find_binop_result_dtype(space, bool_dtype, float64_dtype, 
promote_to_float=True) is float64_dtype
-
-    def test_unaryops(self, space):
-        bool_dtype = get_dtype_cache(space).w_booldtype
-        int8_dtype = get_dtype_cache(space).w_int8dtype
-        uint8_dtype = get_dtype_cache(space).w_uint8dtype
-        int16_dtype = get_dtype_cache(space).w_int16dtype
-        uint16_dtype = get_dtype_cache(space).w_uint16dtype
-        int32_dtype = get_dtype_cache(space).w_int32dtype
-        uint32_dtype = get_dtype_cache(space).w_uint32dtype
-        long_dtype = get_dtype_cache(space).w_longdtype
-        ulong_dtype = get_dtype_cache(space).w_ulongdtype
-        int64_dtype = get_dtype_cache(space).w_int64dtype
-        uint64_dtype = get_dtype_cache(space).w_uint64dtype
-        float16_dtype = get_dtype_cache(space).w_float16dtype
-        float32_dtype = get_dtype_cache(space).w_float32dtype
-        float64_dtype = get_dtype_cache(space).w_float64dtype
-
-        # Normal rules, everything returns itself
-        assert find_unaryop_result_dtype(space, bool_dtype) is bool_dtype
-        assert find_unaryop_result_dtype(space, int8_dtype) is int8_dtype
-        assert find_unaryop_result_dtype(space, uint8_dtype) is uint8_dtype
-        assert find_unaryop_result_dtype(space, int16_dtype) is int16_dtype
-        assert find_unaryop_result_dtype(space, uint16_dtype) is uint16_dtype
-        assert find_unaryop_result_dtype(space, int32_dtype) is int32_dtype
-        assert find_unaryop_result_dtype(space, uint32_dtype) is uint32_dtype
-        assert find_unaryop_result_dtype(space, long_dtype) is long_dtype
-        assert find_unaryop_result_dtype(space, ulong_dtype) is ulong_dtype
-        assert find_unaryop_result_dtype(space, int64_dtype) is int64_dtype
-        assert find_unaryop_result_dtype(space, uint64_dtype) is uint64_dtype
-        assert find_unaryop_result_dtype(space, float32_dtype) is float32_dtype
-        assert find_unaryop_result_dtype(space, float64_dtype) is float64_dtype
-
-        # Coerce to floats, some of these will eventually be float16, or
-        # whatever our smallest float type is.
-        assert find_unaryop_result_dtype(space, bool_dtype, 
promote_to_float=True) is float16_dtype
-        assert find_unaryop_result_dtype(space, int8_dtype, 
promote_to_float=True) is float16_dtype
-        assert find_unaryop_result_dtype(space, uint8_dtype, 
promote_to_float=True) is float16_dtype
-        assert find_unaryop_result_dtype(space, int16_dtype, 
promote_to_float=True) is float32_dtype
-        assert find_unaryop_result_dtype(space, uint16_dtype, 
promote_to_float=True) is float32_dtype
-        assert find_unaryop_result_dtype(space, int32_dtype, 
promote_to_float=True) is float64_dtype
-        assert find_unaryop_result_dtype(space, uint32_dtype, 
promote_to_float=True) is float64_dtype
-        assert find_unaryop_result_dtype(space, int64_dtype, 
promote_to_float=True) is float64_dtype
-        assert find_unaryop_result_dtype(space, uint64_dtype, 
promote_to_float=True) is float64_dtype
-        assert find_unaryop_result_dtype(space, float32_dtype, 
promote_to_float=True) is float32_dtype
-        assert find_unaryop_result_dtype(space, float64_dtype, 
promote_to_float=True) is float64_dtype
-
-        # promote bools, happens with sign ufunc
-        assert find_unaryop_result_dtype(space, bool_dtype, 
promote_bools=True) is int8_dtype
-
-
 class TestGenericUfuncOperation(object):
     def test_signature_parser(self, space):
         class Ufunc(object):
@@ -96,10 +15,10 @@
                 self.nout = nout
                 self.nargs = nin + nout
                 self.core_enabled = True
-                self.core_num_dim_ix = 0 
-                self.core_num_dims = [0] * self.nargs  
+                self.core_num_dim_ix = 0
+                self.core_num_dims = [0] * self.nargs
                 self.core_offsets = [0] * self.nargs
-                self.core_dim_ixs = [] 
+                self.core_dim_ixs = []
 
         u = Ufunc(2, 1)
         _parse_signature(space, u, '(m,n), (n,r)->(m,r)')
@@ -116,8 +35,8 @@
         b_dtype = get_dtype_cache(space).w_booldtype
 
         ufunc = W_UfuncGeneric(space, [None, None, None], 'eigenvals', None, 
1, 1,
-                     [f32_dtype, c64_dtype, 
-                      f64_dtype, c128_dtype, 
+                     [f32_dtype, c64_dtype,
+                      f64_dtype, c128_dtype,
                       c128_dtype, c128_dtype],
                      '')
         f32_array = W_NDimArray(VoidBoxStorage(0, f32_dtype))
@@ -167,7 +86,7 @@
             assert 'object' in str(e)
             # Use pypy specific extension for out_dtype
             adder_ufunc0 = frompyfunc(adder, 2, 1, dtypes=['match'])
-            sumdiff = frompyfunc(sumdiff, 2, 2, dtypes=['match'], 
+            sumdiff = frompyfunc(sumdiff, 2, 2, dtypes=['match'],
                                     signature='(i),(i)->(i),(i)')
             adder_ufunc1 = frompyfunc([adder, adder], 2, 1,
                             dtypes=[int, int, int, float, float, float])
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -4,22 +4,21 @@
 from pypy.interpreter.typedef import TypeDef, GetSetProperty, 
interp_attrproperty
 from pypy.interpreter.argument import Arguments
 from rpython.rlib import jit
-from rpython.rlib.rarithmetic import LONG_BIT, maxint
+from rpython.rlib.rarithmetic import LONG_BIT, maxint, _get_bitsize
 from rpython.tool.sourcetools import func_with_new_name
+from rpython.rlib.rawstorage import (
+    raw_storage_setitem, free_raw_storage, alloc_raw_storage)
+from rpython.rtyper.lltypesystem import rffi, lltype
+from rpython.rlib.objectmodel import keepalive_until_here
+
 from pypy.module.micronumpy import boxes, loop, constants as NPY
-from pypy.module.micronumpy.descriptor import (get_dtype_cache,
-            variable_dtype, decode_w_dtype)
+from pypy.module.micronumpy.descriptor import get_dtype_cache, decode_w_dtype
 from pypy.module.micronumpy.base import convert_to_array, W_NDimArray
 from pypy.module.micronumpy.ctors import numpify
 from pypy.module.micronumpy.nditer import W_NDIter, coalesce_iter
 from pypy.module.micronumpy.strides import shape_agreement
 from pypy.module.micronumpy.support import _parse_signature, product, 
get_storage_as_int
-from rpython.rlib.rawstorage import (raw_storage_setitem, free_raw_storage,
-             alloc_raw_storage)
-from rpython.rtyper.lltypesystem import rffi, lltype
-from rpython.rlib.rarithmetic import LONG_BIT, _get_bitsize
-from rpython.rlib.objectmodel import keepalive_until_here
-
+from .casting import find_unaryop_result_dtype, find_binop_result_dtype
 
 def done_if_true(dtype, val):
     return dtype.itemtype.bool(val)
@@ -445,7 +444,7 @@
                 self.comparison_func and w_out is None:
             if self.name in ('equal', 'less_equal', 'less'):
                return space.wrap(False)
-            return space.wrap(True) 
+            return space.wrap(True)
         elif (w_rdtype.is_str()) and \
                 self.comparison_func and w_out is None:
             if self.name in ('not_equal','less', 'less_equal'):
@@ -955,170 +954,6 @@
 )
 
 
-def find_binop_result_dtype(space, dt1, dt2, promote_to_float=False,
-        promote_bools=False):
-    if dt2 is None:
-        return dt1
-    if dt1 is None:
-        return dt2
-
-    if dt1.num == NPY.OBJECT or dt2.num == NPY.OBJECT:
-        return get_dtype_cache(space).w_objectdtype
-
-    # dt1.num should be <= dt2.num
-    if dt1.num > dt2.num:
-        dt1, dt2 = dt2, dt1
-    # Some operations promote op(bool, bool) to return int8, rather than bool
-    if promote_bools and (dt1.kind == dt2.kind == NPY.GENBOOLLTR):
-        return get_dtype_cache(space).w_int8dtype
-
-    # Everything numeric promotes to complex
-    if dt2.is_complex() or dt1.is_complex():
-        if dt2.num == NPY.HALF:
-            dt1, dt2 = dt2, dt1
-        if dt2.num == NPY.CFLOAT:
-            if dt1.num == NPY.DOUBLE:
-                return get_dtype_cache(space).w_complex128dtype
-            elif dt1.num == NPY.LONGDOUBLE:
-                return get_dtype_cache(space).w_complexlongdtype
-            return get_dtype_cache(space).w_complex64dtype
-        elif dt2.num == NPY.CDOUBLE:
-            if dt1.num == NPY.LONGDOUBLE:
-                return get_dtype_cache(space).w_complexlongdtype
-            return get_dtype_cache(space).w_complex128dtype
-        elif dt2.num == NPY.CLONGDOUBLE:
-            return get_dtype_cache(space).w_complexlongdtype
-        else:
-            raise OperationError(space.w_TypeError, space.wrap("Unsupported 
types"))
-
-    if promote_to_float:
-        return find_unaryop_result_dtype(space, dt2, promote_to_float=True)
-    # If they're the same kind, choose the greater one.
-    if dt1.kind == dt2.kind and not dt2.is_flexible():
-        if dt2.num == NPY.HALF:
-            return dt1
-        return dt2
-
-    # Everything promotes to float, and bool promotes to everything.
-    if dt2.kind == NPY.FLOATINGLTR or dt1.kind == NPY.GENBOOLLTR:
-        if dt2.num == NPY.HALF and dt1.itemtype.get_element_size() == 2:
-            return get_dtype_cache(space).w_float32dtype
-        if dt2.num == NPY.HALF and dt1.itemtype.get_element_size() >= 4:
-            return get_dtype_cache(space).w_float64dtype
-        if dt2.num == NPY.FLOAT and dt1.itemtype.get_element_size() >= 4:
-            return get_dtype_cache(space).w_float64dtype
-        return dt2
-
-    # for now this means mixing signed and unsigned
-    if dt2.kind == NPY.SIGNEDLTR:
-        # if dt2 has a greater number of bytes, then just go with it
-        if dt1.itemtype.get_element_size() < dt2.itemtype.get_element_size():
-            return dt2
-        # we need to promote both dtypes
-        dtypenum = dt2.num + 2
-    elif dt2.num == NPY.ULONGLONG or (LONG_BIT == 64 and dt2.num == NPY.ULONG):
-        # UInt64 + signed = Float64
-        dtypenum = NPY.DOUBLE
-    elif dt2.is_flexible():
-        # For those operations that get here (concatenate, stack),
-        # flexible types take precedence over numeric type
-        if dt2.is_record():
-            return dt2
-        if dt1.is_str_or_unicode():
-            if dt2.elsize >= dt1.elsize:
-                return dt2
-            return dt1
-        return dt2
-    else:
-        # increase to the next signed type
-        dtypenum = dt2.num + 1
-    newdtype = get_dtype_cache(space).dtypes_by_num[dtypenum]
-
-    if (newdtype.itemtype.get_element_size() > dt2.itemtype.get_element_size() 
or
-            newdtype.kind == NPY.FLOATINGLTR):
-        return newdtype
-    else:
-        # we only promoted to long on 32-bit or to longlong on 64-bit
-        # this is really for dealing with the Long and Ulong dtypes
-        dtypenum += 2
-        return get_dtype_cache(space).dtypes_by_num[dtypenum]
-
-
-@jit.unroll_safe
-def find_unaryop_result_dtype(space, dt, promote_to_float=False,
-        promote_bools=False, promote_to_largest=False):
-    if dt.is_object():
-        return dt
-    if promote_to_largest:
-        if dt.kind == NPY.GENBOOLLTR or dt.kind == NPY.SIGNEDLTR:
-            if dt.elsize * 8 < LONG_BIT:
-                return get_dtype_cache(space).w_longdtype
-        elif dt.kind == NPY.UNSIGNEDLTR:
-            if dt.elsize * 8 < LONG_BIT:
-                return get_dtype_cache(space).w_ulongdtype
-        else:
-            assert dt.kind == NPY.FLOATINGLTR or dt.kind == NPY.COMPLEXLTR
-        return dt
-    if promote_bools and (dt.kind == NPY.GENBOOLLTR):
-        return get_dtype_cache(space).w_int8dtype
-    if promote_to_float:
-        if dt.kind == NPY.FLOATINGLTR or dt.kind == NPY.COMPLEXLTR:
-            return dt
-        if dt.num >= NPY.INT:
-            return get_dtype_cache(space).w_float64dtype
-        for bytes, dtype in get_dtype_cache(space).float_dtypes_by_num_bytes:
-            if (dtype.kind == NPY.FLOATINGLTR and
-                    dtype.itemtype.get_element_size() >
-                    dt.itemtype.get_element_size()):
-                return dtype
-    return dt
-
-
-def find_dtype_for_scalar(space, w_obj, current_guess=None):
-    bool_dtype = get_dtype_cache(space).w_booldtype
-    long_dtype = get_dtype_cache(space).w_longdtype
-    int64_dtype = get_dtype_cache(space).w_int64dtype
-    uint64_dtype = get_dtype_cache(space).w_uint64dtype
-    complex_dtype = get_dtype_cache(space).w_complex128dtype
-    float_dtype = get_dtype_cache(space).w_float64dtype
-    object_dtype = get_dtype_cache(space).w_objectdtype
-    if isinstance(w_obj, boxes.W_GenericBox):
-        dtype = w_obj.get_dtype(space)
-        return find_binop_result_dtype(space, dtype, current_guess)
-
-    if space.isinstance_w(w_obj, space.w_bool):
-        return find_binop_result_dtype(space, bool_dtype, current_guess)
-    elif space.isinstance_w(w_obj, space.w_int):
-        return find_binop_result_dtype(space, long_dtype, current_guess)
-    elif space.isinstance_w(w_obj, space.w_long):
-        try:
-            space.int_w(w_obj)
-        except OperationError, e:
-            if e.match(space, space.w_OverflowError):
-                if space.is_true(space.le(w_obj, space.wrap(0))):
-                    return find_binop_result_dtype(space, int64_dtype,
-                                               current_guess)
-                return find_binop_result_dtype(space, uint64_dtype,
-                                               current_guess)
-            raise
-        return find_binop_result_dtype(space, int64_dtype, current_guess)
-    elif space.isinstance_w(w_obj, space.w_float):
-        return find_binop_result_dtype(space, float_dtype, current_guess)
-    elif space.isinstance_w(w_obj, space.w_complex):
-        return complex_dtype
-    elif space.isinstance_w(w_obj, space.w_str):
-        if current_guess is None:
-            return variable_dtype(space,
-                                               'S%d' % space.len_w(w_obj))
-        elif current_guess.num == NPY.STRING:
-            if current_guess.elsize < space.len_w(w_obj):
-                return variable_dtype(space,
-                                                   'S%d' % space.len_w(w_obj))
-        return current_guess
-    return object_dtype
-    #raise oefmt(space.w_NotImplementedError,
-    #            'unable to create dtype from objects, "%T" instance not '
-    #            'supported', w_obj)
 
 
 def ufunc_dtype_caller(space, ufunc_name, op_name, nin, comparison_func,
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to