Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-multidim-shards
Changeset: r49547:638b988b580e
Date: 2011-11-19 17:40 +0200
http://bitbucket.org/pypy/pypy/changeset/638b988b580e/
Log: in-progress. Get this into some shape so we can run tests
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -39,12 +39,19 @@
shape.append(size)
batch = new_batch
+class BroadcastDescription(object):
+ def __init__(self, shape, indices1, indices2):
+ self.shape = shape
+ self.indices1 = indices1
+ self.indices2 = indices2
+
def shape_agreement(space, shape1, shape2):
""" Checks agreement about two shapes with respect to broadcasting. Returns
the resulting shape.
"""
lshift = 0
rshift = 0
+ adjustment = False
if len(shape1) > len(shape2):
m = len(shape1)
n = len(shape2)
@@ -56,21 +63,35 @@
lshift = len(shape1) - len(shape2)
remainder = shape2
endshape = [0] * m
+ indices1 = [True] * m
+ indices2 = [True] * m
for i in range(m - 1, m - n - 1, -1):
left = shape1[i + lshift]
right = shape2[i + rshift]
if left == right:
endshape[i] = left
elif left == 1:
+ adjustment = True
endshape[i] = right
+ indices1[i + lshift] = False
elif right == 1:
+ adjustment = True
endshape[i] = left
+ indices2[i + rshift] = False
else:
raise OperationError(space.w_ValueError, space.wrap(
"frames are not aligned"))
for i in range(m - n):
+ adjustment = True
endshape[i] = remainder[i]
+ #if len(shape1) > len(shape2):
+ # xxx
+ #else:
+ # xxx
+ #if not adjustment:
+ # return None
return endshape
+ return BroadcastDescription(endshape, indices1, indices2)
def descr_new_array(space, w_subtype, w_item_or_iterable, w_dtype=None,
w_order=NoneNotWrapped):
@@ -105,7 +126,7 @@
space.call_function(space.gettypefor(interp_dtype.W_Dtype), w_dtype)
)
arr = NDimArray(size, shape[:], dtype=dtype, order=order)
- arr_iter = arr.start_iter()
+ arr_iter = arr.start_iter(arr.shape)
for i in range(len(elems_w)):
w_elem = elems_w[i]
dtype.setitem_w(space, arr.storage, arr_iter.offset, w_elem)
@@ -123,12 +144,13 @@
raise NotImplementedError
class ArrayIterator(BaseIterator):
- def __init__(self, size, offset=0):
- self.offset = offset
+ def __init__(self, size):
+ self.offset = 0
self.size = size
def next(self):
- return ArrayIterator(self.size, self.offset + 1)
+ self.offset += 1
+ return self
def done(self):
return self.offset >= self.size
@@ -137,34 +159,25 @@
return self.offset
class ViewIterator(BaseIterator):
- def __init__(self, arr, offset=0, indices=None, done=False):
- if indices is None:
- self.indices = [0] * len(arr.shape)
- self.offset = arr.start
- else:
- self.offset = offset
- self.indices = indices
- self.arr = arr
- self._done = done
+ def __init__(self, arr):
+ self.indices = [0] * len(arr.shape)
+ self.offset = arr.start
+ self.arr = arr
+ self._done = False
@jit.unroll_safe
def next(self):
- indices = [0] * len(self.arr.shape)
- for i in range(len(self.arr.shape)):
- indices[i] = self.indices[i]
- done = False
- offset = self.offset
for i in range(len(self.arr.shape) -1, -1, -1):
- if indices[i] < self.arr.shape[i] - 1:
- indices[i] += 1
- offset += self.arr.shards[i]
+ if self.indices[i] < self.arr.shape[i] - 1:
+ self.indices[i] += 1
+ self.offset += self.arr.shards[i]
break
else:
- indices[i] = 0
- offset -= self.arr.backshards[i]
+ self.indices[i] = 0
+ self.offset -= self.arr.backshards[i]
else:
- done = True
- return ViewIterator(self.arr, offset, indices, done)
+ self._done = True
+ return self
def done(self):
return self._done
@@ -172,13 +185,43 @@
def get_offset(self):
return self.offset
+class ResizingIterator(object):
+ def __init__(self, iter, shape, orig_indices):
+ self.shape = shape
+ self.indices = [0] * len(shape)
+ self.orig_indices = orig_indices
+ self.iter = iter
+ self._done = False
+
+ @jit.unroll_safe
+ def next(self):
+ for i in range(len(self.shape) -1, -1, -1):
+ if self.indices[i] < self.shape[i] - 1:
+ self.indices[i] += 1
+ if self.orig_indices[i]:
+ self.iter.next()
+ break
+ else:
+ self.indices[i] = 0
+ else:
+ self._done = True
+ return self
+
+ def get_offset(self):
+ return self.iter.get_offset()
+
+ def done(self):
+ return self._done
+
class Call2Iterator(BaseIterator):
def __init__(self, left, right):
self.left = left
self.right = right
def next(self):
- return Call2Iterator(self.left.next(), self.right.next())
+ self.left.next()
+ self.right.next()
+ return self
def done(self):
return self.left.done() or self.right.done()
@@ -193,7 +236,8 @@
self.child = child
def next(self):
- return Call1Iterator(self.child.next())
+ self.child.next()
+ return self
def done(self):
return self.child.done()
@@ -312,7 +356,7 @@
reduce_driver = jit.JitDriver(greens=['signature'],
reds = ['i', 'result', 'self', 'cur_best', 'dtype'])
def loop(self):
- i = self.start_iter()
+ i = self.start_iter(self.shape)
result = i.get_offset()
cur_best = self.eval(i)
i.next()
@@ -339,7 +383,7 @@
def _all(self):
dtype = self.find_dtype()
- i = self.start_iter()
+ i = self.start_iter(self.shape)
while not i.done():
all_driver.jit_merge_point(signature=self.signature, self=self,
dtype=dtype, i=i)
if not dtype.bool(self.eval(i)):
@@ -351,7 +395,7 @@
def _any(self):
dtype = self.find_dtype()
- i = self.start_iter()
+ i = self.start_iter(self.shape)
while not i.done():
any_driver.jit_merge_point(signature=self.signature, self=self,
dtype=dtype, i=i)
@@ -403,7 +447,7 @@
res.append_slice(str(self_shape), 1, len(self_shape) - 1)
res.append(')')
else:
- self.to_str(space, 1, res, indent=' ')
+ concrete.to_str(space, 1, res, indent=' ')
if (dtype is not space.fromcache(interp_dtype.W_Float64Dtype) and
dtype is not space.fromcache(interp_dtype.W_Int64Dtype)) or \
not self.find_size():
@@ -488,7 +532,8 @@
def descr_str(self, space):
ret = StringBuilder()
- self.to_str(space, 0, ret, ' ')
+ concrete = self.get_concrete()
+ concrete.to_str(space, 0, ret, ' ')
return space.wrap(ret.build())
def _index_of_single_item(self, space, w_idx):
@@ -633,12 +678,12 @@
except ValueError:
pass
return space.wrap(space.is_true(self.get_concrete().eval(
- self.start_iter()).wrap(space)))
+ self.start_iter(self.shape)).wrap(space)))
def getitem(self, item):
raise NotImplementedError
- def start_iter(self):
+ def start_iter(self, res_shape=None):
raise NotImplementedError
def compute_index(self, space, offset):
@@ -697,7 +742,7 @@
def eval(self, iter):
return self.value
- def start_iter(self):
+ def start_iter(self, res_shape=None):
return ConstantIterator()
def to_str(self, space, comma, builder, indent=' '):
@@ -787,10 +832,10 @@
assert isinstance(call_sig, signature.Call1)
return call_sig.func(self.res_dtype, val)
- def start_iter(self):
+ def start_iter(self, res_shape=None):
if self.forced_result is not None:
- return self.forced_result.start_iter()
- return Call1Iterator(self.values.start_iter())
+ return self.forced_result.start_iter(res_shape)
+ return Call1Iterator(self.values.start_iter(res_shape))
class Call2(VirtualArray):
"""
@@ -814,10 +859,13 @@
pass
return self.right.find_size()
- def start_iter(self):
+ def start_iter(self, res_shape=None):
if self.forced_result is not None:
- return self.forced_result.start_iter()
- return Call2Iterator(self.left.start_iter(), self.right.start_iter())
+ return self.forced_result.start_iter(res_shape)
+ if res_shape is None:
+ res_shape = self.shape # we still force the shape on children
+ return Call2Iterator(self.left.start_iter(res_shape),
+ self.right.start_iter(res_shape))
def _eval(self, iter):
assert isinstance(iter, Call2Iterator)
@@ -895,15 +943,12 @@
return self.parent.find_dtype()
def setslice(self, space, w_value):
- if isinstance(w_value, NDimArray):
- if self.shape != w_value.shape:
- raise OperationError(space.w_TypeError, space.wrap(
- "wrong assignment"))
- self._sliceloop(w_value)
+ res_shape = shape_agreement(space, self.shape, w_value.shape)
+ self._sliceloop(w_value, res_shape)
- def _sliceloop(self, source):
- source_iter = source.start_iter()
- res_iter = self.start_iter()
+ def _sliceloop(self, source, res_shape):
+ source_iter = source.start_iter(res_shape)
+ res_iter = self.start_iter(res_shape)
while not res_iter.done():
slice_driver.jit_merge_point(signature=source.signature,
self=self, source=source,
@@ -914,8 +959,11 @@
source_iter = source_iter.next()
res_iter = res_iter.next()
- def start_iter(self, offset=0, indices=None):
- return ViewIterator(self, offset=offset, indices=indices)
+ def start_iter(self, res_shape=None):
+ if res_shape is not None and res_shape != self.shape:
+ raise NotImplementedError # xxx
+ #return ResizingIterator(ViewIterator(self), res_shape,
orig_indices)
+ return ViewIterator(self)
def setitem(self, item, value):
self.parent.setitem(item, value)
@@ -967,9 +1015,11 @@
self.invalidated()
self.dtype.setitem(self.storage, item, value)
- def start_iter(self, offset=0, indices=None):
+ def start_iter(self, res_shape=None):
if self.order == 'C':
- return ArrayIterator(self.size, offset=offset)
+ if res_shape is not None and res_shape != self.shape:
+ raise NotImplementedError # xxx
+ return ArrayIterator(self.size)
raise NotImplementedError # use ViewIterator simply, test it
def __del__(self):
diff --git a/pypy/module/micronumpy/interp_ufuncs.py
b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -56,7 +56,7 @@
space, obj.find_dtype(),
promote_to_largest=True
)
- start = obj.start_iter()
+ start = obj.start_iter(obj.shape)
if self.identity is None:
if size == 0:
raise operationerrfmt(space.w_ValueError, "zero-size array to "
@@ -123,7 +123,7 @@
def call(self, space, args_w):
from pypy.module.micronumpy.interp_numarray import (Call2,
- convert_to_array, Scalar)
+ convert_to_array, Scalar, shape_agreement)
[w_lhs, w_rhs] = args_w
w_lhs = convert_to_array(space, w_lhs)
@@ -146,7 +146,8 @@
new_sig = signature.Signature.find_sig([
self.signature, w_lhs.signature, w_rhs.signature
])
- w_res = Call2(new_sig, w_lhs.shape or w_rhs.shape, calc_dtype,
+ new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape)
+ w_res = Call2(new_sig, new_shape, calc_dtype,
res_dtype, w_lhs, w_rhs)
w_lhs.add_invalidates(w_res)
w_rhs.add_invalidates(w_res)
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -843,8 +843,17 @@
c = b + b
assert c[1][1] == 12
- def test_broadcast(self):
- skip("not working")
+ def test_broadcast_ufunc(self):
+ from numpy import array
+ a = array([[1, 2], [3, 4], [5, 6]])
+ b = array([5, 6])
+ #print a + b
+ c = ((a + b) == [[1+5, 2+6], [3+5, 4+6], [5+5, 6+6]])
+ print c
+ print c.all()
+ assert c.all()
+
+ def test_broadcast_setslice(self):
import numpy
a = numpy.zeros((100, 100))
b = numpy.ones(100)
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit