Author: Maciej Fijalkowski <fij...@gmail.com> Branch: numpy-refactor Changeset: r57185:7ce8d4dbc076 Date: 2012-09-06 19:36 +0200 http://bitbucket.org/pypy/pypy/changeset/7ce8d4dbc076/
Log: some fixes for __getitem__ diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py b/pypy/module/micronumpy/arrayimpl/concrete.py --- a/pypy/module/micronumpy/arrayimpl/concrete.py +++ b/pypy/module/micronumpy/arrayimpl/concrete.py @@ -31,7 +31,6 @@ def next_skip_x(self, x): self.offset += self.skip * x - self.index += x def done(self): return self.offset >= self.size diff --git a/pypy/module/micronumpy/interp_flatiter.py b/pypy/module/micronumpy/interp_flatiter.py --- a/pypy/module/micronumpy/interp_flatiter.py +++ b/pypy/module/micronumpy/interp_flatiter.py @@ -1,15 +1,23 @@ +from pypy.module.micronumpy.base import W_NDimArray +from pypy.module.micronumpy import loop +from pypy.module.micronumpy.strides import to_coords from pypy.interpreter.baseobjspace import Wrappable from pypy.interpreter.error import OperationError -from pypy.interpreter.typedef import TypeDef, interp2app -from pypy.rlib import jit +from pypy.interpreter.typedef import TypeDef, interp2app, GetSetProperty class W_FlatIterator(Wrappable): def __init__(self, arr): self.base = arr - self.iter = arr.create_iter() + self.reset() + + def reset(self): + self.iter = self.base.create_iter() self.index = 0 + def descr_len(self, space): + return space.wrap(self.base.get_size()) + def descr_next(self, space): if self.iter.done(): raise OperationError(space.w_StopIteration, space.w_None) @@ -18,44 +26,47 @@ self.index += 1 return w_res - @jit.unroll_safe + def descr_index(self, space): + return space.wrap(self.index) + + def descr_coords(self, space): + coords, step, lngth = to_coords(space, self.base.get_shape(), + self.base.get_size(), self.base.get_order(), + space.wrap(self.index)) + return space.newtuple([space.wrap(c) for c in coords]) + def descr_getitem(self, space, w_idx): if not (space.isinstance_w(w_idx, space.w_int) or space.isinstance_w(w_idx, space.w_slice)): raise OperationError(space.w_IndexError, space.wrap('unsupported iterator index')) + self.reset() base = self.base start, stop, step, length = space.decode_index4(w_idx, base.get_size()) # setslice would have been better, but flat[u:v] for arbitrary # shapes of array a cannot be represented as a[x1:x2, y1:y2] base_iter = base.create_iter() - xxx - return base.getitem(basei.offset) - base_iter = ViewIterator(base.start, base.strides, - base.backstrides, base.shape) - shapelen = len(base.shape) - basei = basei.next_skip_x(shapelen, start) - res = W_NDimArray([lngth], base.dtype, base.order) - ri = res.create_iter() - while not ri.done(): - flat_get_driver.jit_merge_point(shapelen=shapelen, - base=base, - basei=basei, - step=step, - res=res, - ri=ri) - w_val = base.getitem(basei.offset) - res.setitem(ri.offset, w_val) - basei = basei.next_skip_x(shapelen, step) - ri = ri.next(shapelen) - return res + base_iter.next_skip_x(start) + if length == 1: + return base_iter.getitem() + res = W_NDimArray.from_shape([length], base.get_dtype(), + base.get_order()) + return loop.flatiter_getitem(res, base_iter, step) def descr_iter(self): return self + def descr_base(self, space): + return space.wrap(self.base) + W_FlatIterator.typedef = TypeDef( 'flatiter', - __iter__ = interp2app(W_FlatIterator.descr_iter), + __iter__ = interp2app(W_FlatIterator.descr_iter), + __getitem__ = interp2app(W_FlatIterator.descr_getitem), + __len__ = interp2app(W_FlatIterator.descr_len), next = interp2app(W_FlatIterator.descr_next), + base = GetSetProperty(W_FlatIterator.descr_base), + index = GetSetProperty(W_FlatIterator.descr_index), + coords = GetSetProperty(W_FlatIterator.descr_coords), ) diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -40,6 +40,9 @@ def get_dtype(self): return self.implementation.dtype + def get_order(self): + return self.implementation.order + def descr_get_dtype(self, space): return self.implementation.dtype diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py --- a/pypy/module/micronumpy/loop.py +++ b/pypy/module/micronumpy/loop.py @@ -186,3 +186,11 @@ arr_iter.next() index_iter.next() value_iter.next() + +def flatiter_getitem(res, base_iter, step): + ri = res.create_iter() + while not ri.done(): + ri.setitem(base_iter.getitem()) + base_iter.next_skip_x(step) + ri.next() + return res diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -1814,6 +1814,8 @@ b.next() b.next() b.next() + assert b.index == 3 + assert b.coords == (0, 3) assert b[3] == 3 assert (b[::3] == [0, 3, 6, 9]).all() assert (b[2::5] == [2, 7]).all() @@ -1822,7 +1824,7 @@ raises(IndexError, "b[-11]") raises(IndexError, 'b[0, 1]') assert b.index == 0 - assert b.coords == (0,0) + assert b.coords == (0, 0) def test_flatiter_setitem(self): from _numpypy import arange, array _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit