Author: mattip <[email protected]>
Branch: missing-ndarray-attributes
Changeset: r60307:92d3270bf0ab
Date: 2013-01-21 20:40 +0200
http://bitbucket.org/pypy/pypy/changeset/92d3270bf0ab/

Log:    test_zjit thinks cumltative is constant, hack around it

diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -173,32 +173,44 @@
         y_iter.next()
     return out
 
-axis_reduce__driver = jit.JitDriver(name='numpy_axis_reduce',
-                                    greens=['shapelen', 'cumultative',
+axis_reduce1__driver = jit.JitDriver(name='numpy_axis_reduce',
+                                    greens=['shapelen', 
                                             'func', 'dtype',
                                             'identity'],
                                     reds=['axis', 'arr', 'out', 'shape',
                                           'out_iter', 'arr_iter',
                                           'temp_iter'])
+axis_reduce2__driver = jit.JitDriver(name='numpy_axis_reduce',
+                                    greens=['shapelen', 
+                                            'func', 'dtype',
+                                            'identity'],
+                                    reds=['axis', 'arr', 'out', 'shape',
+                                          'out_iter', 'arr_iter',
+                                          ])
 
 def do_axis_reduce(shape, func, arr, dtype, axis, out, identity, cumultative,
                    temp):
-    out_iter = out.create_axis_iter(arr.get_shape(), axis, cumultative)
     if cumultative:
-        temp_iter = temp.create_axis_iter(arr.get_shape(), axis, False)
+        return do_cum_axis_reduce(shape, func, arr, dtype, axis, out, 
+                identity, temp)
     else:
-        temp_iter = out_iter # hack
+        return do_nocum_axis_reduce(shape, func, arr, dtype, axis, out, 
+                identity, temp)
+
+def do_cum_axis_reduce(shape, func, arr, dtype, axis, out, identity,
+                   temp):
+    out_iter = out.create_axis_iter(arr.get_shape(), axis, True)
+    temp_iter = temp.create_axis_iter(arr.get_shape(), axis, False)
     arr_iter = arr.create_iter()
     if identity is not None:
         identity = identity.convert_to(dtype)
     shapelen = len(shape)
     while not out_iter.done():
-        axis_reduce__driver.jit_merge_point(shapelen=shapelen, func=func,
+        axis_reduce1__driver.jit_merge_point(shapelen=shapelen, func=func,
                                             dtype=dtype, identity=identity,
                                             axis=axis, arr=arr, out=out,
                                             shape=shape, out_iter=out_iter,
                                             arr_iter=arr_iter,
-                                            cumultative=cumultative,
                                             temp_iter=temp_iter)
         w_val = arr_iter.getitem().convert_to(dtype)
         if out_iter.first_line:
@@ -208,9 +220,34 @@
             cur = temp_iter.getitem()
             w_val = func(dtype, cur, w_val)
         out_iter.setitem(w_val)
-        if cumultative:
-            temp_iter.setitem(w_val)
-            temp_iter.next()
+        temp_iter.setitem(w_val)
+        temp_iter.next()
+        arr_iter.next()
+        out_iter.next()
+    return out
+
+def do_nocum_axis_reduce(shape, func, arr, dtype, axis, out, identity,
+                   temp):
+    out_iter = out.create_axis_iter(arr.get_shape(), axis, False)
+    arr_iter = arr.create_iter()
+    if identity is not None:
+        identity = identity.convert_to(dtype)
+    shapelen = len(shape)
+    while not out_iter.done():
+        axis_reduce2__driver.jit_merge_point(shapelen=shapelen, func=func,
+                                            dtype=dtype, identity=identity,
+                                            axis=axis, arr=arr, out=out,
+                                            shape=shape, out_iter=out_iter,
+                                            arr_iter=arr_iter,
+                                            )
+        w_val = arr_iter.getitem().convert_to(dtype)
+        if out_iter.first_line:
+            if identity is not None:
+                w_val = func(dtype, identity, w_val)
+        else:
+            cur = out_iter.getitem()
+            w_val = func(dtype, cur, w_val)
+        out_iter.setitem(w_val)
         arr_iter.next()
         out_iter.next()
     return out
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to