Author: Ronan Lamy <ronan.l...@gmail.com>
Branch: 
Changeset: r77796:cad6015d5380
Date: 2015-06-03 00:04 +0100
http://bitbucket.org/pypy/pypy/changeset/cad6015d5380/

Log:    Merged use_min_scalar into default

diff --git a/pypy/module/micronumpy/test/test_ufuncs.py 
b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -1349,3 +1349,6 @@
         assert np.add(np.float16(0), np.longdouble(0)).dtype == np.longdouble
         assert np.add(np.float16(0), np.complex64(0)).dtype == np.complex64
         assert np.add(np.float16(0), np.complex128(0)).dtype == np.complex128
+        assert np.add(np.zeros(5, dtype=np.int8), 257).dtype == np.int16
+        assert np.subtract(np.zeros(5, dtype=np.int8), 257).dtype == np.int16
+        assert np.divide(np.zeros(5, dtype=np.int8), 257).dtype == np.int16
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -21,7 +21,8 @@
 from pypy.module.micronumpy.support import (_parse_signature, product,
         get_storage_as_int, is_rhs_priority_higher)
 from .casting import (
-    can_cast_type, can_cast_to, find_result_type, promote_types)
+    can_cast_type, can_cast_array, can_cast_to,
+    find_result_type, promote_types)
 from .boxes import W_GenericBox, W_ObjectBox
 
 def done_if_true(dtype, val):
@@ -495,17 +496,12 @@
         return dt_in, dt_out, self.func
 
     def _calc_dtype(self, space, arg_dtype, out=None, casting='unsafe'):
-        use_min_scalar = False
         if arg_dtype.is_object():
             return arg_dtype, arg_dtype
         in_casting = safe_casting_mode(casting)
         for dt_in, dt_out in self.dtypes:
-            if use_min_scalar:
-                if not can_cast_array(space, w_arg, dt_in, in_casting):
-                    continue
-            else:
-                if not can_cast_type(space, arg_dtype, dt_in, in_casting):
-                    continue
+            if not can_cast_type(space, arg_dtype, dt_in, in_casting):
+                continue
             if out is not None:
                 res_dtype = out.get_dtype()
                 if not can_cast_type(space, dt_out, res_dtype, casting):
@@ -605,21 +601,18 @@
                             w_rdtype.get_name(), w_ldtype.get_name(),
                             self.name)
 
-        if self.are_common_types(w_ldtype, w_rdtype):
-            if not w_lhs.is_scalar() and w_rhs.is_scalar():
-                w_rdtype = w_ldtype
-            elif w_lhs.is_scalar() and not w_rhs.is_scalar():
-                w_ldtype = w_rdtype
-        calc_dtype, dt_out, func = self.find_specialization(space, w_ldtype, 
w_rdtype, out, casting)
         if (isinstance(w_lhs, W_GenericBox) and
                 isinstance(w_rhs, W_GenericBox) and out is None):
-            return self.call_scalar(space, w_lhs, w_rhs, calc_dtype)
+            return self.call_scalar(space, w_lhs, w_rhs, casting)
         if isinstance(w_lhs, W_GenericBox):
             w_lhs = W_NDimArray.from_scalar(space, w_lhs)
         assert isinstance(w_lhs, W_NDimArray)
         if isinstance(w_rhs, W_GenericBox):
             w_rhs = W_NDimArray.from_scalar(space, w_rhs)
         assert isinstance(w_rhs, W_NDimArray)
+        calc_dtype, dt_out, func = self.find_specialization(
+            space, w_ldtype, w_rdtype, out, casting, w_lhs, w_rhs)
+
         new_shape = shape_agreement(space, w_lhs.get_shape(), w_rhs)
         new_shape = shape_agreement(space, new_shape, out, 
broadcast_down=False)
         w_highpriority, out_subtype = array_priority(space, w_lhs, w_rhs)
@@ -637,7 +630,10 @@
             w_res = space.call_method(w_highpriority, '__array_wrap__', w_res, 
ctxt)
         return w_res
 
