Author: Ilya Osadchiy <[email protected]>
Branch: numpy-multidim-exp
Changeset: r45090:b2dc68ec3a1a
Date: 2011-06-21 22:28 +0300
http://bitbucket.org/pypy/pypy/changeset/b2dc68ec3a1a/
Log: numpy: something on multidimensions
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -83,12 +83,31 @@
def descr_len(self, space):
return self.get_concrete().descr_len(space)
+ def subscript_to_index(subscript, shape):
+ # TODO: is it better to store cumulative multiply of shape and then
index = reduce("add", map("mul", subscript, cummult_shape)) ?
+ index = 0
+ stride = 1
+ for ind, size in zip(subscript, shape):
+ index += ind * stride
+ stride *= size
+
def descr_getitem(self, space, w_idx):
- # TODO: indexing by tuples
- start, stop, step, slice_length = space.decode_index4(w_idx,
self.find_size())
- if step == 0:
- # Single index
- return space.wrap(self.get_concrete().getitem(start))
+ if space.is_true(space.isinstance(w_idx, space.w_tuple)):
+ # TODO: slices inside tuples, incomplete ind etc
+ subscript = space.unpacktuple(w_idx)
+ shape = self.find_shape()
+ if len(subscript) == len(shape):
+ # Fully qualified index
+ idx = subscript_to_index(subscript, shape)
+ is_single_elem = True
+ else:
+ start, stop, step, slice_length = space.decode_index4(w_idx,
self.find_size())
+ idx = start
+ is_single_elem = (step == 0)
+
+ if is_single_elem:
+ # Single element
+ return space.wrap(self.get_concrete().getitem(idx))
else:
# Slice
res = SingleDimSlice(start, stop, step, slice_length, self,
self.signature.transition(SingleDimSlice.static_signature))
@@ -110,6 +129,9 @@
BaseArray.__init__(self)
self.float_value = float_value
+ def find_shape(self):
+ raise ValueError
+
def find_size(self):
raise ValueError
@@ -120,6 +142,7 @@
"""
Class for representing virtual arrays, such as binary ops or ufuncs
"""
+ _immutable_fields_ = ["shape"]
def __init__(self, signature):
BaseArray.__init__(self)
self.forced_result = None
@@ -133,7 +156,11 @@
i = 0
signature = self.signature
result_size = self.find_size()
- result = SingleDimArray(result_size)
+ result_shape = self.find_shape()
+ if len(result_shape) == 1:
+ result = SingleDimArray(result_size)
+ else:
+ result = MultiDimArray(result_size)
while i < result_size:
numpy_driver.jit_merge_point(signature=signature,
result_size=result_size, i=i,
@@ -156,13 +183,18 @@
return self.forced_result.eval(i)
return self._eval(i)
+ def find_shape(self):
+ if self.forced_result is not None:
+ # The result has been computed and sources may be unavailable
+ return self.forced_result.find_shape()
+ return self._find_shape()
+
def find_size(self):
if self.forced_result is not None:
# The result has been computed and sources may be unavailable
return self.forced_result.find_size()
return self._find_size()
-
class Call1(VirtualArray):
_immutable_fields_ = ["function", "values"]
@@ -174,6 +206,9 @@
def _del_sources(self):
self.values = None
+ def _find_shape(self):
+ return self.values.find_shape()
+
def _find_size(self):
return self.values.find_size()
@@ -195,6 +230,13 @@
self.left = None
self.right = None
+ def _find_shape(self):
+ try:
+ return self.left.find_shape()
+ except ValueError:
+ pass
+ return self.right.find_shape()
+
def _find_size(self):
try:
return self.left.find_size()
@@ -247,6 +289,9 @@
self.step = step
self.size = slice_length
+ def find_shape(self):
+ return (self.size,)
+
def find_size(self):
return self.size
@@ -254,7 +299,10 @@
return (self.start + item * self.step)
-class SingleDimArray(BaseArray):
+class ConcreteArray(BaseArray):
+ """
+ Class for array arrays that actually store data
+ """
signature = Signature()
def __init__(self, size):
@@ -273,6 +321,19 @@
def eval(self, i):
return self.storage[i]
+ def getitem(self, item):
+ return self.storage[item]
+
+ def __del__(self):
+ lltype.free(self.storage, flavor='raw')
+
+class SingleDimArray(ConcreteArray):
+ def __init__(self, size):
+ ConcreteArray.__init__(self, size)
+
+ def find_shape(self):
+ return (self.size,)
+
def getindex(self, space, item):
if item >= self.size:
raise operationerrfmt(space.w_IndexError,
@@ -287,17 +348,28 @@
def descr_len(self, space):
return space.wrap(self.size)
- def getitem(self, item):
- return self.storage[item]
-
@unwrap_spec(item=int, value=float)
def descr_setitem(self, space, item, value):
item = self.getindex(space, item)
self.invalidated()
self.storage[item] = value
- def __del__(self):
- lltype.free(self.storage, flavor='raw')
+class MultiDimArray(ConcreteArray):
+ _immutable_fields_ = ["shape"]
+ def __init__(self, size, shape):
+ ConcreteArray.__init__(self, size)
+ self.shape = shape
+
+ def find_shape(self):
+ return self.shape
+
+ def descr_len(self, space):
+ return space.wrap(self.shape(0))
+
+ def descr_setitem(self, space, w_subscript, w_value):
+ item = self.getindex(space, item)
+ self.invalidated()
+ self.storage[item] = value
def descr_new_numarray(space, w_type, w_size_or_iterable):
l = space.listview(w_size_or_iterable)
@@ -308,10 +380,16 @@
i += 1
return space.wrap(arr)
-@unwrap_spec(ObjSpace, int)
-def zeros(space, size):
- return space.wrap(SingleDimArray(size))
-
+#@unwrap_spec(ObjSpace, int)
+def zeros(space, w_size):
+ if space.is_true(space.isinstance(w_size, space.w_tuple)):
+ shape = tuple(space.unpackiterable(w_size))
+ size = reduce(lambda x, y: x*y, shape)
+ return space.wrap(MultiDimArray(size, shape))
+ elif space.is_true(space.isinstance(w_size, space.w_int)):
+ return space.wrap(SingleDimArray(space.int_w(w_size)))
+ else:
+ raise OperationError(space.w_TypeError, space.wrap("expected sequence
object with len >= 0"))
BaseArray.typedef = TypeDef(
'numarray',
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -175,7 +175,6 @@
a[2] = 20
assert s[2] == 20
-
def test_slice_invaidate(self):
# check that slice shares invalidation list with
from numpy import array
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit