Author: Brian Kearns <[email protected]>
Branch:
Changeset: r69346:12fa4e02e3cf
Date: 2014-02-24 02:15 -0500
http://bitbucket.org/pypy/pypy/changeset/12fa4e02e3cf/
Log: improve dtype setstate functionality
diff --git a/pypy/module/micronumpy/interp_dtype.py
b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -37,13 +37,15 @@
class W_Dtype(W_Root):
- _immutable_fields_ = ["itemtype?", "num", "kind", "name", "char",
- "w_box_type", "byteorder", "size?", "float_type",
- "fields?", "fieldnames?", "shape", "subdtype",
"base"]
+ _immutable_fields_ = [
+ "num", "kind", "name", "char", "w_box_type", "float_type",
+ "itemtype?", "byteorder?", "fields?", "fieldnames?", "size?",
+ "shape?", "subdtype?", "base?"
+ ]
def __init__(self, itemtype, num, kind, name, char, w_box_type,
byteorder=NPY.NATIVE,
size=1, alternate_constructors=[], aliases=[],
float_type=None,
- fields=None, fieldnames=None, shape=[], subdtype=None):
+ fields={}, fieldnames=[], shape=[], subdtype=None):
self.itemtype = itemtype
self.num = num
self.kind = kind
@@ -56,10 +58,8 @@
self.aliases = aliases
self.float_type = float_type
self.fields = fields
- if fieldnames is None:
- fieldnames = []
self.fieldnames = fieldnames
- self.shape = list(shape)
+ self.shape = shape
self.subdtype = subdtype
if not subdtype:
self.base = self
@@ -102,16 +102,16 @@
return self.kind == NPY.GENBOOLLTR
def is_record_type(self):
- return self.fields is not None
+ return bool(self.fields)
def is_str_type(self):
return self.num == NPY.STRING
def is_str_or_unicode(self):
- return (self.num == NPY.STRING or self.num == NPY.UNICODE)
+ return self.num == NPY.STRING or self.num == NPY.UNICODE
def is_flexible_type(self):
- return (self.is_str_or_unicode() or self.is_record_type())
+ return self.is_str_or_unicode() or self.num == NPY.VOID
def is_native(self):
return self.byteorder in (NPY.NATIVE, NPY.NATBYTE)
@@ -125,28 +125,29 @@
return get_dtype_cache(space).dtypes_by_name[self.byteorder +
self.float_type]
def descr_str(self, space):
- if not self.num == NPY.VOID:
- if self.char == 'S':
- return space.wrap('|S' + str(self.get_size()))
- else:
- return self.descr_get_name(space)
+ if self.fields:
+ return space.str(self.descr_get_descr(space))
elif self.subdtype is not None:
return space.str(space.newtuple([
self.subdtype.descr_get_str(space),
self.descr_get_shape(space)]))
- return space.str(self.descr_get_descr(space))
+ else:
+ if self.is_flexible_type():
+ return space.wrap('|' + self.char + str(self.get_size()))
+ else:
+ return self.descr_get_name(space)
def descr_repr(self, space):
- if not self.num == NPY.VOID:
- if self.char == 'S':
- r = space.wrap('S' + str(self.get_size()))
- else:
- r = self.descr_get_name(space)
+ if self.fields:
+ r = self.descr_get_descr(space)
elif self.subdtype is not None:
r = space.newtuple([self.subdtype.descr_get_str(space),
self.descr_get_shape(space)])
else:
- r = self.descr_get_descr(space)
+ if self.is_flexible_type():
+ r = space.wrap(self.char + str(self.get_size()))
+ else:
+ r = self.descr_get_name(space)
return space.wrap("dtype(%s)" % space.str_w(space.repr(r)))
def descr_get_itemsize(self, space):
@@ -224,7 +225,7 @@
return space.wrap(not self.eq(space, w_other))
def descr_get_fields(self, space):
- if self.fields is None:
+ if not self.fields:
return space.w_None
w_d = space.newdict()
for name, (offset, subdtype) in self.fields.iteritems():
@@ -232,34 +233,13 @@
space.newtuple([subdtype, space.wrap(offset)]))
return w_d
- def descr_set_fields(self, space, w_fieldnames, w_fields):
- if w_fields == space.w_None:
- self.fields = None
- else:
- self.fieldnames = []
- self.fields = {}
- size = 0
- for w_name in space.fixedview(w_fieldnames):
- name = space.str_w(w_name)
- value = space.getitem(w_fields, w_name)
-
- dtype = space.getitem(value, space.wrap(0))
- assert isinstance(dtype, W_Dtype)
- offset = space.int_w(space.getitem(value, space.wrap(1)))
-
- self.fieldnames.append(name)
- self.fields[name] = offset, dtype
- size += dtype.get_size()
- self.itemtype = types.RecordType()
- self.size = size
-
def descr_get_names(self, space):
- if len(self.fieldnames) == 0:
+ if not self.fields:
return space.w_None
return space.newtuple([space.wrap(name) for name in self.fieldnames])
def descr_set_names(self, space, w_names):
- if len(self.fieldnames) == 0:
+ if not self.fields:
raise oefmt(space.w_ValueError, "there are no fields defined")
if not space.issequence_w(w_names) or \
space.len_w(w_names) != len(self.fieldnames):
@@ -354,17 +334,63 @@
return space.newtuple([w_class, builder_args, data])
def descr_setstate(self, space, w_data):
- if space.int_w(space.getitem(w_data, space.wrap(0))) != 3:
- raise OperationError(space.w_NotImplementedError,
space.wrap("Pickling protocol version not supported"))
+ if self.fields is None: # if builtin dtype
+ return space.w_None
+
+ version = space.int_w(space.getitem(w_data, space.wrap(0)))
+ if version != 3:
+ raise oefmt(space.w_ValueError,
+ "can't handle version %d of numpy.dtype pickle",
+ version)
endian = space.str_w(space.getitem(w_data, space.wrap(1)))
if endian == NPY.NATBYTE:
endian = NPY.NATIVE
- self.byteorder = endian
+ w_subarray = space.getitem(w_data, space.wrap(2))
w_fieldnames = space.getitem(w_data, space.wrap(3))
w_fields = space.getitem(w_data, space.wrap(4))
- self.descr_set_fields(space, w_fieldnames, w_fields)
+ size = space.int_w(space.getitem(w_data, space.wrap(5)))
+
+ if (w_fieldnames == space.w_None) != (w_fields == space.w_None):
+ raise oefmt(space.w_ValueError, "inconsistent fields and names")
+
+ self.byteorder = endian
+ self.shape = []
+ self.subdtype = None
+ self.base = self
+
+ if w_subarray != space.w_None:
+ if not space.isinstance_w(w_subarray, space.w_tuple) or \
+ space.len_w(w_subarray) != 2:
+ raise oefmt(space.w_ValueError,
+ "incorrect subarray in __setstate__")
+ subdtype, w_shape = space.fixedview(w_subarray)
+ assert isinstance(subdtype, W_Dtype)
+ if not base.issequence_w(space, w_shape):
+ self.shape = [space.int_w(w_shape)]
+ else:
+ self.shape = [space.int_w(w_s) for w_s in
space.fixedview(w_shape)]
+ self.subdtype = subdtype
+ self.base = subdtype.base
+
+ if w_fieldnames != space.w_None:
+ self.fieldnames = []
+ self.fields = {}
+ for w_name in space.fixedview(w_fieldnames):
+ name = space.str_w(w_name)
+ value = space.getitem(w_fields, w_name)
+
+ dtype = space.getitem(value, space.wrap(0))
+ assert isinstance(dtype, W_Dtype)
+ offset = space.int_w(space.getitem(value, space.wrap(1)))
+
+ self.fieldnames.append(name)
+ self.fields[name] = offset, dtype
+ self.itemtype = types.RecordType()
+
+ if self.is_flexible_type():
+ self.size = size
@unwrap_spec(new_order=str)
def descr_newbyteorder(self, space, new_order=NPY.SWAP):
@@ -862,6 +888,8 @@
float_type=dtype.float_type)
for alias in dtype.aliases:
self.dtypes_by_name[alias] = dtype
+ for dtype in self.dtypes_by_name.values():
+ dtype.fields = None # mark these as builtin
typeinfo_full = {
'LONGLONG': self.w_int64dtype,
diff --git a/pypy/module/micronumpy/test/test_dtypes.py
b/pypy/module/micronumpy/test/test_dtypes.py
--- a/pypy/module/micronumpy/test/test_dtypes.py
+++ b/pypy/module/micronumpy/test/test_dtypes.py
@@ -1096,6 +1096,96 @@
assert dt.subdtype == (dtype(float), (10,))
assert dt.base == dtype(float)
+ def test_setstate(self):
+ import numpy as np
+ import sys
+ d = np.dtype('f8')
+ d.__setstate__((3, '|', (np.dtype('float64'), (2,)), None, None, 20,
1, 0))
+ assert d.str == ('<' if sys.byteorder == 'little' else '>') + 'f8'
+ assert d.fields is None
+ assert d.shape == ()
+ assert d.itemsize == 8
+ assert d.subdtype is None
+ assert repr(d) == "dtype('float64')"
+
+ d = np.dtype(('>' if sys.byteorder == 'little' else '<') + 'f8')
+ d.__setstate__((3, '|', (np.dtype('float64'), (2,)), None, None, 20,
1, 0))
+ assert d.str == '|f8'
+ assert d.fields is None
+ assert d.shape == (2,)
+ assert d.itemsize == 8
+ assert d.subdtype is not None
+ assert repr(d) == "dtype(('<f8', (2,)))"
+
+ d = np.dtype(('<f8', 2))
+ assert d.fields is None
+ assert d.shape == (2,)
+ assert d.itemsize == 16
+ assert d.subdtype is not None
+ assert repr(d) == "dtype(('<f8', (2,)))"
+
+ d = np.dtype(('<f8', 2))
+ d.__setstate__((3, '|', (np.dtype('float64'), (2,)), None, None, 20,
1, 0))
+ assert d.fields is None
+ assert d.shape == (2,)
+ assert d.itemsize == 20
+ assert d.subdtype is not None
+ assert repr(d) == "dtype(('<f8', (2,)))"
+
+ d = np.dtype(('<f8', 2))
+ d.__setstate__((3, '|', (np.dtype('float64'), 2), None, None, 20, 1,
0))
+ assert d.fields is None
+ assert d.shape == (2,)
+ assert d.itemsize == 20
+ assert d.subdtype is not None
+ assert repr(d) == "dtype(('<f8', (2,)))"
+
+ d = np.dtype(('<f8', 2))
+ exc = raises(ValueError, "d.__setstate__((3, '|', None, ('f0', 'f1'),
None, 16, 1, 0))")
+ assert exc.value[0] == 'inconsistent fields and names'
+ assert d.fields is None
+ assert d.shape == (2,)
+ assert d.subdtype is not None
+ assert repr(d) == "dtype(('<f8', (2,)))"
+
+ d = np.dtype(('<f8', 2))
+ exc = raises(ValueError, "d.__setstate__((3, '|', None, None, {'f0':
(np.dtype('float64'), 0), 'f1': (np.dtype('float64'), 8)}, 16, 1, 0))")
+ assert exc.value[0] == 'inconsistent fields and names'
+ assert d.fields is None
+ assert d.shape == (2,)
+ assert d.subdtype is not None
+ assert repr(d) == "dtype(('<f8', (2,)))"
+
+ d = np.dtype(('<f8', 2))
+ exc = raises(ValueError, "d.__setstate__((3, '|',
(np.dtype('float64'), (2,), 3), ('f0', 'f1'), {'f0': (np.dtype('float64'), 0),
'f1': (np.dtype('float64'), 8)}, 16, 1, 0))")
+ assert exc.value[0] == 'incorrect subarray in __setstate__'
+ assert d.fields is None
+ assert d.shape == ()
+ assert d.subdtype is None
+ assert repr(d) == "dtype('V16')"
+
+ d = np.dtype(('<f8', 2))
+ d.__setstate__((3, '|', (np.dtype('float64'), (2,)), ('f0', 'f1'),
{'f0': (np.dtype('float64'), 0), 'f1': (np.dtype('float64'), 8)}, 16, 1, 0))
+ assert d.fields is not None
+ assert d.shape == (2,)
+ assert d.subdtype is not None
+ assert repr(d) == "dtype([('f0', '<f8'), ('f1', '<f8')])"
+
+ d = np.dtype(('<f8', 2))
+ d.__setstate__((3, '|', None, ('f0', 'f1'), {'f0':
(np.dtype('float64'), 0), 'f1': (np.dtype('float64'), 8)}, 16, 1, 0))
+ assert d.fields is not None
+ assert d.shape == ()
+ assert d.subdtype is None
+ assert repr(d) == "dtype([('f0', '<f8'), ('f1', '<f8')])"
+
+ d = np.dtype(('<f8', 2))
+ d.__setstate__((3, '|', None, ('f0', 'f1'), {'f0':
(np.dtype('float64'), 0), 'f1': (np.dtype('float64'), 8)}, 16, 1, 0))
+ d.__setstate__((3, '|', (np.dtype('float64'), (2,)), None, None, 16,
1, 0))
+ assert d.fields is not None
+ assert d.shape == (2,)
+ assert d.subdtype is not None
+ assert repr(d) == "dtype([('f0', '<f8'), ('f1', '<f8')])"
+
def test_pickle_record(self):
from numpypy import array, dtype
from cPickle import loads, dumps
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit