Author: Justin Peel <[email protected]>
Branch: numpy-dtype
Changeset: r46226:959ca3a44df9
Date: 2011-08-02 23:30 -0600
http://bitbucket.org/pypy/pypy/changeset/959ca3a44df9/
Log: added find_result_dtype. binops should work correctly now.
diff --git a/pypy/module/micronumpy/interp_dtype.py
b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -46,6 +46,8 @@
UNSIGNEDLTR = 'u'
COMPLEXLTR = 'c'
+kind_dict = {'b': 0, 'u': 1, 'i': 1, 'f': 2, 'c': 2}
+
class Dtype(Wrappable):
# attributes: type, kind, typeobj?(I think it should point to np.float64 or
# the like), byteorder, flags, type_num, elsize, alignment, subarray,
@@ -174,14 +176,41 @@
raise OperationError(space.w_TypeError,
space.wrap("data type not understood"))
-def find_base_dtype(dtype1, dtype2):
+def find_result_dtype(d1, d2):
+ # this function is for determining the result dtype of bin ops, etc.
+ # it is kind of a mess so feel free to improve it
+
+ # first make sure larger num is in d2
+ if d1.num > d2.num:
+ dtype1 = d2
+ dtype2 = d1
+ else:
+ dtype1 = d1
+ dtype2 = d2
num1 = dtype1.num
num2 = dtype2.num
- # this is much more complex
- if num1 < num2:
+ kind1 = dtype1.kind
+ kind2 = dtype2.kind
+ if kind1 == kind2:
+ # dtype2 has the greater number
return dtype2
- return dtype
-
+ kind_num1 = kind_dict[kind1]
+ kind_num2 = kind_dict[kind2]
+ if kind_num1 == kind_num2: # two kinds of integers or float and complex
+ # XXX: Need to deal with float and complex combo here also
+ if kind2 == SIGNEDLTR:
+ return dtype2
+ if num2 < UInt32_num:
+ return _dtype_list[num2+1]
+ if num2 == UInt64_num or (LONG_BIT == 64 and num2 == Long_num): #
UInt64
+ return Float64_dtype
+ # dtype2 is uint32
+ return Int64_dtype
+ if kind_num1 == 1: # is an integer
+ if num2 == Float32_num and num2 == UInt64_num or \
+ (LONG_BIT == 64 and num2 == Long_num):
+ return Float64_dtype
+ return dtype2
def descr_new_dtype(space, w_type, w_string_or_type):
return space.wrap(get_dtype(space, w_type, w_string_or_type))
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -2,7 +2,7 @@
from pypy.interpreter.error import OperationError, operationerrfmt
from pypy.interpreter.gateway import interp2app, unwrap_spec
from pypy.interpreter.typedef import TypeDef, GetSetProperty
-from pypy.module.micronumpy.interp_dtype import Dtype, Float64_num, Int32_num,
Float64_dtype, get_dtype, find_scalar_dtype, find_base_dtype
+from pypy.module.micronumpy.interp_dtype import Dtype, Float64_num, Int32_num,
Float64_dtype, get_dtype, find_scalar_dtype, find_result_dtype
from pypy.module.micronumpy.interp_support import Signature
from pypy.module.micronumpy import interp_ufuncs
from pypy.objspace.std.floatobject import float2string as float2string_orig
@@ -417,16 +417,24 @@
def __init__(self, function, left, right, signature):
VirtualArray.__init__(self, signature)
- self.function = function
self.left = left
self.right = right
dtype = self.left.find_dtype()
dtype2 = self.right.find_dtype()
- # this is more complicated than this.
- # for instance int32 + uint32 = int64
- if dtype.num != dtype.num:
- dtype = find_base_dtype(dtype, dtype2)
- self.dtype = dtype
+ if dtype.num != dtype2.num:
+ newdtype = find_result_dtype(dtype, dtype2)
+ cast = newdtype.cast
+ if dtype.num != newdtype.num:
+ if dtype2.num != newdtype.num:
+ self.function = lambda x, y: function(cast(x), cast(y))
+ else:
+ self.function = lambda x, y: function(cast(x), y)
+ else:
+ self.function = lambda x, y: function(x, cast(y))
+ self.dtype = newdtype
+ else:
+ self.dtype = dtype
+ self.function = function
def _del_sources(self):
self.left = None
diff --git a/pypy/module/micronumpy/test/test_dtypes.py
b/pypy/module/micronumpy/test/test_dtypes.py
--- a/pypy/module/micronumpy/test/test_dtypes.py
+++ b/pypy/module/micronumpy/test/test_dtypes.py
@@ -35,3 +35,13 @@
assert a[0] == 1
assert a[1] == 2
assert a[2] == 3
+
+ def test_bool_binop_types(self):
+ from numpy import array, dtype
+ types = ('?','b','B','h','H','i','I','l','L','q','Q','f','d','g')
+ dtypes = [dtype(t) for t in types]
+ N = len(types)
+ a = array([True],'?')
+ for i in xrange(N):
+ assert (a + array([0], types[i])).dtype is dtypes[i]
+# need more tests for binop result types
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit