Author: Ronan Lamy <ronan.l...@gmail.com> Branch: fix-result-types Changeset: r77671:75c1167b4588 Date: 2015-05-29 03:37 +0100 http://bitbucket.org/pypy/pypy/changeset/75c1167b4588/
Log: Move can_cast_to() and can_cast_itemtype() to casting.py diff --git a/pypy/module/micronumpy/casting.py b/pypy/module/micronumpy/casting.py --- a/pypy/module/micronumpy/casting.py +++ b/pypy/module/micronumpy/casting.py @@ -1,17 +1,19 @@ """Functions and helpers for converting between dtypes""" from rpython.rlib import jit +from rpython.rlib.signature import signature, types as ann from pypy.interpreter.gateway import unwrap_spec from pypy.interpreter.error import oefmt, OperationError from pypy.module.micronumpy.base import W_NDimArray, convert_to_array from pypy.module.micronumpy import constants as NPY from .types import ( - Bool, ULong, Long, Float64, Complex64, StringType, UnicodeType, VoidType, ObjectType, + BaseType, Bool, ULong, Long, Float64, Complex64, + StringType, UnicodeType, VoidType, ObjectType, int_types, float_types, complex_types, number_types, all_types) from .descriptor import ( - get_dtype_cache, as_dtype, is_scalar_w, variable_dtype, new_string_dtype, - new_unicode_dtype, num2dtype) + W_Dtype, get_dtype_cache, as_dtype, is_scalar_w, variable_dtype, + new_string_dtype, new_unicode_dtype, num2dtype) @jit.unroll_safe def result_type(space, __args__): @@ -153,13 +155,13 @@ elif casting == 'unsafe': return True elif casting == 'same_kind': - if origin.can_cast_to(target): + if can_cast_to(origin, target): return True if origin.kind in kind_ordering and target.kind in kind_ordering: return kind_ordering[origin.kind] <= kind_ordering[target.kind] return False else: # 'safe' - return origin.can_cast_to(target) + return can_cast_to(origin, target) def can_cast_record(space, origin, target, casting): if origin is target: @@ -325,6 +327,37 @@ return variable_dtype(space, 'S%d' % space.len_w(w_obj)) return object_dtype +@signature(ann.instance(W_Dtype), ann.instance(W_Dtype), returns=ann.bool()) +def can_cast_to(dt1, dt2): + """Return whether dtype `dt1` can be cast safely to `dt2`""" + # equivalent to PyArray_CanCastTo + from .casting import can_cast_itemtype + result = can_cast_itemtype(dt1.itemtype, dt2.itemtype) + if result: + if dt1.num == NPY.STRING: + if dt2.num == NPY.STRING: + return dt1.elsize <= dt2.elsize + elif dt2.num == NPY.UNICODE: + return dt1.elsize * 4 <= dt2.elsize + elif dt1.num == NPY.UNICODE and dt2.num == NPY.UNICODE: + return dt1.elsize <= dt2.elsize + elif dt2.num in (NPY.STRING, NPY.UNICODE): + if dt2.num == NPY.STRING: + char_size = 1 + else: # NPY.UNICODE + char_size = 4 + if dt2.elsize == 0: + return True + if dt1.is_int(): + return dt2.elsize >= dt1.itemtype.strlen * char_size + return result + + +@signature(ann.instance(BaseType), ann.instance(BaseType), returns=ann.bool()) +def can_cast_itemtype(tp1, tp2): + # equivalent to PyArray_CanCastSafely + return casting_table[tp1.num][tp2.num] + #_________________________ @@ -334,6 +367,7 @@ casting_table[type1.num][type2.num] = True def _can_cast(type1, type2): + """NOT_RPYTHON: operates on BaseType subclasses""" return casting_table[type1.num][type2.num] for tp in all_types: diff --git a/pypy/module/micronumpy/descriptor.py b/pypy/module/micronumpy/descriptor.py --- a/pypy/module/micronumpy/descriptor.py +++ b/pypy/module/micronumpy/descriptor.py @@ -8,7 +8,6 @@ from rpython.rlib import jit from rpython.rlib.objectmodel import specialize, compute_hash, we_are_translated from rpython.rlib.rarithmetic import r_longlong, r_ulonglong -from rpython.rlib.signature import finishsigs, signature, types as ann from pypy.module.micronumpy import types, boxes, support, constants as NPY from .base import W_NDimArray from pypy.module.micronumpy.appbridge import get_appbridge_cache @@ -41,7 +40,6 @@ -@finishsigs class W_Dtype(W_Root): _immutable_fields_ = [ "itemtype?", "w_box_type", "byteorder?", "names?", "fields?", @@ -95,29 +93,6 @@ def box_complex(self, real, imag): return self.itemtype.box_complex(real, imag) - @signature(ann.self(), ann.self(), returns=ann.bool()) - def can_cast_to(self, other): - # equivalent to PyArray_CanCastTo - result = self.itemtype.can_cast_to(other.itemtype) - if result: - if self.num == NPY.STRING: - if other.num == NPY.STRING: - return self.elsize <= other.elsize - elif other.num == NPY.UNICODE: - return self.elsize * 4 <= other.elsize - elif self.num == NPY.UNICODE and other.num == NPY.UNICODE: - return self.elsize <= other.elsize - elif other.num in (NPY.STRING, NPY.UNICODE): - if other.num == NPY.STRING: - char_size = 1 - else: # NPY.UNICODE - char_size = 4 - if other.elsize == 0: - return True - if self.is_int(): - return other.elsize >= self.itemtype.strlen * char_size - return result - def coerce(self, space, w_item): return self.itemtype.coerce(space, self, w_item) @@ -311,20 +286,24 @@ return space.wrap(not self.eq(space, w_other)) def descr_le(self, space, w_other): + from .casting import can_cast_to w_other = as_dtype(space, w_other) - return space.wrap(self.can_cast_to(w_other)) + return space.wrap(can_cast_to(self, w_other)) def descr_ge(self, space, w_other): + from .casting import can_cast_to w_other = as_dtype(space, w_other) - return space.wrap(w_other.can_cast_to(self)) + return space.wrap(can_cast_to(w_other, self)) def descr_lt(self, space, w_other): + from .casting import can_cast_to w_other = as_dtype(space, w_other) - return space.wrap(self.can_cast_to(w_other) and not self.eq(space, w_other)) + return space.wrap(can_cast_to(self, w_other) and not self.eq(space, w_other)) def descr_gt(self, space, w_other): + from .casting import can_cast_to w_other = as_dtype(space, w_other) - return space.wrap(w_other.can_cast_to(self) and not self.eq(space, w_other)) + return space.wrap(can_cast_to(w_other, self) and not self.eq(space, w_other)) def _compute_hash(self, space, x): from rpython.rlib.rarithmetic import intmask diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py --- a/pypy/module/micronumpy/types.py +++ b/pypy/module/micronumpy/types.py @@ -154,11 +154,6 @@ def basesize(cls): return rffi.sizeof(cls.T) - def can_cast_to(self, other): - # equivalent to PyArray_CanCastSafely - from .casting import casting_table - return casting_table[self.num][other.num] - class Primitive(object): _mixin_ = True diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py --- a/pypy/module/micronumpy/ufuncs.py +++ b/pypy/module/micronumpy/ufuncs.py @@ -20,7 +20,8 @@ from pypy.module.micronumpy.strides import shape_agreement from pypy.module.micronumpy.support import (_parse_signature, product, get_storage_as_int, is_rhs_priority_higher) -from .casting import can_cast_type, find_result_type, promote_types +from .casting import ( + can_cast_type, can_cast_to, find_result_type, promote_types) from .boxes import W_GenericBox, W_ObjectBox def done_if_true(dtype, val): @@ -668,14 +669,14 @@ if dtype.is_object(): return dtype for dt_in, dt_out in self.dtypes: - if dtype.can_cast_to(dt_in): + if can_cast_to(dtype, dt_in): if dt_out == dt_in: return dt_in else: dtype = dt_out break for dt_in, dt_out in self.dtypes: - if dtype.can_cast_to(dt_in) and dt_out == dt_in: + if can_cast_to(dtype, dt_in) and dt_out == dt_in: return dt_in raise ValueError( "could not find a matching type for %s.accumulate, " _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit