Author: Brian Kearns <bdkea...@gmail.com> Branch: Changeset: r70763:142661a32d5b Date: 2014-04-18 17:05 -0400 http://bitbucket.org/pypy/pypy/changeset/142661a32d5b/
Log: support more variations of ndarray.item() diff --git a/pypy/module/micronumpy/ndarray.py b/pypy/module/micronumpy/ndarray.py --- a/pypy/module/micronumpy/ndarray.py +++ b/pypy/module/micronumpy/ndarray.py @@ -18,7 +18,7 @@ multi_axis_converter from pypy.module.micronumpy.flagsobj import W_FlagsObject from pypy.module.micronumpy.flatiter import W_FlatIterator -from pypy.module.micronumpy.strides import get_shape_from_iterable, to_coords, \ +from pypy.module.micronumpy.strides import get_shape_from_iterable, \ shape_agreement, shape_agreement_multiple @@ -469,29 +469,33 @@ def descr_get_flatiter(self, space): return space.wrap(W_FlatIterator(self)) - def to_coords(self, space, w_index): - coords, _, _ = to_coords(space, self.get_shape(), - self.get_size(), self.get_order(), - w_index) - return coords - - def descr_item(self, space, w_arg=None): - if space.is_none(w_arg): + def descr_item(self, space, __args__): + args_w, kw_w = __args__.unpack() + if len(args_w) == 1 and space.isinstance_w(args_w[0], space.w_tuple): + args_w = space.fixedview(args_w[0]) + shape = self.get_shape() + coords = [0] * len(shape) + if len(args_w) == 0: if self.get_size() == 1: w_obj = self.get_scalar_value() assert isinstance(w_obj, boxes.W_GenericBox) return w_obj.item(space) raise oefmt(space.w_ValueError, "can only convert an array of size 1 to a Python scalar") - if space.isinstance_w(w_arg, space.w_int): - if self.is_scalar(): - raise oefmt(space.w_IndexError, "index out of bounds") - i = self.to_coords(space, w_arg) - item = self.getitem(space, i) - assert isinstance(item, boxes.W_GenericBox) - return item.item(space) - raise OperationError(space.w_NotImplementedError, space.wrap( - "non-int arg not supported")) + elif len(args_w) == 1 and len(shape) != 1: + value = support.index_w(space, args_w[0]) + value = support.check_and_adjust_index(space, value, self.get_size(), -1) + for idim in range(len(shape) - 1, -1, -1): + coords[idim] = value % shape[idim] + value //= shape[idim] + elif len(args_w) == len(shape): + for idim in range(len(shape)): + coords[idim] = support.index_w(space, args_w[idim]) + else: + raise oefmt(space.w_ValueError, "incorrect number of indices for array") + item = self.getitem(space, coords) + assert isinstance(item, boxes.W_GenericBox) + return item.item(space) def descr_itemset(self, space, args_w): if len(args_w) == 0: diff --git a/pypy/module/micronumpy/strides.py b/pypy/module/micronumpy/strides.py --- a/pypy/module/micronumpy/strides.py +++ b/pypy/module/micronumpy/strides.py @@ -233,30 +233,6 @@ return dtype -def to_coords(space, shape, size, order, w_item_or_slice): - '''Returns a start coord, step, and length. - ''' - start = lngth = step = 0 - if not (space.isinstance_w(w_item_or_slice, space.w_int) or - space.isinstance_w(w_item_or_slice, space.w_slice)): - raise OperationError(space.w_IndexError, - space.wrap('unsupported iterator index')) - - start, stop, step, lngth = space.decode_index4(w_item_or_slice, size) - - coords = [0] * len(shape) - i = start - if order == 'C': - for s in range(len(shape) -1, -1, -1): - coords[s] = i % shape[s] - i //= shape[s] - else: - for s in range(len(shape)): - coords[s] = i % shape[s] - i //= shape[s] - return coords, step, lngth - - @jit.unroll_safe def shape_agreement(space, shape1, w_arr2, broadcast_down=True): if w_arr2 is None: diff --git a/pypy/module/micronumpy/support.py b/pypy/module/micronumpy/support.py --- a/pypy/module/micronumpy/support.py +++ b/pypy/module/micronumpy/support.py @@ -25,3 +25,18 @@ for x in s: i *= x return i + + +def check_and_adjust_index(space, index, size, axis): + if index < -size or index >= size: + if axis >= 0: + raise oefmt(space.w_IndexError, + "index %d is out of bounds for axis %d with size %d", + index, axis, size) + else: + raise oefmt(space.w_IndexError, + "index %d is out of bounds for size %d", + index, size) + if index < 0: + index += size + return index diff --git a/pypy/module/micronumpy/test/test_ndarray.py b/pypy/module/micronumpy/test/test_ndarray.py --- a/pypy/module/micronumpy/test/test_ndarray.py +++ b/pypy/module/micronumpy/test/test_ndarray.py @@ -164,24 +164,6 @@ assert calc_new_strides([1, 1, 105, 1, 1], [7, 15], [1, 7],'F') == \ [1, 1, 1, 105, 105] - def test_to_coords(self): - from pypy.module.micronumpy.strides import to_coords - - def _to_coords(index, order): - return to_coords(self.space, [2, 3, 4], 24, order, - self.space.wrap(index))[0] - - assert _to_coords(0, 'C') == [0, 0, 0] - assert _to_coords(1, 'C') == [0, 0, 1] - assert _to_coords(-1, 'C') == [1, 2, 3] - assert _to_coords(5, 'C') == [0, 1, 1] - assert _to_coords(13, 'C') == [1, 0, 1] - assert _to_coords(0, 'F') == [0, 0, 0] - assert _to_coords(1, 'F') == [1, 0, 0] - assert _to_coords(-1, 'F') == [1, 2, 3] - assert _to_coords(5, 'F') == [1, 2, 0] - assert _to_coords(13, 'F') == [1, 0, 2] - def test_find_shape(self): from pypy.module.micronumpy.strides import find_shape_and_elems @@ -2988,12 +2970,14 @@ raises((IndexError, ValueError), "a.compress([1] * 100)") def test_item(self): + import numpy as np from numpypy import array assert array(3).item() == 3 assert type(array(3).item()) is int assert type(array(True).item()) is bool assert type(array(3.5).item()) is float - raises(IndexError, "array(3).item(15)") + exc = raises(IndexError, "array(3).item(15)") + assert str(exc.value) == 'index 15 is out of bounds for size 1' raises(ValueError, "array([1, 2, 3]).item()") assert array([3]).item(0) == 3 assert type(array([3]).item(0)) is int @@ -3012,6 +2996,11 @@ assert type(b[1]) is str assert b[0] == 1 assert b[1] == 'ab' + a = np.arange(24).reshape(2, 4, 3) + assert a.item(1, 1, 1) == 16 + assert a.item((1, 1, 1)) == 16 + exc = raises(ValueError, a.item, 1, 1, 1, 1) + assert str(exc.value) == "incorrect number of indices for array" def test_itemset(self): import numpy as np _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit