Author: Brian Kearns <bdkea...@gmail.com> Branch: Changeset: r70761:667ad75d7ce9 Date: 2014-04-18 14:51 -0400 http://bitbucket.org/pypy/pypy/changeset/667ad75d7ce9/
Log: simplify nditer diff --git a/pypy/module/micronumpy/iterators.py b/pypy/module/micronumpy/iterators.py --- a/pypy/module/micronumpy/iterators.py +++ b/pypy/module/micronumpy/iterators.py @@ -164,38 +164,6 @@ self.array.setitem(state.offset, elem) -class SliceIterator(ArrayIter): - def __init__(self, arr, strides, backstrides, shape, order="C", - backward=False, dtype=None): - if dtype is None: - dtype = arr.implementation.dtype - self.dtype = dtype - self.arr = arr - if backward: - self.slicesize = shape[0] - self.gap = [support.product(shape[1:]) * dtype.elsize] - strides = strides[1:] - backstrides = backstrides[1:] - shape = shape[1:] - strides.reverse() - backstrides.reverse() - shape.reverse() - size = support.product(shape) - else: - shape = [support.product(shape)] - strides, backstrides = calc_strides(shape, dtype, order) - size = 1 - self.slicesize = support.product(shape) - self.gap = strides - ArrayIter.__init__(self, arr.implementation, size, shape, strides, backstrides) - - def getslice(self): - from pypy.module.micronumpy.concrete import SliceArray - return SliceArray(self.offset, self.gap, self.backstrides, - [self.slicesize], self.arr.implementation, - self.arr, self.dtype) - - def AxisIter(array, shape, axis, cumulative): strides = array.get_strides() backstrides = array.get_backstrides() diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py --- a/pypy/module/micronumpy/nditer.py +++ b/pypy/module/micronumpy/nditer.py @@ -5,56 +5,33 @@ from pypy.module.micronumpy import ufuncs, support, concrete from pypy.module.micronumpy.base import W_NDimArray, convert_to_array from pypy.module.micronumpy.descriptor import decode_w_dtype -from pypy.module.micronumpy.iterators import ArrayIter, SliceIterator +from pypy.module.micronumpy.iterators import ArrayIter from pypy.module.micronumpy.strides import (calculate_broadcast_strides, shape_agreement, shape_agreement_multiple) -class Iterator(object): - def __init__(self, nditer, index, it, op_flags): - self.nditer = nditer - self.index = index - self.it = it - self.st = it.reset() - self.op_flags = op_flags - - def done(self): - return self.it.done(self.st) - - def next(self): - self.st = self.it.next(self.st) - - def getitem(self, space, array): - return self.op_flags.get_it_item[self.index](space, self.nditer, self.it, self.st) - - def setitem(self, space, array, val): - xxx - - def parse_op_arg(space, name, w_op_flags, n, parse_one_arg): - ret = [] if space.is_w(w_op_flags, space.w_None): - for i in range(n): - ret.append(OpFlag()) - elif not space.isinstance_w(w_op_flags, space.w_tuple) and not \ + w_op_flags = space.newtuple([space.wrap('readonly')]) + if not space.isinstance_w(w_op_flags, space.w_tuple) and not \ space.isinstance_w(w_op_flags, space.w_list): raise oefmt(space.w_ValueError, '%s must be a tuple or array of per-op flag-tuples', name) + ret = [] + w_lst = space.listview(w_op_flags) + if space.isinstance_w(w_lst[0], space.w_tuple) or \ + space.isinstance_w(w_lst[0], space.w_list): + if len(w_lst) != n: + raise oefmt(space.w_ValueError, + '%s must be a tuple or array of per-op flag-tuples', + name) + for item in w_lst: + ret.append(parse_one_arg(space, space.listview(item))) else: - w_lst = space.listview(w_op_flags) - if space.isinstance_w(w_lst[0], space.w_tuple) or \ - space.isinstance_w(w_lst[0], space.w_list): - if len(w_lst) != n: - raise oefmt(space.w_ValueError, - '%s must be a tuple or array of per-op flag-tuples', - name) - for item in w_lst: - ret.append(parse_one_arg(space, space.listview(item))) - else: - op_flag = parse_one_arg(space, w_lst) - for i in range(n): - ret.append(op_flag) + op_flag = parse_one_arg(space, w_lst) + for i in range(n): + ret.append(op_flag) return ret @@ -67,29 +44,6 @@ self.native_byte_order = False self.tmp_copy = '' self.allocate = False - self.get_it_item = (get_readonly_item, get_readonly_slice) - - -def get_readonly_item(space, nditer, it, st): - res = concrete.ConcreteNonWritableArrayWithBase( - [], it.array.dtype, it.array.order, [], [], it.array.storage, nditer) - res.start = st.offset - return W_NDimArray(res) - - -def get_readwrite_item(space, nditer, it, st): - res = concrete.ConcreteArrayWithBase( - [], it.array.dtype, it.array.order, [], [], it.array.storage, nditer) - res.start = st.offset - return W_NDimArray(res) - - -def get_readonly_slice(space, array, it): - return W_NDimArray(it.getslice().readonly()) - - -def get_readwrite_slice(space, array, it): - return W_NDimArray(it.getslice()) def parse_op_flag(space, lst): @@ -128,17 +82,10 @@ else: raise OperationError(space.w_ValueError, space.wrap( 'op_flags must be a tuple or array of per-op flag-tuples')) - if op_flag.rw == '': - raise oefmt(space.w_ValueError, - "None of the iterator flags READWRITE, READONLY, or " - "WRITEONLY were specified for an operand") - elif op_flag.rw == 'r': - op_flag.get_it_item = (get_readonly_item, get_readonly_slice) - elif op_flag.rw == 'rw': - op_flag.get_it_item = (get_readwrite_item, get_readwrite_slice) - elif op_flag.rw == 'w': - # XXX Extra logic needed to make sure writeonly - op_flag.get_it_item = (get_readwrite_item, get_readwrite_slice) + if op_flag.rw == '': + raise oefmt(space.w_ValueError, + "None of the iterator flags READWRITE, READONLY, or " + "WRITEONLY were specified for an operand") return op_flag @@ -230,12 +177,6 @@ return ArrayIter(imp, imp.get_size(), shape, r[0], r[1]) -def get_external_loop_iter(space, order, arr, shape): - imp = arr.implementation - backward = is_backward(imp, order) - return SliceIterator(arr, imp.strides, imp.backstrides, shape, order=order, backward=backward) - - class IndexIterator(object): def __init__(self, shape, backward=False): self.shape = shape @@ -326,8 +267,6 @@ out_dtype = None for i in range(len(self.seq)): if self.seq[i] is None: - self.op_flags[i].get_it_item = (get_readwrite_item, - get_readwrite_slice) self.op_flags[i].allocate = True continue if self.op_flags[i].rw == 'w': @@ -372,20 +311,9 @@ self.dtypes = [s.get_dtype() for s in self.seq] # create an iterator for each operand - if self.external_loop: - for i in range(len(self.seq)): - self.iters.append(Iterator( - self, 1, - get_external_loop_iter( - space, self.order, self.seq[i], iter_shape), - self.op_flags[i])) - else: - for i in range(len(self.seq)): - self.iters.append(Iterator( - self, 0, - get_iter( - space, self.order, self.seq[i], iter_shape, self.dtypes[i]), - self.op_flags[i])) + for i in range(len(self.seq)): + it = get_iter(space, self.order, self.seq[i], iter_shape, self.dtypes[i]) + self.iters.append((it, it.reset())) def set_op_axes(self, space, w_op_axes): if space.len_w(w_op_axes) != len(self.seq): @@ -417,14 +345,24 @@ def descr_iter(self, space): return space.wrap(self) + def getitem(self, it, st, op_flags): + if op_flags.rw == 'r': + impl = concrete.ConcreteNonWritableArrayWithBase + else: + impl = concrete.ConcreteArrayWithBase + res = impl([], it.array.dtype, it.array.order, [], [], + it.array.storage, self) + res.start = st.offset + return W_NDimArray(res) + def descr_getitem(self, space, w_idx): idx = space.int_w(w_idx) try: - ret = space.wrap(self.iters[idx].getitem(space, self.seq[idx])) + it, st = self.iters[idx] except IndexError: raise oefmt(space.w_IndexError, "Iterator operand index %d is out of bounds", idx) - return ret + return self.getitem(it, st, self.op_flags[idx]) def descr_setitem(self, space, w_idx, w_value): raise oefmt(space.w_NotImplementedError, "not implemented yet") @@ -433,8 +371,8 @@ space.wrap(len(self.iters)) def descr_next(self, space): - for it in self.iters: - if not it.done(): + for it, st in self.iters: + if not it.done(st): break else: self.done = True @@ -445,9 +383,9 @@ self.index_iter.next() else: self.first_next = False - for i in range(len(self.iters)): - res.append(self.iters[i].getitem(space, self.seq[i])) - self.iters[i].next() + for i, (it, st) in enumerate(self.iters): + res.append(self.getitem(it, st, self.op_flags[i])) + self.iters[i] = (it, it.next(st)) if len(res) < 2: return res[0] return space.newtuple(res) @@ -455,10 +393,10 @@ def iternext(self): if self.index_iter: self.index_iter.next() - for i in range(len(self.iters)): - self.iters[i].next() - for it in self.iters: - if not it.done(): + for i, (it, st) in enumerate(self.iters): + self.iters[i] = (it, it.next(st)) + for it, st in self.iters: + if not it.done(st): break else: self.done = True _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit