Author: Matti Picus <[email protected]>
Branch: numpypy-nditer
Changeset: r70694:7a493d5e93f3
Date: 2014-04-17 02:25 +0300
http://bitbucket.org/pypy/pypy/changeset/7a493d5e93f3/

Log:    wip

diff --git a/pypy/module/micronumpy/iterators.py 
b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -45,6 +45,18 @@
 from pypy.module.micronumpy.strides import calc_strides
 from pypy.module.micronumpy.base import W_NDimArray
 
+class ScalarIter(object):
+    def __init__(self, array):
+        self.array = array
+
+    def done(self):
+        return True
+
+    def next(self):
+        pass
+
+    def getitem(self):
+        return self.array.getitem(0)
 
 class PureShapeIter(object):
     def __init__(self, shape, idx_w):
@@ -137,6 +149,7 @@
         return self.array.getitem_bool(self.offset)
 
     def setitem(self, elem):
+        print 'setting',self.offset,'to',elem
         self.array.setitem(self.offset, elem)
 
 class SliceIterator(ArrayIter):
diff --git a/pypy/module/micronumpy/nditer.py b/pypy/module/micronumpy/nditer.py
--- a/pypy/module/micronumpy/nditer.py
+++ b/pypy/module/micronumpy/nditer.py
@@ -5,7 +5,7 @@
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
 from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
                                              shape_agreement, 
shape_agreement_multiple)
-from pypy.module.micronumpy.iterators import ArrayIter, SliceIterator
+from pypy.module.micronumpy.iterators import ArrayIter, SliceIterator, 
ScalarIter
 from pypy.module.micronumpy.concrete import SliceArray
 from pypy.module.micronumpy.descriptor import decode_w_dtype
 from pypy.module.micronumpy import ufuncs
@@ -205,6 +205,8 @@
 def get_iter(space, order, arr, shape, dtype):
     imp = arr.implementation.astype(space, dtype)
     backward = is_backward(imp, order)
+    if arr.is_scalar():
+        return ScalarIter(imp)
     if (imp.strides[0] < imp.strides[-1] and not backward) or \
        (imp.strides[0] > imp.strides[-1] and backward):
         # flip the strides. Is this always true for multidimension?
@@ -310,16 +312,19 @@
                                                            shape=out_shape)
         if len(outargs) > 0:
             # Make None operands writeonly and flagged for allocation
-            out_dtype = self.dtypes[0] if len(self.dtypes) > 0 else None
-            for i in range(len(self.seq)):
-                if self.seq[i] is None:
-                    self.op_flags[i].get_it_item = (get_readwrite_item,
+            if len(self.dtypes) > 0:
+                out_dtype = self.dtypes[outargs[0]]
+            else:
+                out_dtype = None
+                for i in range(len(self.seq)):
+                    if self.seq[i] is None:
+                        self.op_flags[i].get_it_item = (get_readwrite_item,
                                                     get_readwrite_slice)
-                    self.op_flags[i].allocate = True
-                    continue
-                if self.op_flags[i].rw == 'w':
-                    continue
-                out_dtype = ufuncs.find_binop_result_dtype(space,
+                        self.op_flags[i].allocate = True
+                        continue
+                    if self.op_flags[i].rw == 'w':
+                        continue
+                    out_dtype = ufuncs.find_binop_result_dtype(space,
                                                 self.seq[i].get_dtype(), 
out_dtype)
             for i in outargs:
                 if self.seq[i] is None:
@@ -346,7 +351,7 @@
                     self.dtypes[i] = seq_d
                 elif selfd != seq_d and not 'r' in self.op_flags[i].tmp_copy:
                     raise OperationError(space.w_TypeError, space.wrap(
-                        "Iterator operand required copying or buffering"))
+                        "Iterator operand required copying or buffering for 
operand %d" % i))
         else:
             #copy them from seq
             self.dtypes = [s.get_dtype() for s in self.seq]
diff --git a/pypy/module/micronumpy/test/test_nditer.py 
b/pypy/module/micronumpy/test/test_nditer.py
--- a/pypy/module/micronumpy/test/test_nditer.py
+++ b/pypy/module/micronumpy/test/test_nditer.py
@@ -155,6 +155,17 @@
             r.append(sqrt(x))
         assert abs((array(r) - [1.73205080757j, 1.41421356237j, 1j, 0j,
                             1+0j, 1.41421356237+0j]).sum()) < 1e-5
+        multi = nditer([None, array([2, 3], dtype='int64'), array(2., 
dtype='double')],
+                       op_dtypes = ['int64', 'int64', 'float64'],
+                       op_flags = [['writeonly', 'allocate'], ['readonly'], 
['readonly']])
+        print 'starting the real mccoy'
+        for a, b, c in multi:
+            print 'in loop'
+            a[...] = b * c
+        print multi.operands[0]
+        print multi.operands[1]
+        print multi.operands[2]
+        assert (multi.operands[0] == [4, 6]).all()
 
     def test_casting(self):
         from numpy import arange, nditer
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to