Author: Brian Kearns <bdkea...@gmail.com>
Branch: 
Changeset: r73859:cb9125b5bf31
Date: 2014-10-09 01:41 -0400
http://bitbucket.org/pypy/pypy/changeset/cb9125b5bf31/

Log:    implement searchside_converter

diff --git a/pypy/module/micronumpy/constants.py 
b/pypy/module/micronumpy/constants.py
--- a/pypy/module/micronumpy/constants.py
+++ b/pypy/module/micronumpy/constants.py
@@ -65,6 +65,9 @@
 FLOATINGLTR = 'f'
 COMPLEXLTR = 'c'
 
+SEARCHLEFT = 0
+SEARCHRIGHT = 1
+
 ANYORDER = -1
 CORDER = 0
 FORTRANORDER = 1
diff --git a/pypy/module/micronumpy/converters.py 
b/pypy/module/micronumpy/converters.py
--- a/pypy/module/micronumpy/converters.py
+++ b/pypy/module/micronumpy/converters.py
@@ -1,4 +1,4 @@
-from pypy.interpreter.error import OperationError
+from pypy.interpreter.error import OperationError, oefmt
 from pypy.module.micronumpy import constants as NPY
 
 
@@ -41,6 +41,23 @@
                          space.wrap("clipmode not understood"))
 
 
+def searchside_converter(space, w_obj):
+    try:
+        s = space.str_w(w_obj)
+    except OperationError:
+        s = None
+    if not s:
+        raise oefmt(space.w_ValueError,
+                    "expected nonempty string for keyword 'side'")
+    if s[0] == 'l' or s[0] == 'L':
+        return NPY.SEARCHLEFT
+    elif s[0] == 'r' or s[0] == 'R':
+        return NPY.SEARCHRIGHT
+    else:
+        raise oefmt(space.w_ValueError,
+                    "'%s' is an invalid value for keyword 'side'", s)
+
+
 def order_converter(space, w_order, default):
     if space.is_none(w_order):
         return default
diff --git a/pypy/module/micronumpy/ndarray.py 
b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -16,7 +16,7 @@
     ArrayArgumentException, wrap_impl
 from pypy.module.micronumpy.concrete import BaseConcreteArray
 from pypy.module.micronumpy.converters import multi_axis_converter, \
-    order_converter, shape_converter
+    order_converter, shape_converter, searchside_converter
 from pypy.module.micronumpy.flagsobj import W_FlagsObject
 from pypy.module.micronumpy.flatiter import W_FlatIterator
 from pypy.module.micronumpy.strides import get_shape_from_iterable, \
@@ -728,21 +728,12 @@
         loop.round(space, self, calc_dtype, self.get_shape(), decimals, out)
         return out
 
-    @unwrap_spec(side=str, w_sorter=WrappedDefault(None))
-    def descr_searchsorted(self, space, w_v, side='left', w_sorter=None):
+    @unwrap_spec(w_side=WrappedDefault('left'), w_sorter=WrappedDefault(None))
+    def descr_searchsorted(self, space, w_v, w_side=None, w_sorter=None):
         if not space.is_none(w_sorter):
             raise OperationError(space.w_NotImplementedError, space.wrap(
                 'sorter not supported in searchsort'))
-        if not side or len(side) < 1:
-            raise OperationError(space.w_ValueError, space.wrap(
-                "expected nonempty string for keyword 'side'"))
-        elif side[0] == 'l' or side[0] == 'L':
-            side = 'l'
-        elif side[0] == 'r' or side[0] == 'R':
-            side = 'r'
-        else:
-            raise oefmt(space.w_ValueError,
-                        "'%s' is an invalid value for keyword 'side'", side)
+        side = searchside_converter(space, w_side)
         if len(self.get_shape()) > 1:
             raise oefmt(space.w_ValueError, "a must be a 1-d array")
         v = convert_to_array(space, w_v)
@@ -1321,7 +1312,7 @@
                 else:
                     imax = imid
             return imin
-        if side == 'l':
+        if side == 0:
             op = operator.lt
         else:
             op = operator.le
diff --git a/pypy/module/micronumpy/test/test_sorting.py 
b/pypy/module/micronumpy/test/test_sorting.py
--- a/pypy/module/micronumpy/test/test_sorting.py
+++ b/pypy/module/micronumpy/test/test_sorting.py
@@ -366,6 +366,12 @@
         ret = a.searchsorted(3, side='right')
         assert ret == 3
         assert isinstance(ret, np.generic)
+        exc = raises(ValueError, a.searchsorted, 3, side=None)
+        assert str(exc.value) == "expected nonempty string for keyword 'side'"
+        exc = raises(ValueError, a.searchsorted, 3, side='')
+        assert str(exc.value) == "expected nonempty string for keyword 'side'"
+        exc = raises(ValueError, a.searchsorted, 3, side=2)
+        assert str(exc.value) == "expected nonempty string for keyword 'side'"
         ret = a.searchsorted([-10, 10, 2, 3])
         assert (ret == [0, 5, 1, 2]).all()
         if '__pypy__' in sys.builtin_module_names:
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to