Author: Brian Kearns <[email protected]>
Branch: 
Changeset: r69346:12fa4e02e3cf
Date: 2014-02-24 02:15 -0500
http://bitbucket.org/pypy/pypy/changeset/12fa4e02e3cf/

Log:    improve dtype setstate functionality

diff --git a/pypy/module/micronumpy/interp_dtype.py 
b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -37,13 +37,15 @@
 
 
 class W_Dtype(W_Root):
-    _immutable_fields_ = ["itemtype?", "num", "kind", "name", "char",
-                          "w_box_type", "byteorder", "size?", "float_type",
-                          "fields?", "fieldnames?", "shape", "subdtype", 
"base"]
+    _immutable_fields_ = [
+        "num", "kind", "name", "char", "w_box_type", "float_type",
+        "itemtype?", "byteorder?", "fields?", "fieldnames?", "size?",
+        "shape?", "subdtype?", "base?"
+    ]
 
     def __init__(self, itemtype, num, kind, name, char, w_box_type, 
byteorder=NPY.NATIVE,
                  size=1, alternate_constructors=[], aliases=[], 
float_type=None,
-                 fields=None, fieldnames=None, shape=[], subdtype=None):
+                 fields={}, fieldnames=[], shape=[], subdtype=None):
         self.itemtype = itemtype
         self.num = num
         self.kind = kind
@@ -56,10 +58,8 @@
         self.aliases = aliases
         self.float_type = float_type
         self.fields = fields
-        if fieldnames is None:
-            fieldnames = []
         self.fieldnames = fieldnames
-        self.shape = list(shape)
+        self.shape = shape
         self.subdtype = subdtype
         if not subdtype:
             self.base = self
@@ -102,16 +102,16 @@
         return self.kind == NPY.GENBOOLLTR
 
     def is_record_type(self):
-        return self.fields is not None
+        return bool(self.fields)
 
     def is_str_type(self):
         return self.num == NPY.STRING
 
     def is_str_or_unicode(self):
-        return (self.num == NPY.STRING or self.num == NPY.UNICODE)
+        return self.num == NPY.STRING or self.num == NPY.UNICODE
 
     def is_flexible_type(self):
-        return (self.is_str_or_unicode() or self.is_record_type())
+        return self.is_str_or_unicode() or self.num == NPY.VOID
 
     def is_native(self):
         return self.byteorder in (NPY.NATIVE, NPY.NATBYTE)
@@ -125,28 +125,29 @@
         return get_dtype_cache(space).dtypes_by_name[self.byteorder + 
self.float_type]
 
     def descr_str(self, space):
-        if not self.num == NPY.VOID:
-            if self.char == 'S':
-                return space.wrap('|S' + str(self.get_size()))
-            else:
-                return self.descr_get_name(space)
+        if self.fields:
+            return space.str(self.descr_get_descr(space))
         elif self.subdtype is not None:
             return space.str(space.newtuple([
                 self.subdtype.descr_get_str(space),
                 self.descr_get_shape(space)]))
-        return space.str(self.descr_get_descr(space))
+        else:
+            if self.is_flexible_type():
+                return space.wrap('|' + self.char + str(self.get_size()))
+            else:
+                return self.descr_get_name(space)
 
     def descr_repr(self, space):
-        if not self.num == NPY.VOID:
-            if self.char == 'S':
-                r = space.wrap('S' + str(self.get_size()))
-            else:
-                r = self.descr_get_name(space)
+        if self.fields:
+            r = self.descr_get_descr(space)
         elif self.subdtype is not None:
             r = space.newtuple([self.subdtype.descr_get_str(space),
                                 self.descr_get_shape(space)])
         else:
-            r = self.descr_get_descr(space)
+            if self.is_flexible_type():
+                r = space.wrap(self.char + str(self.get_size()))
+            else:
+                r = self.descr_get_name(space)
         return space.wrap("dtype(%s)" % space.str_w(space.repr(r)))
 
     def descr_get_itemsize(self, space):
@@ -224,7 +225,7 @@
         return space.wrap(not self.eq(space, w_other))
 
     def descr_get_fields(self, space):
-        if self.fields is None:
+        if not self.fields:
             return space.w_None
         w_d = space.newdict()
         for name, (offset, subdtype) in self.fields.iteritems():
@@ -232,34 +233,13 @@
                           space.newtuple([subdtype, space.wrap(offset)]))
         return w_d
 
