Author: Ronan Lamy <ronan.l...@gmail.com>
Branch: fix-result-types
Changeset: r77296:2fc8c1b68f07
Date: 2015-05-12 06:16 +0100
http://bitbucket.org/pypy/pypy/changeset/2fc8c1b68f07/

Log:    Use the same logic as cnumpy in W_Ufunc1.find_specialization()

diff --git a/pypy/module/micronumpy/descriptor.py 
b/pypy/module/micronumpy/descriptor.py
--- a/pypy/module/micronumpy/descriptor.py
+++ b/pypy/module/micronumpy/descriptor.py
@@ -900,17 +900,20 @@
             NPY.CDOUBLE:     self.w_float64dtype,
             NPY.CLONGDOUBLE: self.w_floatlongdtype,
         }
-        self.builtin_dtypes = [
-            self.w_booldtype,
+        integer_dtypes = [
             self.w_int8dtype, self.w_uint8dtype,
             self.w_int16dtype, self.w_uint16dtype,
+            self.w_int32dtype, self.w_uint32dtype,
             self.w_longdtype, self.w_ulongdtype,
-            self.w_int32dtype, self.w_uint32dtype,
-            self.w_int64dtype, self.w_uint64dtype,
-            ] + float_dtypes + complex_dtypes + [
-            self.w_stringdtype, self.w_unicodedtype, self.w_voiddtype,
-            self.w_objectdtype,
-        ]
+            self.w_int64dtype, self.w_uint64dtype]
+        self.builtin_dtypes = ([self.w_booldtype] + integer_dtypes +
+            float_dtypes + complex_dtypes + [
+                self.w_stringdtype, self.w_unicodedtype, self.w_voiddtype,
+                self.w_objectdtype,
+            ])
+        self.integer_dtypes = integer_dtypes
+        self.float_dtypes = float_dtypes
+        self.complex_dtypes = complex_dtypes
         self.float_dtypes_by_num_bytes = sorted(
             (dtype.elsize, dtype)
             for dtype in float_dtypes
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py 
b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -1,5 +1,5 @@
 from pypy.module.micronumpy.test.test_base import BaseNumpyAppTest
-from pypy.module.micronumpy.ufuncs import W_UfuncGeneric
+from pypy.module.micronumpy.ufuncs import W_UfuncGeneric, W_Ufunc1
 from pypy.module.micronumpy.support import _parse_signature
 from pypy.module.micronumpy.descriptor import get_dtype_cache
 from pypy.module.micronumpy.base import W_NDimArray
@@ -54,6 +54,20 @@
         exc = raises(OperationError, ufunc.type_resolver, space, [f32_array], 
[None],
                                 'i->i', ufunc.dtypes)
 
+    def test_allowed_types(self, space):
+        dt_bool = get_dtype_cache(space).w_booldtype
+        dt_float16 = get_dtype_cache(space).w_float16dtype
+        dt_int32 = get_dtype_cache(space).w_int32dtype
+        ufunc = W_Ufunc1(None, 'x', int_only=True)
+        assert ufunc._calc_dtype(space, dt_bool) == dt_bool
+        assert ufunc.allowed_types(space)  # XXX: shouldn't contain too much 
stuff
+
+        ufunc = W_Ufunc1(None, 'x', promote_to_float=True)
+        assert ufunc._calc_dtype(space, dt_bool) == dt_float16
+
+        ufunc = W_Ufunc1(None, 'x')
+        assert ufunc._calc_dtype(space, dt_int32) == dt_int32
+
 class AppTestUfuncs(BaseNumpyAppTest):
     def test_constants(self):
         import numpy as np
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -18,7 +18,8 @@
 from pypy.module.micronumpy.nditer import W_NDIter, coalesce_iter
 from pypy.module.micronumpy.strides import shape_agreement
 from pypy.module.micronumpy.support import _parse_signature, product, 
get_storage_as_int
-from .casting import find_unaryop_result_dtype, find_binop_result_dtype
+from .casting import (
+    find_unaryop_result_dtype, find_binop_result_dtype, can_cast_type)
 
 def done_if_true(dtype, val):
     return dtype.itemtype.bool(val)
@@ -384,12 +385,36 @@
                 not self.allow_complex and dtype.is_complex()):
             raise oefmt(space.w_TypeError,
                 "ufunc %s not supported for the input type", self.name)
-        calc_dtype = find_unaryop_result_dtype(space,
-                                  dtype,
-                                  promote_to_float=self.promote_to_float,
-                                  promote_bools=self.promote_bools)
+        calc_dtype = self._calc_dtype(space, dtype)
         return calc_dtype, self.func
 
+    def _calc_dtype(self, space, arg_dtype):
+        use_min_scalar=False
+        if arg_dtype.is_object():
+            return arg_dtype
+        for dtype in self.allowed_types(space):
+            if use_min_scalar:
+                if can_cast_array(space, w_arg, dtype, casting='safe'):
+                    return dtype
+            else:
+                if can_cast_type(space, arg_dtype, dtype, casting='safe'):
+                    return dtype
+        else:
+            raise oefmt(space.w_TypeError,
+                "No loop matching the specified signature was found "
+                "for ufunc %s", self.name)
+
+    def allowed_types(self, space):
+        dtypes = []
+        cache = get_dtype_cache(space)
+        if not self.promote_bools and not self.promote_to_float:
+            dtypes.append(cache.w_booldtype)
+        if not self.promote_to_float:
+            dtypes.extend(cache.integer_dtypes)
+        dtypes.extend(cache.float_dtypes)
+        dtypes.extend(cache.complex_dtypes)
+        return dtypes
+
 
 class W_Ufunc2(W_Ufunc):
     _immutable_fields_ = ["func", "comparison_func", "done_func"]
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to