Author: Ronan Lamy <[email protected]>
Branch: fix-result-types
Changeset: r77468:9e71be121b16
Date: 2015-05-22 01:10 +0100
http://bitbucket.org/pypy/pypy/changeset/9e71be121b16/

Log:    Handle record dtypes in can_cast()

diff --git a/pypy/module/micronumpy/casting.py 
b/pypy/module/micronumpy/casting.py
--- a/pypy/module/micronumpy/casting.py
+++ b/pypy/module/micronumpy/casting.py
@@ -141,6 +141,11 @@
 
 def can_cast_type(space, origin, target, casting):
     # equivalent to PyArray_CanCastTypeTo
+    if origin == target:
+        return True
+    if origin.is_record() or target.is_record():
+        return can_cast_record(space, origin, target, casting)
+
     if casting == 'no':
         return origin.eq(space, target)
     elif casting == 'equiv':
@@ -156,6 +161,22 @@
     else:  # 'safe'
         return origin.can_cast_to(target)
 
+def can_cast_record(space, origin, target, casting):
+    if origin is target:
+        return True
+    if origin.fields is None or target.fields is None:
+        return False
+    if len(origin.fields) != len(target.fields):
+        return False
+    for name, (offset, orig_field) in origin.fields.iteritems():
+        if name not in target.fields:
+            return False
+        target_field = target.fields[name][1]
+        if not can_cast_type(space, orig_field, target_field, casting):
+            return False
+    return True
+
+
 def can_cast_array(space, w_from, target, casting):
     # equivalent to PyArray_CanCastArrayTo
     origin = w_from.get_dtype()
diff --git a/pypy/module/micronumpy/test/test_casting.py 
b/pypy/module/micronumpy/test/test_casting.py
--- a/pypy/module/micronumpy/test/test_casting.py
+++ b/pypy/module/micronumpy/test/test_casting.py
@@ -118,6 +118,15 @@
         assert np.can_cast(1., np.complex64)
         assert not np.can_cast(1e50, np.complex64)
 
+    def test_can_cast_record(self):
+        import numpy as np
+        rec1 = np.dtype([('x', int), ('y', float)])
+        rec2 = np.dtype([('x', float), ('y', float)])
+        rec3 = np.dtype([('y', np.float64), ('x', float)])
+        assert not np.can_cast(rec1, rec2, 'equiv')
+        assert np.can_cast(rec2, rec3, 'equiv')
+        assert np.can_cast(rec1, rec2)
+
     def test_min_scalar_type(self):
         import numpy as np
         assert np.min_scalar_type(2**8 - 1) == np.dtype('uint8')
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to