-    def descr_set_fields(self, space, w_fieldnames, w_fields):
-        if w_fields == space.w_None:
-            self.fields = None
-        else:
-            self.fieldnames = []
-            self.fields = {}
-            size = 0
-            for w_name in space.fixedview(w_fieldnames):
-                name = space.str_w(w_name)
-                value = space.getitem(w_fields, w_name)
-
-                dtype = space.getitem(value, space.wrap(0))
-                assert isinstance(dtype, W_Dtype)
-                offset = space.int_w(space.getitem(value, space.wrap(1)))
-
-                self.fieldnames.append(name)
-                self.fields[name] = offset, dtype
-                size += dtype.get_size()
-            self.itemtype = types.RecordType()
-            self.size = size
-
     def descr_get_names(self, space):
-        if len(self.fieldnames) == 0:
+        if not self.fields:
             return space.w_None
         return space.newtuple([space.wrap(name) for name in self.fieldnames])
 
     def descr_set_names(self, space, w_names):
-        if len(self.fieldnames) == 0:
+        if not self.fields:
             raise oefmt(space.w_ValueError, "there are no fields defined")
         if not space.issequence_w(w_names) or \
                 space.len_w(w_names) != len(self.fieldnames):
@@ -354,17 +334,63 @@
         return space.newtuple([w_class, builder_args, data])
 
     def descr_setstate(self, space, w_data):
-        if space.int_w(space.getitem(w_data, space.wrap(0))) != 3:
-            raise OperationError(space.w_NotImplementedError, 
space.wrap("Pickling protocol version not supported"))
+        if self.fields is None:  # if builtin dtype
+            return space.w_None
+
+        version = space.int_w(space.getitem(w_data, space.wrap(0)))
+        if version != 3:
+            raise oefmt(space.w_ValueError,
+                        "can't handle version %d of numpy.dtype pickle",
+                        version)
 
         endian = space.str_w(space.getitem(w_data, space.wrap(1)))
         if endian == NPY.NATBYTE:
             endian = NPY.NATIVE
-        self.byteorder = endian
 
+        w_subarray = space.getitem(w_data, space.wrap(2))
         w_fieldnames = space.getitem(w_data, space.wrap(3))
         w_fields = space.getitem(w_data, space.wrap(4))
-        self.descr_set_fields(space, w_fieldnames, w_fields)
+        size = space.int_w(space.getitem(w_data, space.wrap(5)))
+
+        if (w_fieldnames == space.w_None) != (w_fields == space.w_None):
+            raise oefmt(space.w_ValueError, "inconsistent fields and names")
+
+        self.byteorder = endian
+        self.shape = []
+        self.subdtype = None
+        self.base = self
+
+        if w_subarray != space.w_None:
+            if not space.isinstance_w(w_subarray, space.w_tuple) or \
+                    space.len_w(w_subarray) != 2:
+                raise oefmt(space.w_ValueError,
+                            "incorrect subarray in __setstate__")
+            subdtype, w_shape = space.fixedview(w_subarray)
+            assert isinstance(subdtype, W_Dtype)
+            if not base.issequence_w(space, w_shape):
+                self.shape = [space.int_w(w_shape)]
+            else:
+                self.shape = [space.int_w(w_s) for w_s in 
space.fixedview(w_shape)]
+            self.subdtype = subdtype
+            self.base = subdtype.base
+
+        if w_fieldnames != space.w_None:
+            self.fieldnames = []
+            self.fields = {}
+            for w_name in space.fixedview(w_fieldnames):
+                name = space.str_w(w_name)
+                value = space.getitem(w_fields, w_name)
+
+                dtype = space.getitem(value, space.wrap(0))
+                assert isinstance(dtype, W_Dtype)
+                offset = space.int_w(space.getitem(value, space.wrap(1)))
+
+                self.fieldnames.append(name)
+                self.fields[name] = offset, dtype
+            self.itemtype = types.RecordType()
+
+        if self.is_flexible_type():
+            self.size = size
 
     @unwrap_spec(new_order=str)
     def descr_newbyteorder(self, space, new_order=NPY.SWAP):
@@ -862,6 +888,8 @@
                     float_type=dtype.float_type)
             for alias in dtype.aliases:
                 self.dtypes_by_name[alias] = dtype
+        for dtype in self.dtypes_by_name.values():
+            dtype.fields = None  # mark these as builtin
 
         typeinfo_full = {
             'LONGLONG': self.w_int64dtype,
diff --git a/pypy/module/micronumpy/test/test_dtypes.py 
b/pypy/module/micronumpy/test/test_dtypes.py
--- a/pypy/module/micronumpy/test/test_dtypes.py
+++ b/pypy/module/micronumpy/test/test_dtypes.py
@@ -1096,6 +1096,96 @@
         assert dt.subdtype == (dtype(float), (10,))
         assert dt.base == dtype(float)
 
+    def test_setstate(self):
+        import numpy as np
+        import sys
+        d = np.dtype('f8')
+        d.__setstate__((3, '|', (np.dtype('float64'), (2,)), None, None, 20, 
1, 0))
+        assert d.str == ('<' if sys.byteorder == 'little' else '>') + 'f8'
+        assert d.fields is None
+        assert d.shape == ()
+        assert d.itemsize == 8
+        assert d.subdtype is None
+        assert repr(d) == "dtype('float64')"
+
+        d = np.dtype(('>' if sys.byteorder == 'little' else '<') + 'f8')
+        d.__setstate__((3, '|', (np.dtype('float64'), (2,)), None, None, 20, 
1, 0))
+        assert d.str == '|f8'
+        assert d.fields is None
+        assert d.shape == (2,)
+        assert d.itemsize == 8
+        assert d.subdtype is not None
+        assert repr(d) == "dtype(('<f8', (2,)))"
+
+        d = np.dtype(('<f8', 2))
+        assert d.fields is None
+        assert d.shape == (2,)
+        assert d.itemsize == 16
+        assert d.subdtype is not None
+        assert repr(d) == "dtype(('<f8', (2,)))"
+
+        d = np.dtype(('<f8', 2))
+        d.__setstate__((3, '|', (np.dtype('float64'), (2,)), None, None, 20, 
1, 0))
+        assert d.fields is None
+        assert d.shape == (2,)
+        assert d.itemsize == 20
+        assert d.subdtype is not None
+        assert repr(d) == "dtype(('<f8', (2,)))"
+
+        d = np.dtype(('<f8', 2))
+        d.__setstate__((3, '|', (np.dtype('float64'), 2), None, None, 20, 1, 
0))
+        assert d.fields is None
+        assert d.shape == (2,)
+        assert d.itemsize == 20
+        assert d.subdtype is not None
+        assert repr(d) == "dtype(('<f8', (2,)))"
+
+        d = np.dtype(('<f8', 2))
+        exc = raises(ValueError, "d.__setstate__((3, '|', None, ('f0', 'f1'), 
None, 16, 1, 0))")
+        assert exc.value[0] == 'inconsistent fields and names'
+        assert d.fields is None
+        assert d.shape == (2,)
+        assert d.subdtype is not None
+        assert repr(d) == "dtype(('<f8', (2,)))"
+
+        d = np.dtype(('<f8', 2))
+        exc = raises(ValueError, "d.__setstate__((3, '|', None, None, {'f0': 
(np.dtype('float64'), 0), 'f1': (np.dtype('float64'), 8)}, 16, 1, 0))")
+        assert exc.value[0] == 'inconsistent fields and names'
+        assert d.fields is None
+        assert d.shape == (2,)
+        assert d.subdtype is not None
+        assert repr(d) == "dtype(('<f8', (2,)))"
+
+        d = np.dtype(('<f8', 2))
+        exc = raises(ValueError, "d.__setstate__((3, '|', 
(np.dtype('float64'), (2,), 3), ('f0', 'f1'), {'f0': (np.dtype('float64'), 0), 
'f1': (np.dtype('float64'), 8)}, 16, 1, 0))")
+        assert exc.value[0] == 'incorrect subarray in __setstate__'
+        assert d.fields is None
+        assert d.shape == ()
+        assert d.subdtype is None
+        assert repr(d) == "dtype('V16')"
+
+        d = np.dtype(('<f8', 2))
+        d.__setstate__((3, '|', (np.dtype('float64'), (2,)), ('f0', 'f1'), 
{'f0': (np.dtype('float64'), 0), 'f1': (np.dtype('float64'), 8)}, 16, 1, 0))
+        assert d.fields is not None
+        assert d.shape == (2,)
+        assert d.subdtype is not None
+        assert repr(d) == "dtype([('f0', '<f8'), ('f1', '<f8')])"
+
+        d = np.dtype(('<f8', 2))
+        d.__setstate__((3, '|', None, ('f0', 'f1'), {'f0': 
(np.dtype('float64'), 0), 'f1': (np.dtype('float64'), 8)}, 16, 1, 0))
+        assert d.fields is not None
+        assert d.shape == ()
+        assert d.subdtype is None
+        assert repr(d) == "dtype([('f0', '<f8'), ('f1', '<f8')])"
+
+        d = np.dtype(('<f8', 2))
+        d.__setstate__((3, '|', None, ('f0', 'f1'), {'f0': 
(np.dtype('float64'), 0), 'f1': (np.dtype('float64'), 8)}, 16, 1, 0))
+        d.__setstate__((3, '|', (np.dtype('float64'), (2,)), None, None, 16, 
1, 0))
+        assert d.fields is not None
+        assert d.shape == (2,)
+        assert d.subdtype is not None
+        assert repr(d) == "dtype([('f0', '<f8'), ('f1', '<f8')])"
+
     def test_pickle_record(self):
         from numpypy import array, dtype
         from cPickle import loads, dumps
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to