Author: Ronan Lamy <ronan.l...@gmail.com>
Branch: fix-result-types
Changeset: r77671:75c1167b4588
Date: 2015-05-29 03:37 +0100
http://bitbucket.org/pypy/pypy/changeset/75c1167b4588/

Log:    Move can_cast_to() and can_cast_itemtype() to casting.py

diff --git a/pypy/module/micronumpy/casting.py 
b/pypy/module/micronumpy/casting.py
--- a/pypy/module/micronumpy/casting.py
+++ b/pypy/module/micronumpy/casting.py
@@ -1,17 +1,19 @@
 """Functions and helpers for converting between dtypes"""
 
 from rpython.rlib import jit
+from rpython.rlib.signature import signature, types as ann
 from pypy.interpreter.gateway import unwrap_spec
 from pypy.interpreter.error import oefmt, OperationError
 
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
 from pypy.module.micronumpy import constants as NPY
 from .types import (
-    Bool, ULong, Long, Float64, Complex64, StringType, UnicodeType, VoidType, 
ObjectType,
+    BaseType, Bool, ULong, Long, Float64, Complex64,
+    StringType, UnicodeType, VoidType, ObjectType,
     int_types, float_types, complex_types, number_types, all_types)
 from .descriptor import (
-    get_dtype_cache, as_dtype, is_scalar_w, variable_dtype, new_string_dtype,
-    new_unicode_dtype, num2dtype)
+    W_Dtype, get_dtype_cache, as_dtype, is_scalar_w, variable_dtype,
+    new_string_dtype, new_unicode_dtype, num2dtype)
 
 @jit.unroll_safe
 def result_type(space, __args__):
@@ -153,13 +155,13 @@
     elif casting == 'unsafe':
         return True
     elif casting == 'same_kind':
-        if origin.can_cast_to(target):
+        if can_cast_to(origin, target):
             return True
         if origin.kind in kind_ordering and target.kind in kind_ordering:
             return kind_ordering[origin.kind] <= kind_ordering[target.kind]
         return False
     else:  # 'safe'
-        return origin.can_cast_to(target)
+        return can_cast_to(origin, target)
 
 def can_cast_record(space, origin, target, casting):
     if origin is target:
@@ -325,6 +327,37 @@
         return variable_dtype(space, 'S%d' % space.len_w(w_obj))
     return object_dtype
 
+@signature(ann.instance(W_Dtype), ann.instance(W_Dtype), returns=ann.bool())
+def can_cast_to(dt1, dt2):
+    """Return whether dtype `dt1` can be cast safely to `dt2`"""
+    # equivalent to PyArray_CanCastTo
+    from .casting import can_cast_itemtype
+    result = can_cast_itemtype(dt1.itemtype, dt2.itemtype)
+    if result:
+        if dt1.num == NPY.STRING:
+            if dt2.num == NPY.STRING:
+                return dt1.elsize <= dt2.elsize
+            elif dt2.num == NPY.UNICODE:
+                return dt1.elsize * 4 <= dt2.elsize
+        elif dt1.num == NPY.UNICODE and dt2.num == NPY.UNICODE:
+            return dt1.elsize <= dt2.elsize
+        elif dt2.num in (NPY.STRING, NPY.UNICODE):
+            if dt2.num == NPY.STRING:
+                char_size = 1
+            else:  # NPY.UNICODE
+                char_size = 4
+            if dt2.elsize == 0:
+                return True
+            if dt1.is_int():
+                return dt2.elsize >= dt1.itemtype.strlen * char_size
+    return result
+
+
+@signature(ann.instance(BaseType), ann.instance(BaseType), returns=ann.bool())
+def can_cast_itemtype(tp1, tp2):
+    # equivalent to PyArray_CanCastSafely
+    return casting_table[tp1.num][tp2.num]
+
 #_________________________
 
 
@@ -334,6 +367,7 @@
     casting_table[type1.num][type2.num] = True
 
 def _can_cast(type1, type2):
+    """NOT_RPYTHON: operates on BaseType subclasses"""
     return casting_table[type1.num][type2.num]
 
 for tp in all_types:
diff --git a/pypy/module/micronumpy/descriptor.py 
b/pypy/module/micronumpy/descriptor.py
--- a/pypy/module/micronumpy/descriptor.py
+++ b/pypy/module/micronumpy/descriptor.py
@@ -8,7 +8,6 @@
 from rpython.rlib import jit
 from rpython.rlib.objectmodel import specialize, compute_hash, 
we_are_translated
 from rpython.rlib.rarithmetic import r_longlong, r_ulonglong
-from rpython.rlib.signature import finishsigs, signature, types as ann
 from pypy.module.micronumpy import types, boxes, support, constants as NPY
 from .base import W_NDimArray
 from pypy.module.micronumpy.appbridge import get_appbridge_cache
@@ -41,7 +40,6 @@
 
 
 
-@finishsigs
 class W_Dtype(W_Root):
     _immutable_fields_ = [
         "itemtype?", "w_box_type", "byteorder?", "names?", "fields?",
@@ -95,29 +93,6 @@
     def box_complex(self, real, imag):
         return self.itemtype.box_complex(real, imag)
 
-    @signature(ann.self(), ann.self(), returns=ann.bool())
-    def can_cast_to(self, other):
-        # equivalent to PyArray_CanCastTo
-        result = self.itemtype.can_cast_to(other.itemtype)
-        if result:
-            if self.num == NPY.STRING:
-                if other.num == NPY.STRING:
-                    return self.elsize <= other.elsize
-                elif other.num == NPY.UNICODE:
-                    return self.elsize * 4 <= other.elsize
-            elif self.num == NPY.UNICODE and other.num == NPY.UNICODE:
-                return self.elsize <= other.elsize
-            elif other.num in (NPY.STRING, NPY.UNICODE):
-                if other.num == NPY.STRING:
-                    char_size = 1
-                else:  # NPY.UNICODE
-                    char_size = 4
-                if other.elsize == 0:
-                    return True
-                if self.is_int():
-                    return other.elsize >= self.itemtype.strlen * char_size
-        return result
-
     def coerce(self, space, w_item):
         return self.itemtype.coerce(space, self, w_item)
 
@@ -311,20 +286,24 @@
         return space.wrap(not self.eq(space, w_other))
 
     def descr_le(self, space, w_other):
+        from .casting import can_cast_to
         w_other = as_dtype(space, w_other)
-        return space.wrap(self.can_cast_to(w_other))
+        return space.wrap(can_cast_to(self, w_other))
 
     def descr_ge(self, space, w_other):
+        from .casting import can_cast_to
         w_other = as_dtype(space, w_other)
-        return space.wrap(w_other.can_cast_to(self))
+        return space.wrap(can_cast_to(w_other, self))
 
     def descr_lt(self, space, w_other):
+        from .casting import can_cast_to
         w_other = as_dtype(space, w_other)
-        return space.wrap(self.can_cast_to(w_other) and not self.eq(space, 
w_other))
+        return space.wrap(can_cast_to(self, w_other) and not self.eq(space, 
w_other))
 
     def descr_gt(self, space, w_other):
+        from .casting import can_cast_to
         w_other = as_dtype(space, w_other)
-        return space.wrap(w_other.can_cast_to(self) and not self.eq(space, 
w_other))
+        return space.wrap(can_cast_to(w_other, self) and not self.eq(space, 
w_other))
 
     def _compute_hash(self, space, x):
         from rpython.rlib.rarithmetic import intmask
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -154,11 +154,6 @@
     def basesize(cls):
         return rffi.sizeof(cls.T)
 
-    def can_cast_to(self, other):
-        # equivalent to PyArray_CanCastSafely
-        from .casting import casting_table
-        return casting_table[self.num][other.num]
-
 class Primitive(object):
     _mixin_ = True
 
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -20,7 +20,8 @@
 from pypy.module.micronumpy.strides import shape_agreement
 from pypy.module.micronumpy.support import (_parse_signature, product,
         get_storage_as_int, is_rhs_priority_higher)
-from .casting import can_cast_type, find_result_type, promote_types
+from .casting import (
+    can_cast_type, can_cast_to, find_result_type, promote_types)
 from .boxes import W_GenericBox, W_ObjectBox
 
 def done_if_true(dtype, val):
@@ -668,14 +669,14 @@
         if dtype.is_object():
             return dtype
         for dt_in, dt_out in self.dtypes:
-            if dtype.can_cast_to(dt_in):
+            if can_cast_to(dtype, dt_in):
                 if dt_out == dt_in:
                     return dt_in
                 else:
                     dtype = dt_out
                     break
         for dt_in, dt_out in self.dtypes:
-            if dtype.can_cast_to(dt_in) and dt_out == dt_in:
+            if can_cast_to(dtype, dt_in) and dt_out == dt_in:
                 return dt_in
         raise ValueError(
             "could not find a matching type for %s.accumulate, "
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to