-    def call_scalar(self, space, w_lhs, w_rhs, in_dtype):
+    def call_scalar(self, space, w_lhs, w_rhs, casting):
+        in_dtype, out_dtype, func = self.find_specialization(
+            space, w_lhs.get_dtype(space), w_rhs.get_dtype(space),
+            out=None, casting=casting)
         w_val = self.func(in_dtype,
                           w_lhs.convert_to(space, in_dtype),
                           w_rhs.convert_to(space, in_dtype))
@@ -645,7 +641,8 @@
             return w_val.w_obj
         return w_val
 
-    def _find_specialization(self, space, l_dtype, r_dtype, out, casting):
+    def _find_specialization(self, space, l_dtype, r_dtype, out, casting,
+                             w_arg1, w_arg2):
         if (not self.allow_bool and (l_dtype.is_bool() or
                                          r_dtype.is_bool()) or
                 not self.allow_complex and (l_dtype.is_complex() or
@@ -657,15 +654,23 @@
             dtype = find_result_type(space, [], [l_dtype, r_dtype])
             bool_dtype = get_dtype_cache(space).w_booldtype
             return dtype, bool_dtype, self.func
-        dt_in, dt_out = self._calc_dtype(space, l_dtype, r_dtype, out, casting)
+        dt_in, dt_out = self._calc_dtype(
+            space, l_dtype, r_dtype, out, casting, w_arg1, w_arg2)
         return dt_in, dt_out, self.func
 
-    def find_specialization(self, space, l_dtype, r_dtype, out, casting):
+    def find_specialization(self, space, l_dtype, r_dtype, out, casting,
+                            w_arg1=None, w_arg2=None):
         if self.simple_binary:
             if out is None and not (l_dtype.is_object() or 
r_dtype.is_object()):
-                dtype = promote_types(space, l_dtype, r_dtype)
+                if w_arg1 is not None and w_arg2 is not None:
+                    w_arg1 = convert_to_array(space, w_arg1)
+                    w_arg2 = convert_to_array(space, w_arg2)
+                    dtype = find_result_type(space, [w_arg1, w_arg2], [])
+                else:
+                    dtype = promote_types(space, l_dtype, r_dtype)
                 return dtype, dtype, self.func
-        return self._find_specialization(space, l_dtype, r_dtype, out, casting)
+        return self._find_specialization(
+            space, l_dtype, r_dtype, out, casting, w_arg1, w_arg2)
 
     def find_binop_type(self, space, dtype):
         """Find a valid dtype signature of the form xx->x"""
@@ -686,15 +691,21 @@
             "requested type has type code '%s'" % (self.name, dtype.char))
 
 
-    def _calc_dtype(self, space, l_dtype, r_dtype, out=None, casting='unsafe'):
-        use_min_scalar = False
+    def _calc_dtype(self, space, l_dtype, r_dtype, out, casting,
+                    w_arg1, w_arg2):
         if l_dtype.is_object() or r_dtype.is_object():
             dtype = get_dtype_cache(space).w_objectdtype
             return dtype, dtype
+        use_min_scalar = (w_arg1 is not None and w_arg2 is not None and
+                          ((w_arg1.is_scalar() and not w_arg2.is_scalar()) or
+                           (not w_arg1.is_scalar() and w_arg2.is_scalar())))
         in_casting = safe_casting_mode(casting)
         for dt_in, dt_out in self.dtypes:
             if use_min_scalar:
-                if not can_cast_array(space, w_arg, dt_in, in_casting):
+                w_arg1 = convert_to_array(space, w_arg1)
+                w_arg2 = convert_to_array(space, w_arg2)
+                if not (can_cast_array(space, w_arg1, dt_in, in_casting) and
+                        can_cast_array(space, w_arg2, dt_in, in_casting)):
                     continue
             else:
                 if not (can_cast_type(space, l_dtype, dt_in, in_casting) and
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to