Author: Maciej Fijalkowski <fij...@gmail.com> Branch: numpy-multidim-shards Changeset: r49547:638b988b580e Date: 2011-11-19 17:40 +0200 http://bitbucket.org/pypy/pypy/changeset/638b988b580e/
Log: in-progress. Get this into some shape so we can run tests diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -39,12 +39,19 @@ shape.append(size) batch = new_batch +class BroadcastDescription(object): + def __init__(self, shape, indices1, indices2): + self.shape = shape + self.indices1 = indices1 + self.indices2 = indices2 + def shape_agreement(space, shape1, shape2): """ Checks agreement about two shapes with respect to broadcasting. Returns the resulting shape. """ lshift = 0 rshift = 0 + adjustment = False if len(shape1) > len(shape2): m = len(shape1) n = len(shape2) @@ -56,21 +63,35 @@ lshift = len(shape1) - len(shape2) remainder = shape2 endshape = [0] * m + indices1 = [True] * m + indices2 = [True] * m for i in range(m - 1, m - n - 1, -1): left = shape1[i + lshift] right = shape2[i + rshift] if left == right: endshape[i] = left elif left == 1: + adjustment = True endshape[i] = right + indices1[i + lshift] = False elif right == 1: + adjustment = True endshape[i] = left + indices2[i + rshift] = False else: raise OperationError(space.w_ValueError, space.wrap( "frames are not aligned")) for i in range(m - n): + adjustment = True endshape[i] = remainder[i] + #if len(shape1) > len(shape2): + # xxx + #else: + # xxx + #if not adjustment: + # return None return endshape + return BroadcastDescription(endshape, indices1, indices2) def descr_new_array(space, w_subtype, w_item_or_iterable, w_dtype=None, w_order=NoneNotWrapped): @@ -105,7 +126,7 @@ space.call_function(space.gettypefor(interp_dtype.W_Dtype), w_dtype) ) arr = NDimArray(size, shape[:], dtype=dtype, order=order) - arr_iter = arr.start_iter() + arr_iter = arr.start_iter(arr.shape) for i in range(len(elems_w)): w_elem = elems_w[i] dtype.setitem_w(space, arr.storage, arr_iter.offset, w_elem) @@ -123,12 +144,13 @@ raise NotImplementedError class ArrayIterator(BaseIterator): - def __init__(self, size, offset=0): - self.offset = offset + def __init__(self, size): + self.offset = 0 self.size = size def next(self): - return ArrayIterator(self.size, self.offset + 1) + self.offset += 1 + return self def done(self): return self.offset >= self.size @@ -137,34 +159,25 @@ return self.offset class ViewIterator(BaseIterator): - def __init__(self, arr, offset=0, indices=None, done=False): - if indices is None: - self.indices = [0] * len(arr.shape) - self.offset = arr.start - else: - self.offset = offset - self.indices = indices - self.arr = arr - self._done = done + def __init__(self, arr): + self.indices = [0] * len(arr.shape) + self.offset = arr.start + self.arr = arr + self._done = False @jit.unroll_safe def next(self): - indices = [0] * len(self.arr.shape) - for i in range(len(self.arr.shape)): - indices[i] = self.indices[i] - done = False - offset = self.offset for i in range(len(self.arr.shape) -1, -1, -1): - if indices[i] < self.arr.shape[i] - 1: - indices[i] += 1 - offset += self.arr.shards[i] + if self.indices[i] < self.arr.shape[i] - 1: + self.indices[i] += 1 + self.offset += self.arr.shards[i] break else: - indices[i] = 0 - offset -= self.arr.backshards[i] + self.indices[i] = 0 + self.offset -= self.arr.backshards[i] else: - done = True - return ViewIterator(self.arr, offset, indices, done) + self._done = True + return self def done(self): return self._done @@ -172,13 +185,43 @@ def get_offset(self): return self.offset +class ResizingIterator(object): + def __init__(self, iter, shape, orig_indices): + self.shape = shape + self.indices = [0] * len(shape) + self.orig_indices = orig_indices + self.iter = iter + self._done = False + + @jit.unroll_safe + def next(self): + for i in range(len(self.shape) -1, -1, -1): + if self.indices[i] < self.shape[i] - 1: + self.indices[i] += 1 + if self.orig_indices[i]: + self.iter.next() + break + else: + self.indices[i] = 0 + else: + self._done = True + return self + + def get_offset(self): + return self.iter.get_offset() + + def done(self): + return self._done + class Call2Iterator(BaseIterator): def __init__(self, left, right): self.left = left self.right = right def next(self): - return Call2Iterator(self.left.next(), self.right.next()) + self.left.next() + self.right.next() + return self def done(self): return self.left.done() or self.right.done() @@ -193,7 +236,8 @@ self.child = child def next(self): - return Call1Iterator(self.child.next()) + self.child.next() + return self def done(self): return self.child.done() @@ -312,7 +356,7 @@ reduce_driver = jit.JitDriver(greens=['signature'], reds = ['i', 'result', 'self', 'cur_best', 'dtype']) def loop(self): - i = self.start_iter() + i = self.start_iter(self.shape) result = i.get_offset() cur_best = self.eval(i) i.next() @@ -339,7 +383,7 @@ def _all(self): dtype = self.find_dtype() - i = self.start_iter() + i = self.start_iter(self.shape) while not i.done(): all_driver.jit_merge_point(signature=self.signature, self=self, dtype=dtype, i=i) if not dtype.bool(self.eval(i)): @@ -351,7 +395,7 @@ def _any(self): dtype = self.find_dtype() - i = self.start_iter() + i = self.start_iter(self.shape) while not i.done(): any_driver.jit_merge_point(signature=self.signature, self=self, dtype=dtype, i=i) @@ -403,7 +447,7 @@ res.append_slice(str(self_shape), 1, len(self_shape) - 1) res.append(')') else: - self.to_str(space, 1, res, indent=' ') + concrete.to_str(space, 1, res, indent=' ') if (dtype is not space.fromcache(interp_dtype.W_Float64Dtype) and dtype is not space.fromcache(interp_dtype.W_Int64Dtype)) or \ not self.find_size(): @@ -488,7 +532,8 @@ def descr_str(self, space): ret = StringBuilder() - self.to_str(space, 0, ret, ' ') + concrete = self.get_concrete() + concrete.to_str(space, 0, ret, ' ') return space.wrap(ret.build()) def _index_of_single_item(self, space, w_idx): @@ -633,12 +678,12 @@ except ValueError: pass return space.wrap(space.is_true(self.get_concrete().eval( - self.start_iter()).wrap(space))) + self.start_iter(self.shape)).wrap(space))) def getitem(self, item): raise NotImplementedError - def start_iter(self): + def start_iter(self, res_shape=None): raise NotImplementedError def compute_index(self, space, offset): @@ -697,7 +742,7 @@ def eval(self, iter): return self.value - def start_iter(self): + def start_iter(self, res_shape=None): return ConstantIterator() def to_str(self, space, comma, builder, indent=' '): @@ -787,10 +832,10 @@ assert isinstance(call_sig, signature.Call1) return call_sig.func(self.res_dtype, val) - def start_iter(self): + def start_iter(self, res_shape=None): if self.forced_result is not None: - return self.forced_result.start_iter() - return Call1Iterator(self.values.start_iter()) + return self.forced_result.start_iter(res_shape) + return Call1Iterator(self.values.start_iter(res_shape)) class Call2(VirtualArray): """ @@ -814,10 +859,13 @@ pass return self.right.find_size() - def start_iter(self): + def start_iter(self, res_shape=None): if self.forced_result is not None: - return self.forced_result.start_iter() - return Call2Iterator(self.left.start_iter(), self.right.start_iter()) + return self.forced_result.start_iter(res_shape) + if res_shape is None: + res_shape = self.shape # we still force the shape on children + return Call2Iterator(self.left.start_iter(res_shape), + self.right.start_iter(res_shape)) def _eval(self, iter): assert isinstance(iter, Call2Iterator) @@ -895,15 +943,12 @@ return self.parent.find_dtype() def setslice(self, space, w_value): - if isinstance(w_value, NDimArray): - if self.shape != w_value.shape: - raise OperationError(space.w_TypeError, space.wrap( - "wrong assignment")) - self._sliceloop(w_value) + res_shape = shape_agreement(space, self.shape, w_value.shape) + self._sliceloop(w_value, res_shape) - def _sliceloop(self, source): - source_iter = source.start_iter() - res_iter = self.start_iter() + def _sliceloop(self, source, res_shape): + source_iter = source.start_iter(res_shape) + res_iter = self.start_iter(res_shape) while not res_iter.done(): slice_driver.jit_merge_point(signature=source.signature, self=self, source=source, @@ -914,8 +959,11 @@ source_iter = source_iter.next() res_iter = res_iter.next() - def start_iter(self, offset=0, indices=None): - return ViewIterator(self, offset=offset, indices=indices) + def start_iter(self, res_shape=None): + if res_shape is not None and res_shape != self.shape: + raise NotImplementedError # xxx + #return ResizingIterator(ViewIterator(self), res_shape, orig_indices) + return ViewIterator(self) def setitem(self, item, value): self.parent.setitem(item, value) @@ -967,9 +1015,11 @@ self.invalidated() self.dtype.setitem(self.storage, item, value) - def start_iter(self, offset=0, indices=None): + def start_iter(self, res_shape=None): if self.order == 'C': - return ArrayIterator(self.size, offset=offset) + if res_shape is not None and res_shape != self.shape: + raise NotImplementedError # xxx + return ArrayIterator(self.size) raise NotImplementedError # use ViewIterator simply, test it def __del__(self): diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py --- a/pypy/module/micronumpy/interp_ufuncs.py +++ b/pypy/module/micronumpy/interp_ufuncs.py @@ -56,7 +56,7 @@ space, obj.find_dtype(), promote_to_largest=True ) - start = obj.start_iter() + start = obj.start_iter(obj.shape) if self.identity is None: if size == 0: raise operationerrfmt(space.w_ValueError, "zero-size array to " @@ -123,7 +123,7 @@ def call(self, space, args_w): from pypy.module.micronumpy.interp_numarray import (Call2, - convert_to_array, Scalar) + convert_to_array, Scalar, shape_agreement) [w_lhs, w_rhs] = args_w w_lhs = convert_to_array(space, w_lhs) @@ -146,7 +146,8 @@ new_sig = signature.Signature.find_sig([ self.signature, w_lhs.signature, w_rhs.signature ]) - w_res = Call2(new_sig, w_lhs.shape or w_rhs.shape, calc_dtype, + new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape) + w_res = Call2(new_sig, new_shape, calc_dtype, res_dtype, w_lhs, w_rhs) w_lhs.add_invalidates(w_res) w_rhs.add_invalidates(w_res) diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -843,8 +843,17 @@ c = b + b assert c[1][1] == 12 - def test_broadcast(self): - skip("not working") + def test_broadcast_ufunc(self): + from numpy import array + a = array([[1, 2], [3, 4], [5, 6]]) + b = array([5, 6]) + #print a + b + c = ((a + b) == [[1+5, 2+6], [3+5, 4+6], [5+5, 6+6]]) + print c + print c.all() + assert c.all() + + def test_broadcast_setslice(self): import numpy a = numpy.zeros((100, 100)) b = numpy.ones(100) _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit