Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-multidim
Changeset: r49723:ce322f0bc995
Date: 2011-11-24 14:52 +0200
http://bitbucket.org/pypy/pypy/changeset/ce322f0bc995/

Log:    progress on test_zjit. Now test_slice is almost ready, barring few
        arraylen_gc that are from god knows where. A bit worrying it
        expanded test_multidim_slice which has tons and tons of operations
        that got included in the bridge. hopefully will be fixed by jit
        targets

diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -9,14 +9,16 @@
 from pypy.rlib.rstring import StringBuilder
 from pypy.rlib.objectmodel import instantiate
 
-numpy_driver = jit.JitDriver(greens=['signature'],
+numpy_driver = jit.JitDriver(greens=['shapelen', 'signature'],
                              reds=['result_size', 'i', 'ri', 'self',
                                      'result'])
-all_driver = jit.JitDriver(greens=['signature'], reds=['i', 'self', 'dtype'])
-any_driver = jit.JitDriver(greens=['signature'], reds=['i', 'self', 'dtype'])
-slice_driver = jit.JitDriver(greens=['signature'], reds=['self', 'source',
-                                                         'source_iter',
-                                                         'res_iter'])
+all_driver = jit.JitDriver(greens=['shapelen', 'signature'],
+                           reds=['i', 'self', 'dtype'])
+any_driver = jit.JitDriver(greens=['shapelen', 'signature'],
+                           reds=['i', 'self', 'dtype'])
+slice_driver = jit.JitDriver(greens=['shapelen', 'signature'],
+                             reds=['self', 'source', 'source_iter',
+                                   'res_iter'])
 
 def _find_shape_and_elems(space, w_iterable):
     shape = [space.len_w(w_iterable)]
@@ -119,15 +121,16 @@
         space.call_function(space.gettypefor(interp_dtype.W_Dtype), w_dtype)
     )
     arr = NDimArray(size, shape[:], dtype=dtype, order=order)
+    shapelen = len(shape)
     arr_iter = arr.start_iter(arr.shape)
     for i in range(len(elems_w)):
         w_elem = elems_w[i]
         dtype.setitem_w(space, arr.storage, arr_iter.offset, w_elem)
-        arr_iter = arr_iter.next()
+        arr_iter = arr_iter.next(shapelen)
     return arr
 
 class BaseIterator(object):
-    def next(self):
+    def next(self, shapelen):
         raise NotImplementedError
 
     def done(self):
@@ -141,7 +144,7 @@
         self.offset = 0
         self.size = size
 
-    def next(self):
+    def next(self, shapelen):
         arr = instantiate(ArrayIterator)
         arr.size = self.size
         arr.offset = self.offset + 1
@@ -159,22 +162,30 @@
         self.offset  = arr.start
         self.arr     = arr
         self._done   = False
-        self.shape_len = len(arr.shape)
 
     @jit.unroll_safe
-    def next(self):
-        shape_len = jit.promote(self.shape_len)
-        for i in range(shape_len - 1, -1, -1):
-            if self.indices[i] < self.arr.shape[i] - 1:
-                self.indices[i] += 1
-                self.offset += self.arr.strides[i]
+    def next(self, shapelen):
+        offset = self.offset
+        indices = [0] * shapelen
+        for i in range(shapelen):
+            indices[i] = self.indices[i]
+        done = False
+        for i in range(shapelen - 1, -1, -1):
+            if indices[i] < self.arr.shape[i] - 1:
+                indices[i] += 1
+                offset += self.arr.strides[i]
                 break
             else:
-                self.indices[i] = 0
-                self.offset -= self.arr.backstrides[i]
+                indices[i] = 0
+                offset -= self.arr.backstrides[i]
         else:
-            self._done = True
-        return self
+            done = True
+        res = instantiate(ViewIterator)
+        res.offset = offset
+        res.indices = indices
+        res.arr = self.arr
+        res._done = done
+        return res
 
     def done(self):
         return self._done
@@ -208,9 +219,8 @@
         self.arr = arr
 
     @jit.unroll_safe
-    def next(self):
-        shape_len = jit.promote(self.shape_len)
-        for i in range(shape_len - 1, -1, -1):
+    def next(self, shapelen):
+        for i in range(shapelen - 1, -1, -1):
             if self.indices[i] < self.res_shape[i] - 1:
                 self.indices[i] += 1
                 self.offset += self.strides[i]
@@ -233,8 +243,9 @@
         self.left = left
         self.right = right
 
-    def next(self):
-        return Call2Iterator(self.left.next(), self.right.next())
+    def next(self, shapelen):
+        return Call2Iterator(self.left.next(shapelen),
+                             self.right.next(shapelen))
 
     def done(self):
         if isinstance(self.left, ConstantIterator):
@@ -250,8 +261,8 @@
     def __init__(self, child):
         self.child = child
 
-    def next(self):
-        return Call1Iterator(self.child.next())
+    def next(self, shapelen):
+        return Call1Iterator(self.child.next(shapelen))
 
     def done(self):
         return self.child.done()
@@ -260,7 +271,7 @@
         return self.child.get_offset()
 
 class ConstantIterator(BaseIterator):
-    def next(self):
+    def next(self, shapelen):
         return self
 
     def done(self):
@@ -367,16 +378,18 @@
     descr_min = _reduce_ufunc_impl("minimum")
 
     def _reduce_argmax_argmin_impl(op_name):
-        reduce_driver = jit.JitDriver(greens=['signature'],
+        reduce_driver = jit.JitDriver(greens=['shapelen', 'signature'],
                          reds=['result', 'i', 'self', 'cur_best', 'dtype'])
         def loop(self):
-            i = self.start_iter(self.shape)
+            i = self.start_iter()
             result = i.get_offset()
             cur_best = self.eval(i)
-            i.next()
+            shapelen = len(self.shape)
+            i = i.next(shapelen)
             dtype = self.find_dtype()
             while not i.done():
                 reduce_driver.jit_merge_point(signature=self.signature,
+                                              shapelen=shapelen,
                                               self=self, dtype=dtype,
                                               i=i, result=result,
                                               cur_best=cur_best)
@@ -384,7 +397,7 @@
                 if dtype.ne(new_best, cur_best):
                     result = i.get_offset()
                     cur_best = new_best
-                i = i.next()
+                i = i.next(shapelen)
             return result
         def impl(self, space):
             size = self.find_size()
@@ -397,25 +410,30 @@
 
     def _all(self):
         dtype = self.find_dtype()
-        i = self.start_iter(self.shape)
+        i = self.start_iter()
+        shapelen = len(self.shape)
         while not i.done():
-            all_driver.jit_merge_point(signature=self.signature, self=self, 
dtype=dtype, i=i)
+            all_driver.jit_merge_point(signature=self.signature,
+                                       shapelen=shapelen, self=self,
+                                       dtype=dtype, i=i)
             if not dtype.bool(self.eval(i)):
                 return False
-            i = i.next()
+            i = i.next(shapelen)
         return True
     def descr_all(self, space):
         return space.wrap(self._all())
 
     def _any(self):
         dtype = self.find_dtype()
-        i = self.start_iter(self.shape)
+        i = self.start_iter()
+        shapelen = len(self.shape)
         while not i.done():
-            any_driver.jit_merge_point(signature=self.signature, self=self,
+            any_driver.jit_merge_point(signature=self.signature,
+                                       shapelen=shapelen, self=self,
                                        dtype=dtype, i=i)
             if dtype.bool(self.eval(i)):
                 return True
-            i = i.next()
+            i = i.next(shapelen)
         return False
     def descr_any(self, space):
         return space.wrap(self._any())
@@ -651,9 +669,6 @@
         view.setslice(space, w_value)
 
     def create_slice(self, space, chunks):
-        new_sig = signature.Signature.find_sig([
-            NDimSlice.signature, self.signature
-        ])
         if len(chunks) == 1:
             start, stop, step, lgt = chunks[0]
             if step == 0:
@@ -684,6 +699,9 @@
             shape += self.shape[s:]
             strides += self.strides[s:]
             backstrides += self.backstrides[s:]
+        new_sig = signature.Signature.find_sig([
+            NDimSlice.signature, self.signature,
+        ])
         return NDimSlice(self, new_sig, start, strides[:], backstrides[:],
                          shape[:])
 
@@ -787,15 +805,17 @@
         signature = self.signature
         result_size = self.find_size()
         result = NDimArray(result_size, self.shape, self.find_dtype())
+        shapelen = len(self.shape)
         i = self.start_iter()
         ri = result.start_iter()
         while not ri.done():
             numpy_driver.jit_merge_point(signature=signature,
+                                         shapelen=shapelen,
                                          result_size=result_size, i=i, ri=ri,
                                          self=self, result=result)
             result.dtype.setitem(result.storage, ri.offset, self.eval(i))
-            i = i.next()
-            ri = ri.next()
+            i = i.next(shapelen)
+            ri = ri.next(shapelen)
         return result
 
     def force_if_needed(self):
@@ -967,19 +987,23 @@
     def _sliceloop(self, source, res_shape):
         source_iter = source.start_iter(res_shape)
         res_iter = self.start_iter(res_shape)
+        shapelen = len(res_shape)
         while not res_iter.done():
             slice_driver.jit_merge_point(signature=source.signature,
+                                         shapelen=shapelen,
                                          self=self, source=source,
                                          res_iter=res_iter,
                                          source_iter=source_iter)
             self.setitem(res_iter.offset, source.eval(source_iter).convert_to(
                 self.find_dtype()))
-            source_iter = source_iter.next()
-            res_iter = res_iter.next()
+            source_iter = source_iter.next(shapelen)
+            res_iter = res_iter.next(shapelen)
 
     def start_iter(self, res_shape=None):
         if res_shape is not None and res_shape != self.shape:
             return BroadcastIterator(self, res_shape)
+        # XXX there is a possible optimization here with SingleDimViewIterator
+        #     ignore for now
         return ViewIterator(self)
 
     def setitem(self, item, value):
diff --git a/pypy/module/micronumpy/interp_ufuncs.py 
b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -9,7 +9,7 @@
 
 
 reduce_driver = jit.JitDriver(
-    greens = ["signature"],
+    greens = ['shapelen', "signature"],
     reds = ["i", "self", "dtype", "value", "obj"]
 )
 
@@ -63,26 +63,29 @@
             promote_to_largest=True
         )
         start = obj.start_iter(obj.shape)
+        shapelen = len(obj.shape)
         if self.identity is None:
             if size == 0:
                 raise operationerrfmt(space.w_ValueError, "zero-size array to "
                     "%s.reduce without identity", self.name)
             value = obj.eval(start).convert_to(dtype)
-            start = start.next()
+            start = start.next(shapelen)
         else:
             value = self.identity.convert_to(dtype)
         new_sig = signature.Signature.find_sig([
             self.reduce_signature, obj.signature
         ])
-        return self.reduce(new_sig, start, value, obj, dtype).wrap(space)
+        return self.reduce(new_sig, shapelen, start, value, obj,
+                           dtype).wrap(space)
 
-    def reduce(self, signature, i, value, obj, dtype):
+    def reduce(self, signature, shapelen, i, value, obj, dtype):
         while not i.done():
-            reduce_driver.jit_merge_point(signature=signature, self=self,
+            reduce_driver.jit_merge_point(signature=signature,
+                                          shapelen=shapelen, self=self,
                                           value=value, obj=obj, i=i,
                                           dtype=dtype)
             value = self.func(dtype, value, obj.eval(i).convert_to(dtype))
-            i = i.next()
+            i = i.next(shapelen)
         return value
 
 class W_Ufunc1(W_Ufunc):
diff --git a/pypy/module/micronumpy/signature.py 
b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -49,4 +49,4 @@
     _immutable_fields_ = ["func"]
 
     def __init__(self, func):
-        self.func = func
\ No newline at end of file
+        self.func = func
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -655,6 +655,10 @@
         assert r == 4
         r = (a + a).argmax()
         assert r == 9
+        a = array([1, 0, 0])
+        assert a.argmax() == 0
+        a = array([0, 0, 1])
+        assert a.argmax() == 2
 
     def test_argmin(self):
         from numpypy import array
diff --git a/pypy/module/micronumpy/test/test_zjit.py 
b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -261,7 +261,7 @@
 
     def define_multidim_slice():
         return """
-        a = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]]
+        a = [[1, 2, 3, 4], [3, 4, 5, 6], [5, 6, 7, 8], [7, 8, 9, 10], [9, 10, 
11, 12], [11, 12, 13, 14], [13, 14, 15, 16], [16, 17, 18, 19]]
         b = a -> ::2
         c = b + b
         c -> 1 -> 1
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to