Author: mattip <[email protected]>
Branch: missing-ndarray-attributes
Changeset: r60869:593666df7502
Date: 2013-02-04 19:21 +0200
http://bitbucket.org/pypy/pypy/changeset/593666df7502/

Log:    test that float16 is not argsortable, fix argsort for ints and
        complex

diff --git a/pypy/module/micronumpy/arrayimpl/sort.py 
b/pypy/module/micronumpy/arrayimpl/sort.py
--- a/pypy/module/micronumpy/arrayimpl/sort.py
+++ b/pypy/module/micronumpy/arrayimpl/sort.py
@@ -8,6 +8,7 @@
 from rpython.rlib.rawstorage import raw_storage_getitem, raw_storage_setitem, \
         free_raw_storage, alloc_raw_storage
 from rpython.rlib.unroll import unrolling_iterable
+from rpython.rlib.rarithmetic import intmask
 from rpython.rlib.objectmodel import specialize
 from pypy.interpreter.error import OperationError
 from pypy.module.micronumpy.base import W_NDimArray
@@ -35,21 +36,20 @@
             if count < 2:
                 v = raw_storage_getitem(TP, self.values, item * 
self.stride_size
                                     + self.start)
-                if comp_type=='int':
-                    v = int(v)
-                elif comp_type=='float':
-                    v = float(v)
-                elif comp_type=='complex':
-                    v = float(v[0]),float(v[1])
-                else:
-                    raise NotImplementedError('cannot reach')
             else:
                 v = []
                 for i in range(count):
                     _v = raw_storage_getitem(TP, self.values, item * 
self.stride_size
                                     + self.start + step * i)
                     v.append(_v)
-                v = for_computation(v)
+            if comp_type=='int':
+                v = intmask(v)
+            elif comp_type=='float':
+                v = float(v)
+            elif comp_type=='complex':
+                v = [float(v[0]),float(v[1])]
+            else:
+                raise NotImplementedError('cannot reach')
             return (v, raw_storage_getitem(lltype.Signed, self.indexes,
                                            item * self.index_stride_size +
                                            self.index_start))
@@ -59,9 +59,11 @@
                 raw_storage_setitem(self.values, idx * self.stride_size +
                                 self.start, rffi.cast(TP, item[0]))
             else:
-                for i in range(count):
+                i = 0
+                for val in item[0]:
                     raw_storage_setitem(self.values, idx * self.stride_size +
-                                self.start + i*step, rffi.cast(TP, item[0][i]))
+                                self.start + i*step, rffi.cast(TP, val))
+                    i += 1
             raw_storage_setitem(self.indexes, idx * self.index_stride_size +
                                 self.index_start, item[1])
 
@@ -94,9 +96,20 @@
         for i in range(stop-start):
             retval.setitem(i, lst.getitem(i+start))
         return retval
-
-    def arg_lt(a, b):
-        return a[0] < b[0]
+    
+    if count < 2:
+        def arg_lt(a, b):
+            # Does numpy do <= ?
+            return a[0] < b[0]
+    else:
+        def arg_lt(a, b):
+            for i in range(count):
+                if a[0][i] < b[0][i]:
+                    return True
+                elif a[0][i] > b[0][i]:
+                    return False
+            # Does numpy do True?    
+            return False
 
     ArgSort = make_timsort_class(arg_getitem, arg_setitem, arg_length,
                                  arg_getitem_slice, arg_lt)
@@ -174,7 +187,7 @@
         self.built = True
         cache = {}
         for cls, it in all_types._items:
-            if cls in types.all_complex_types:
+            if it == 'complex':
                 cache[cls] = make_sort_function(space, cls, it, 2)
             else:
                 cache[cls] = make_sort_function(space, cls, it)
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -2373,7 +2373,7 @@
         from _numpypy import array, arange
         assert array(2.0).argsort() == 0
         nnp = self.non_native_prefix
-        for dtype in ['int', 'float', 'int16', 'float32', 'float16', 
+        for dtype in ['int', 'float', 'int16', 'float32', 'uint64', 
                         nnp + 'i2', complex]:
             a = array([6, 4, -1, 3, 8, 3, 256+20, 100, 101], dtype=dtype)
             c = a.copy()
@@ -2383,6 +2383,7 @@
             assert (a == c).all() # not modified
             a = arange(100)
             assert (a.argsort() == a).all()
+        raises(NotImplementedError, 'arange(10,dtype="float16").argsort()')    
 
     def test_argsort_nd(self):
         from _numpypy import array
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to