Author: Romain Guillebert <[email protected]>
Branch: numpypy-nditer
Changeset: r64823:000f7ae8ea0c
Date: 2013-06-07 16:43 +0200
http://bitbucket.org/pypy/pypy/changeset/000f7ae8ea0c/
Log: Merge heads
diff --git a/pypy/module/micronumpy/interp_nditer.py
b/pypy/module/micronumpy/interp_nditer.py
--- a/pypy/module/micronumpy/interp_nditer.py
+++ b/pypy/module/micronumpy/interp_nditer.py
@@ -32,13 +32,13 @@
self.it.next()
def getitem(self, space, array):
- return self.op_flags.get_it_item(space, array, self.it)
+ return self.op_flags.get_it_item[self.index](space, array, self.it)
-class BoxIterator(IteratorMixin):
- pass
+class BoxIterator(IteratorMixin, AbstractIterator):
+ index = 0
-class SliceIterator(IteratorMixin):
- pass
+class ExternalLoopIterator(IteratorMixin, AbstractIterator):
+ index = 1
def parse_op_arg(space, name, w_op_flags, n, parse_one_arg):
ret = []
@@ -73,7 +73,7 @@
self.native_byte_order = False
self.tmp_copy = ''
self.allocate = False
- self.get_it_item = get_readonly_item
+ self.get_it_item = (get_readonly_item, get_readonly_slice)
def get_readonly_item(space, array, it):
return space.wrap(it.getitem())
@@ -128,9 +128,9 @@
raise OperationError(space.w_ValueError, space.wrap(
'op_flags must be a tuple or array of per-op flag-tuples'))
if op_flag.rw == 'r':
- op_flag.get_it_item = get_readonly_item
+ op_flag.get_it_item = (get_readonly_item, get_readonly_slice)
elif op_flag.rw == 'rw':
- op_flag.get_it_item = get_readwrite_item
+ op_flag.get_it_item = (get_readwrite_item, get_readwrite_slice)
return op_flag
def parse_func_flags(space, nditer, w_flags):
@@ -180,7 +180,8 @@
'Iterator flag EXTERNAL_LOOP cannot be used if an index or '
'multi-index is being tracked'))
-def get_iter(space, order, imp, shape):
+def get_iter(space, order, arr, shape):
+ imp = arr.implementation
if order == 'K' or (order == 'C' and imp.order == 'C'):
backward = False
elif order =='F' and imp.order == 'C':
@@ -201,6 +202,18 @@
shape, backward)
return MultiDimViewIterator(imp, imp.dtype, imp.start, r[0], r[1], shape)
+def get_external_loop_iter(space, order, arr, shape):
+ imp = arr.implementation
+ if order == 'K' or (order == 'C' and imp.order == 'C'):
+ backward = False
+ elif order =='F' and imp.order == 'C':
+ backward = True
+ else:
+ raise OperationError(space.w_NotImplementedError, space.wrap(
+ 'not implemented yet'))
+
+ return SliceIterator(arr, imp.strides, imp.backstrides, shape,
order=order, backward=backward)
+
class W_NDIter(W_Root):
@@ -229,11 +242,13 @@
self.iters=[]
self.shape = iter_shape = shape_agreement_multiple(space, self.seq)
if self.external_loop:
- #XXX find longest contiguous shape
- iter_shape = iter_shape[1:]
- for i in range(len(self.seq)):
- self.iters.append(BoxIterator(get_iter(space, self.order,
- self.seq[i].implementation, iter_shape),
self.op_flags[i]))
+ for i in range(len(self.seq)):
+
self.iters.append(ExternalLoopIterator(get_external_loop_iter(space, self.order,
+ self.seq[i], iter_shape), self.op_flags[i]))
+ else:
+ for i in range(len(self.seq)):
+ self.iters.append(BoxIterator(get_iter(space, self.order,
+ self.seq[i], iter_shape), self.op_flags[i]))
def descr_iter(self, space):
return space.wrap(self)
diff --git a/pypy/module/micronumpy/iter.py b/pypy/module/micronumpy/iter.py
--- a/pypy/module/micronumpy/iter.py
+++ b/pypy/module/micronumpy/iter.py
@@ -46,6 +46,7 @@
calculate_slice_strides
from pypy.module.micronumpy.base import W_NDimArray
from pypy.module.micronumpy.arrayimpl import base
+from pypy.module.micronumpy import support
from rpython.rlib import jit
# structures to describe slicing
@@ -267,28 +268,49 @@
self.offset %= self.size
class SliceIterator(object):
- def __init__(self, arr, stride, backstride, shape, dtype=None):
- self.step = 0
+ def __init__(self, arr, strides, backstrides, shape, order="C",
backward=False, dtype=None):
+ self.indexes = [0] * (len(shape) - 1)
+ self.offset = 0
self.arr = arr
- self.stride = stride
- self.backstride = backstride
- self.shape = shape
if dtype is None:
dtype = arr.implementation.dtype
+ if backward:
+ self.slicesize = shape[0]
+ self.gap = [support.product(shape[1:]) * dtype.get_size()]
+ self.strides = strides[1:][::-1]
+ self.backstrides = backstrides[1:][::-1]
+ self.shape = shape[1:][::-1]
+ self.shapelen = len(self.shape)
+ else:
+ shape = [support.product(shape)]
+ self.strides, self.backstrides = support.calc_strides(shape,
dtype, order)
+ self.slicesize = support.product(shape)
+ self.shapelen = 0
+ self.gap = self.strides
self.dtype = dtype
self._done = False
- def done():
+ def done(self):
return self._done
- def next():
- self.step += self.arr.implementation.dtype.get_size()
- if self.step == self.backstride - self.implementation.dtype.get_size():
+ @jit.unroll_safe
+ def next(self):
+ offset = self.offset
+ for i in range(self.shapelen - 1, -1, -1):
+ if self.indexes[i] < self.shape[i] - 1:
+ self.indexes[i] += 1
+ offset += self.strides[i]
+ break
+ else:
+ self.indexes[i] = 0
+ offset -= self.backstrides[i]
+ else:
self._done = True
+ self.offset = offset
def getslice(self):
from pypy.module.micronumpy.arrayimpl.concrete import SliceArray
- return SliceArray(self.step, [self.stride], [self.backstride],
self.shape, self.arr.implementation, self.arr, self.dtype)
+ return SliceArray(self.offset, self.gap, self.backstrides,
[self.slicesize], self.arr.implementation, self.arr, self.dtype)
class AxisIterator(base.BaseArrayIterator):
def __init__(self, array, shape, dim, cumultative):
diff --git a/pypy/module/micronumpy/test/test_nditer.py
b/pypy/module/micronumpy/test/test_nditer.py
--- a/pypy/module/micronumpy/test/test_nditer.py
+++ b/pypy/module/micronumpy/test/test_nditer.py
@@ -44,21 +44,21 @@
def test_external_loop(self):
from numpypy import arange, nditer, array
- a = arange(12).reshape(2,3,2)
+ a = arange(24).reshape(2, 3, 4)
r = []
n = 0
for x in nditer(a, flags=['external_loop']):
r.append(x)
n += 1
assert n == 1
- assert (array(r) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]).all()
+ assert (array(r) == range(24)).all()
r = []
n = 0
for x in nditer(a, flags=['external_loop'], order='F'):
r.append(x)
n += 1
- assert n == 6
- assert (array(r) == [[0, 6], [2, 8], [4, 10], [1, 7], [3, 9], [5,
11]]).all()
+ assert n == 12
+ assert (array(r) == [[ 0, 12], [ 4, 16], [ 8, 20], [ 1, 13], [ 5, 17],
[ 9, 21], [ 2, 14], [ 6, 18], [10, 22], [ 3, 15], [ 7, 19], [11, 23]]).all()
def test_interface(self):
from numpypy import arange, nditer, zeros
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit