Author: Ronan Lamy <ronan.l...@gmail.com>
Branch: ufunc-reduce
Changeset: r78670:985e46cb47d8
Date: 2015-07-26 16:04 +0100
http://bitbucket.org/pypy/pypy/changeset/985e46cb47d8/

Log:    refactor .cumsum() and .cumprod()

diff --git a/pypy/module/micronumpy/ndarray.py 
b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -1145,14 +1145,14 @@
 
     # ----------------------- reduce -------------------------------
 
-    def _reduce_ufunc_impl(ufunc_name, name, variant=ufuncs.REDUCE, 
bool_result=False):
+    def _reduce_ufunc_impl(ufunc_name, name, bool_result=False):
         @unwrap_spec(keepdims=bool)
         def impl(self, space, w_axis=None, w_dtype=None, w_out=None, 
keepdims=False):
             out = out_converter(space, w_out)
             if bool_result:
                 w_dtype = descriptor.get_dtype_cache(space).w_booldtype
             return getattr(ufuncs.get(space), ufunc_name).reduce(
-                space, self, w_axis, keepdims, out, w_dtype, variant=variant)
+                space, self, w_axis, keepdims, out, w_dtype)
         impl.__name__ = name
         return impl
 
@@ -1163,8 +1163,23 @@
     descr_all = _reduce_ufunc_impl('logical_and', "descr_all", 
bool_result=True)
     descr_any = _reduce_ufunc_impl('logical_or', "descr_any", bool_result=True)
 
-    descr_cumsum = _reduce_ufunc_impl('add', "descr_cumsum", 
variant=ufuncs.ACCUMULATE)
-    descr_cumprod = _reduce_ufunc_impl('multiply', "descr_cumprod", 
variant=ufuncs.ACCUMULATE)
+
+    def _accumulate_method(ufunc_name, name):
+        def method(self, space, w_axis=None, w_dtype=None, w_out=None):
+            out = out_converter(space, w_out)
+            if space.is_none(w_axis):
+                w_axis = space.wrap(0)
+                arr = self.reshape(space, space.wrap(-1))
+            else:
+                arr = self
+            ufunc = getattr(ufuncs.get(space), ufunc_name)
+            return ufunc.reduce(space, arr, w_axis, False, out, w_dtype,
+                                variant=ufuncs.ACCUMULATE)
+        method.__name__ = name
+        return method
+
+    descr_cumsum = _accumulate_method('add', 'descr_cumsum')
+    descr_cumprod = _accumulate_method('multiply', 'descr_cumprod')
 
     def _reduce_argmax_argmin_impl(raw_name):
         op_name = "arg%s" % raw_name
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -159,7 +159,7 @@
         return retval
 
     def descr_accumulate(self, space, w_obj, w_axis=None, w_dtype=None, 
w_out=None):
-        if space.is_none(w_axis):
+        if w_axis is None:
             w_axis = space.wrap(0)
         out = out_converter(space, w_out)
         return self.reduce(space, w_obj, w_axis, True, #keepdims must be true
@@ -243,7 +243,9 @@
         if obj.is_scalar():
             return obj.get_scalar_value()
         shapelen = len(obj_shape)
+
         if space.is_none(w_axis):
+            axes = range(shapelen)
             axis = maxint
         else:
             if space.isinstance_w(w_axis, space.w_tuple) and 
space.len_w(w_axis) == 1:
@@ -253,6 +255,7 @@
                 raise oefmt(space.w_ValueError, "'axis' entry is out of 
bounds")
             if axis < 0:
                 axis += shapelen
+            axes = [axis]
         assert axis >= 0
         dtype = decode_w_dtype(space, dtype)
 
@@ -282,9 +285,14 @@
                             "which has no identity", self.name)
 
         if variant == ACCUMULATE:
+            if len(axes) != 1:
+                raise oefmt(space.w_ValueError,
+                    "accumulate does not allow multiple axes")
+            axis = axes[0]
+            assert axis >= 0
             dtype = self.find_binop_type(space, dtype)
             call__array_wrap__ = True
-            if shapelen > 1 and axis < shapelen:
+            if shapelen > 1:
                 temp = None
                 shape = obj_shape[:]
                 temp_shape = obj_shape[:axis] + obj_shape[axis + 1:]
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to