Author: Maciej Fijalkowski <fij...@gmail.com>
Branch: numpy-multidim
Changeset: r48510:61f36db28f06
Date: 2011-10-27 17:37 +0200
http://bitbucket.org/pypy/pypy/changeset/61f36db28f06/

Log:    setitem with slice - part one

diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -10,9 +10,11 @@
 
 numpy_driver = jit.JitDriver(greens = ['signature'],
                              reds = ['result_size', 'i', 'self', 'result'])
-all_driver = jit.JitDriver(greens=['signature'], reds=['i', 'size', 'self', 
'dtype'])
-any_driver = jit.JitDriver(greens=['signature'], reds=['i', 'size', 'self', 
'dtype'])
-slice_driver = jit.JitDriver(greens=['signature'], reds=['i', 'j', 'step', 
'stop', 'source', 'dest'])
+all_driver = jit.JitDriver(greens=['signature'], reds=['i', 'size', 'self',
+                                                       'dtype'])
+any_driver = jit.JitDriver(greens=['signature'], reds=['i', 'size', 'self',
+                                                       'dtype'])
+slice_driver = jit.JitDriver(greens=['signature'], reds=['i', 'self', 
'source'])
 
 class BaseArray(Wrappable):
     _attrs_ = ["invalidates", "signature"]
@@ -304,55 +306,26 @@
 
     def descr_setitem(self, space, w_idx, w_value):
         self.invalidated()
-        if self._single_item_at_index(space, w_idx):
+        if self._single_item_result(space, w_idx):
             item = self._single_item_at_index(space, w_idx)
             self.get_concrete().setitem_w(space, item, w_value)
             return
-        xxx
-        if space.isinstance_w(w_idx, space.w_tuple):
-            length = space.len_w(w_idx)
-            if length > 1: # only one dimension for now.
-                raise OperationError(space.w_IndexError,
-                                     space.wrap("invalid index"))
-            if length == 0:
-                w_idx = space.newslice(space.wrap(0),
-                                      space.wrap(self.find_size()),
-                                      space.wrap(1))
-            else:
-                w_idx = space.getitem(w_idx, space.wrap(0))
-        start, stop, step, slice_length = space.decode_index4(w_idx,
-                                                              self.find_size())
-        if step == 0:
-            # Single index
-            self.get_concrete().setitem_w(space, start, w_value)
+        concrete = self.get_concrete()
+        if isinstance(w_value, BaseArray):
+            # for now we just copy if setting part of an array from
+            # part of itself. can be improved.
+            if (concrete.get_root_storage() ==
+                w_value.get_concrete().get_root_storage()):
+                w_value = space.call_function(space.gettypefor(BaseArray), 
w_value)
+                assert isinstance(w_value, BaseArray)
         else:
-            concrete = self.get_concrete()
-            if isinstance(w_value, BaseArray):
-                # for now we just copy if setting part of an array from
-                # part of itself. can be improved.
-                if (concrete.get_root_storage() ==
-                    w_value.get_concrete().get_root_storage()):
-                    w_value = space.call_function(space.gettypefor(BaseArray), 
w_value)
-                    assert isinstance(w_value, BaseArray)
-            else:
-                w_value = convert_to_array(space, w_value)
-            concrete.setslice(space, start, stop, step,
-                                               slice_length, w_value)
+            w_value = convert_to_array(space, w_value)
+        view = self._create_slice(space, w_idx)
+        view.setslice(space, w_value)
 
     def descr_mean(self, space):
         return 
space.wrap(space.float_w(self.descr_sum(space))/self.find_size())
 
-    def _sliceloop(self, start, stop, step, source, dest):
-        i = start
-        j = 0
-        while (step > 0 and i < stop) or (step < 0 and i > stop):
-            slice_driver.jit_merge_point(signature=source.signature, step=step,
-                                         stop=stop, i=i, j=j, source=source,
-                                         dest=dest)
-            dest.setitem(i, source.eval(j).convert_to(dest.find_dtype()))
-            j += 1
-            i += step
-
 def convert_to_array(space, w_obj):
     if isinstance(w_obj, BaseArray):
         return w_obj
@@ -557,13 +530,23 @@
     def find_dtype(self):
         return self.parent.find_dtype()
 
-    def setslice(self, space, start, stop, step, slice_length, arr):
-        xxx
-        start = self.calc_index(start)
-        if stop != -1:
-            stop = self.calc_index(stop)
-        step = self.step * step
-        self._sliceloop(start, stop, step, arr, self.parent)
+    def setslice(self, space, w_value):
+        assert isinstance(w_value, NDimArray)
+        if self.shape != w_value.shape:
+            raise OperationError(space.w_TypeError, space.wrap(
+                "wrong assignment"))
+        self._sliceloop(w_value)
+
+    def _sliceloop(self, source):
+        i = 0
+        while i < self.size:
+            slice_driver.jit_merge_point(signature=source.signature, i=i,
+                                         self=self, source=source)
+            self.setitem(i, source.eval(i).convert_to(self.find_dtype()))
+            i += 1
+
+    def setitem(self, item, value):
+        self.parent.setitem(self.calc_index(item), value)
 
     def len_of_shape(self):
         return self.parent.len_of_shape() - self.shape_reduction
@@ -644,9 +627,6 @@
         self.invalidated()
         self.dtype.setitem(self.storage, item, value)
 
-    def setslice(self, space, start, stop, step, slice_length, arr):
-        self._sliceloop(start, stop, step, arr, self)
-
     def __del__(self):
         lltype.free(self.storage, flavor='raw', track_allocation=False)
 
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -658,6 +658,12 @@
         assert a[0][1][1] == 13
         assert a[1][2][1] == 15
 
+    def test_setitem_slice(self):
+        import numpy
+        a = numpy.zeros((3, 4))
+        a[1] = [1, 2, 3, 4]
+        assert a[1, 2] == 3
+
 class AppTestSupport(object):
     def setup_class(cls):
         import struct
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to