Author: Maciej Fijalkowski <[email protected]>
Branch: missing-ndarray-attributes
Changeset: r58584:47b57e79e2fa
Date: 2012-10-29 14:59 +0100
http://bitbucket.org/pypy/pypy/changeset/47b57e79e2fa/
Log: implement ndarray.choose
diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py
b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -13,12 +13,6 @@
from pypy.module.micronumpy.arrayimpl.sort import argsort_array
from pypy.rlib.debug import make_sure_not_resized
-def int_w(space, w_obj):
- try:
- return space.int_w(space.index(w_obj))
- except OperationError:
- return space.int_w(space.int(w_obj))
-
class BaseConcreteArray(base.BaseArrayImplementation):
start = 0
parent = None
@@ -85,7 +79,7 @@
for i, w_index in enumerate(view_w):
if space.isinstance_w(w_index, space.w_slice):
raise IndexError
- idx = int_w(space, w_index)
+ idx = support.int_w(space, w_index)
if idx < 0:
idx = self.get_shape()[i] + idx
if idx < 0 or idx >= self.get_shape()[i]:
@@ -159,7 +153,7 @@
return self._lookup_by_index(space, view_w)
if shape_len > 1:
raise IndexError
- idx = int_w(space, w_idx)
+ idx = support.int_w(space, w_idx)
return self._lookup_by_index(space, [space.wrap(idx)])
@jit.unroll_safe
diff --git a/pypy/module/micronumpy/constants.py
b/pypy/module/micronumpy/constants.py
new file mode 100644
--- /dev/null
+++ b/pypy/module/micronumpy/constants.py
@@ -0,0 +1,4 @@
+
+MODE_WRAP, MODE_RAISE, MODE_CLIP = range(3)
+
+MODES = {'wrap': MODE_WRAP, 'raise': MODE_RAISE, 'clip': MODE_CLIP}
diff --git a/pypy/module/micronumpy/interp_arrayops.py
b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -3,6 +3,7 @@
from pypy.module.micronumpy import loop, interp_ufuncs
from pypy.module.micronumpy.iter import Chunk, Chunks
from pypy.module.micronumpy.strides import shape_agreement
+from pypy.module.micronumpy.constants import MODES
from pypy.interpreter.error import OperationError, operationerrfmt
from pypy.interpreter.gateway import unwrap_spec
@@ -153,3 +154,28 @@
def count_nonzero(space, w_obj):
return space.wrap(loop.count_all_true(convert_to_array(space, w_obj)))
+
+def choose(space, arr, w_choices, out, mode):
+ choices = [convert_to_array(space, w_item) for w_item
+ in space.listview(w_choices)]
+ if not choices:
+ raise OperationError(space.w_ValueError,
+ space.wrap("choices list cannot be empty"))
+ # find the shape agreement
+ shape = arr.get_shape()
+ for choice in choices:
+ shape = shape_agreement(space, shape, choice)
+ if out is not None:
+ shape = shape_agreement(space, shape, out)
+ # find the correct dtype
+ dtype = choices[0].get_dtype()
+ for choice in choices[1:]:
+ dtype = interp_ufuncs.find_binop_result_dtype(space,
+ dtype,
choice.get_dtype())
+ if out is None:
+ out = W_NDimArray.from_shape(shape, dtype)
+ if mode not in MODES:
+ raise OperationError(space.w_ValueError,
+ space.wrap("mode %s not known" % (mode,)))
+ loop.choose(space, arr, choices, shape, dtype, out, MODES[mode])
+ return out
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -4,7 +4,8 @@
from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault
from pypy.module.micronumpy.base import W_NDimArray, convert_to_array,\
ArrayArgumentException
-from pypy.module.micronumpy import interp_dtype, interp_ufuncs, interp_boxes
+from pypy.module.micronumpy import interp_dtype, interp_ufuncs, interp_boxes,\
+ interp_arrayops
from pypy.module.micronumpy.strides import find_shape_and_elems,\
get_shape_from_iterable, to_coords, shape_agreement
from pypy.module.micronumpy.interp_flatiter import W_FlatIterator
@@ -402,9 +403,11 @@
return res
@unwrap_spec(mode=str)
- def descr_choose(self, space, w_choices, w_out=None, mode='raise'):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- "choose not implemented yet"))
+ def descr_choose(self, space, w_choices, mode='raise', w_out=None):
+ if w_out is not None and not isinstance(w_out, W_NDimArray):
+ raise OperationError(space.w_TypeError, space.wrap(
+ "return arrays must be of ArrayType"))
+ return interp_arrayops.choose(space, self, w_choices, w_out, mode)
def descr_clip(self, space, w_min, w_max, w_out=None):
raise OperationError(space.w_NotImplementedError, space.wrap(
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -4,11 +4,14 @@
over all the array elements.
"""
+from pypy.interpreter.error import OperationError
from pypy.rlib.rstring import StringBuilder
from pypy.rlib import jit
from pypy.rpython.lltypesystem import lltype, rffi
from pypy.module.micronumpy.base import W_NDimArray
from pypy.module.micronumpy.iter import PureShapeIterator
+from pypy.module.micronumpy import constants
+from pypy.module.micronumpy.support import int_w
call2_driver = jit.JitDriver(name='numpy_call2',
greens = ['shapelen', 'func', 'calc_dtype',
@@ -486,3 +489,36 @@
to_iter.setitem(dtype.itemtype.byteswap(from_iter.getitem()))
to_iter.next()
from_iter.next()
+
+choose_driver = jit.JitDriver(greens = ['shapelen', 'mode', 'dtype'],
+ reds = ['shape', 'iterators', 'arr_iter',
+ 'out_iter'])
+
+def choose(space, arr, choices, shape, dtype, out, mode):
+ shapelen = len(shape)
+ iterators = [a.create_iter(shape) for a in choices]
+ arr_iter = arr.create_iter(shape)
+ out_iter = out.create_iter(shape)
+ while not arr_iter.done():
+ choose_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
+ mode=mode, shape=shape,
+ iterators=iterators, arr_iter=arr_iter,
+ out_iter=out_iter)
+ index = int_w(space, arr_iter.getitem())
+ if index < 0 or index >= len(iterators):
+ if mode == constants.MODE_RAISE:
+ raise OperationError(space.w_ValueError, space.wrap(
+ "invalid entry in choice array"))
+ elif mode == constants.MODE_WRAP:
+ index = index % (len(iterators))
+ else:
+ assert mode == constants.MODE_CLIP
+ if index < 0:
+ index = 0
+ else:
+ index = len(iterators) - 1
+ out_iter.setitem(iterators[index].getitem().convert_to(dtype))
+ for iter in iterators:
+ iter.next()
+ out_iter.next()
+ arr_iter.next()
diff --git a/pypy/module/micronumpy/support.py
b/pypy/module/micronumpy/support.py
--- a/pypy/module/micronumpy/support.py
+++ b/pypy/module/micronumpy/support.py
@@ -1,5 +1,11 @@
from pypy.rlib import jit
+from pypy.interpreter.error import OperationError
+def int_w(space, w_obj):
+ try:
+ return space.int_w(space.index(w_obj))
+ except OperationError:
+ return space.int_w(space.int(w_obj))
@jit.unroll_safe
def product(s):
diff --git a/pypy/module/micronumpy/test/test_arrayops.py
b/pypy/module/micronumpy/test/test_arrayops.py
--- a/pypy/module/micronumpy/test/test_arrayops.py
+++ b/pypy/module/micronumpy/test/test_arrayops.py
@@ -108,8 +108,16 @@
a, b, c = array([1, 2, 3]), [4, 5, 6], 13
raises(ValueError, "array([3, 1, 0]).choose([a, b, c])")
raises(ValueError, "array([3, 1, 0]).choose([a, b, c], 'raises')")
+ raises(ValueError, "array([3, 1, 0]).choose([])")
+ raises(ValueError, "array([-1, -2, -3]).choose([a, b, c])")
r = array([4, 1, 0]).choose([a, b, c], mode='clip')
assert (r == [13, 5, 3]).all()
r = array([4, 1, 0]).choose([a, b, c], mode='wrap')
assert (r == [4, 5, 3]).all()
-
+
+
+ def test_choose_dtype(self):
+ from _numpypy import array
+ a, b, c = array([1.2, 2, 3]), [4, 5, 6], 13
+ r = array([2, 1, 0]).choose([a, b, c])
+ assert r.dtype == float
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit