Author: Sergey Matyunin <[email protected]>
Branch: numpy_broadcast_nd
Changeset: r84066:d52b849b3779
Date: 2016-04-24 13:54 +0200
http://bitbucket.org/pypy/pypy/changeset/d52b849b3779/
Log: Implemented reset for numpy broadcast object.
diff --git a/pypy/module/micronumpy/broadcast.py
b/pypy/module/micronumpy/broadcast.py
--- a/pypy/module/micronumpy/broadcast.py
+++ b/pypy/module/micronumpy/broadcast.py
@@ -75,6 +75,11 @@
return res[0]
return space.newtuple(res)
+ def descr_reset(self, space):
+ self.index = 0
+ self.done = False
+ for it in self.list_iter_state:
+ it.reset()
W_Broadcast.typedef = TypeDef("numpy.broadcast",
__new__=interp2app(descr_new_broadcast),
@@ -86,4 +91,5 @@
numiter=GetSetProperty(W_Broadcast.descr_get_numiter),
nd=GetSetProperty(W_Broadcast.descr_get_number_of_dimensions),
iters=GetSetProperty(W_Broadcast.descr_get_iters),
+ reset=interp2app(W_Broadcast.descr_reset),
)
diff --git a/pypy/module/micronumpy/flatiter.py
b/pypy/module/micronumpy/flatiter.py
--- a/pypy/module/micronumpy/flatiter.py
+++ b/pypy/module/micronumpy/flatiter.py
@@ -76,7 +76,7 @@
base.get_order(), w_instance=base)
return loop.flatiter_getitem(res, self.iter, state, step)
finally:
- self.iter.reset(self.state, mutate=True)
+ self.reset()
def descr_setitem(self, space, w_idx, w_value):
if not (space.isinstance_w(w_idx, space.w_int) or
@@ -96,11 +96,14 @@
arr = convert_to_array(space, w_value)
loop.flatiter_setitem(space, dtype, arr, self.iter, state, step,
length)
finally:
- self.iter.reset(self.state, mutate=True)
+ self.reset()
def descr___array_wrap__(self, space, obj, w_context=None):
return obj
+ def reset(self):
+ self.iter.reset(self.state, mutate=True)
+
W_FlatIterator.typedef = TypeDef("numpy.flatiter",
base = GetSetProperty(W_FlatIterator.descr_base),
index = GetSetProperty(W_FlatIterator.descr_index),
diff --git a/pypy/module/micronumpy/test/test_broadcast.py
b/pypy/module/micronumpy/test/test_broadcast.py
--- a/pypy/module/micronumpy/test/test_broadcast.py
+++ b/pypy/module/micronumpy/test/test_broadcast.py
@@ -123,3 +123,15 @@
assert step_in_y == y[0, 0] # == 3
assert step_in_broadcast == (1, 3)
assert step2_in_y == y[1, 0] # == 4
+
+ def test_broadcast_reset(self):
+ import numpy as np
+ x = np.array([1, 2, 3])
+ y = np.array([[4], [5], [6]])
+
+ b = np.broadcast(x, y)
+ b.next(), b.next(), b.next()
+ b.reset()
+
+ assert b.index == 0
+ assert b.next() == (1, 4)
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit