Author: Ronan Lamy <[email protected]>
Branch: fix-result-types
Changeset: r77468:9e71be121b16
Date: 2015-05-22 01:10 +0100
http://bitbucket.org/pypy/pypy/changeset/9e71be121b16/
Log: Handle record dtypes in can_cast()
diff --git a/pypy/module/micronumpy/casting.py
b/pypy/module/micronumpy/casting.py
--- a/pypy/module/micronumpy/casting.py
+++ b/pypy/module/micronumpy/casting.py
@@ -141,6 +141,11 @@
def can_cast_type(space, origin, target, casting):
# equivalent to PyArray_CanCastTypeTo
+ if origin == target:
+ return True
+ if origin.is_record() or target.is_record():
+ return can_cast_record(space, origin, target, casting)
+
if casting == 'no':
return origin.eq(space, target)
elif casting == 'equiv':
@@ -156,6 +161,22 @@
else: # 'safe'
return origin.can_cast_to(target)
+def can_cast_record(space, origin, target, casting):
+ if origin is target:
+ return True
+ if origin.fields is None or target.fields is None:
+ return False
+ if len(origin.fields) != len(target.fields):
+ return False
+ for name, (offset, orig_field) in origin.fields.iteritems():
+ if name not in target.fields:
+ return False
+ target_field = target.fields[name][1]
+ if not can_cast_type(space, orig_field, target_field, casting):
+ return False
+ return True
+
+
def can_cast_array(space, w_from, target, casting):
# equivalent to PyArray_CanCastArrayTo
origin = w_from.get_dtype()
diff --git a/pypy/module/micronumpy/test/test_casting.py
b/pypy/module/micronumpy/test/test_casting.py
--- a/pypy/module/micronumpy/test/test_casting.py
+++ b/pypy/module/micronumpy/test/test_casting.py
@@ -118,6 +118,15 @@
assert np.can_cast(1., np.complex64)
assert not np.can_cast(1e50, np.complex64)
+ def test_can_cast_record(self):
+ import numpy as np
+ rec1 = np.dtype([('x', int), ('y', float)])
+ rec2 = np.dtype([('x', float), ('y', float)])
+ rec3 = np.dtype([('y', np.float64), ('x', float)])
+ assert not np.can_cast(rec1, rec2, 'equiv')
+ assert np.can_cast(rec2, rec3, 'equiv')
+ assert np.can_cast(rec1, rec2)
+
def test_min_scalar_type(self):
import numpy as np
assert np.min_scalar_type(2**8 - 1) == np.dtype('uint8')
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit