Author: mattip <[email protected]>
Branch: object-dtype
Changeset: r76150:6239ed06f4ed
Date: 2015-02-25 23:22 +0200
http://bitbucket.org/pypy/pypy/changeset/6239ed06f4ed/

Log:    fix nonnative and object sort, argsort by converting _read,_write
        static methods

diff --git a/pypy/module/micronumpy/selection.py 
b/pypy/module/micronumpy/selection.py
--- a/pypy/module/micronumpy/selection.py
+++ b/pypy/module/micronumpy/selection.py
@@ -13,18 +13,19 @@
 INT_SIZE = rffi.sizeof(lltype.Signed)
 
 all_types = (types.all_float_types + types.all_complex_types +
-             types.all_int_types)
-all_types = [i for i in all_types if not issubclass(i[0], types.Float16)]
+             types.all_int_types + [(types.ObjectType, 'object')])
 all_types = unrolling_iterable(all_types)
 
 
 def make_argsort_function(space, itemtype, comp_type, count=1):
     TP = itemtype.T
     step = rffi.sizeof(TP)
+    reader = itemtype.read_from_storage
+    writer = itemtype.write_to_storage
 
     class Repr(object):
         def __init__(self, index_stride_size, stride_size, size, values,
-                     indexes, index_start, start):
+                     indexes, index_start, start, native):
             self.index_stride_size = index_stride_size
             self.stride_size = stride_size
             self.index_start = index_start
@@ -32,39 +33,34 @@
             self.size = size
             self.values = values
             self.indexes = indexes
+            self.native = native
 
-        def getitem(self, item):
-            if count < 2:
-                v = raw_storage_getitem(TP, self.values, item * 
self.stride_size
-                                    + self.start)
-            else:
-                v = []
-                for i in range(count):
-                    _v = raw_storage_getitem(TP, self.values, item * 
self.stride_size
-                                    + self.start + step * i)
-                    v.append(_v)
+        def getitem(self, idx):
+            v = reader(TP, self.values, self.native, idx * self.stride_size + 
self.start)
             if comp_type == 'int':
                 v = widen(v)
             elif comp_type == 'float':
                 v = float(v)
             elif comp_type == 'complex':
                 v = [float(v[0]),float(v[1])]
+            elif comp_type == 'object':
+                pass
             else:
                 raise NotImplementedError('cannot reach')
             return (v, raw_storage_getitem(lltype.Signed, self.indexes,
-                                           item * self.index_stride_size +
+                                           idx * self.index_stride_size +
                                            self.index_start))
 
         def setitem(self, idx, item):
-            if count < 2:
-                raw_storage_setitem(self.values, idx * self.stride_size +
+            if comp_type == 'object':
+                writer(TP, self.values, self.native, idx * self.stride_size +
+                                self.start, item[0])
+            elif count == 1:
+                writer(TP, self.values, self.native, idx * self.stride_size +
                                 self.start, rffi.cast(TP, item[0]))
             else:
-                i = 0
-                for val in item[0]:
-                    raw_storage_setitem(self.values, idx * self.stride_size +
-                                self.start + i*step, rffi.cast(TP, val))
-                    i += 1
+                writer(TP, self.values, self.native, idx * self.stride_size +
+                                self.start, [rffi.cast(TP, i) for i in 
item[0]])
             raw_storage_setitem(self.indexes, idx * self.index_stride_size +
                                 self.index_start, item[1])
 
@@ -76,7 +72,7 @@
             values = alloc_raw_storage(size * stride_size,
                                             track_allocation=False)
             Repr.__init__(self, dtype.elsize, stride_size,
-                          size, values, indexes, start, start)
+                          size, values, indexes, start, start, True)
 
         def __del__(self):
             free_raw_storage(self.indexes, track_allocation=False)
@@ -135,11 +131,12 @@
         dtype = descriptor.get_dtype_cache(space).w_longdtype
         index_arr = W_NDimArray.from_shape(space, arr.get_shape(), dtype)
         storage = index_arr.implementation.get_storage()
+        native = arr.dtype.is_native()
         if len(arr.get_shape()) == 1:
             for i in range(arr.get_size()):
                 raw_storage_setitem(storage, i * INT_SIZE, i)
             r = Repr(INT_SIZE, itemsize, arr.get_size(), arr.get_storage(),
-                     storage, 0, arr.start)
+                     storage, 0, arr.start, native)
             ArgSort(r).sort()
         else:
             shape = arr.get_shape()
@@ -160,7 +157,8 @@
                     raw_storage_setitem(storage, i * index_stride_size +
                                         index_state.offset, i)
                 r = Repr(index_stride_size, stride_size, axis_size,
-                         arr.get_storage(), storage, index_state.offset, 
arr_state.offset)
+                         arr.get_storage(), storage, index_state.offset,
+                         arr_state.offset, native)
                 ArgSort(r).sort()
                 arr_state = arr_iter.next(arr_state)
                 index_state = index_iter.next(index_state)
@@ -185,52 +183,49 @@
 def make_sort_function(space, itemtype, comp_type, count=1):
     TP = itemtype.T
     step = rffi.sizeof(TP)
+    reader = itemtype.read_from_storage
+    writer = itemtype.write_to_storage
 
     class Repr(object):
-        def __init__(self, stride_size, size, values, start):
+        def __init__(self, stride_size, size, values, start, native):
             self.stride_size = stride_size
             self.start = start
             self.size = size
             self.values = values
+            self.native = native
 
-        def getitem(self, item):
-            if count < 2:
-                v = raw_storage_getitem(TP, self.values, item * 
self.stride_size
-                                    + self.start)
-            else:
-                v = []
-                for i in range(count):
-                    _v = raw_storage_getitem(TP, self.values, item * 
self.stride_size
-                                    + self.start + step * i)
-                    v.append(_v)
+        def getitem(self, idx):
+            v = reader(TP, self.values, self.native, idx * self.stride_size + 
self.start)
             if comp_type == 'int':
                 v = widen(v)
             elif comp_type == 'float':
                 v = float(v)
             elif comp_type == 'complex':
                 v = [float(v[0]),float(v[1])]
+            elif comp_type == 'object':
+                pass
             else:
                 raise NotImplementedError('cannot reach')
             return (v)
 
         def setitem(self, idx, item):
-            if count < 2:
-                raw_storage_setitem(self.values, idx * self.stride_size +
+            if comp_type == 'object':
+                writer(TP, self.values, self.native, idx * self.stride_size +
+                                self.start, item)
+            elif count == 1:
+                writer(TP, self.values, self.native, idx * self.stride_size +
                                 self.start, rffi.cast(TP, item))
             else:
-                i = 0
-                for val in item:
-                    raw_storage_setitem(self.values, idx * self.stride_size +
-                                self.start + i*step, rffi.cast(TP, val))
-                    i += 1
+                writer(TP, self.values, self.native, idx * self.stride_size +
+                                self.start, [rffi.cast(TP, i) for i in item])
 
     class ArgArrayRepWithStorage(Repr):
-        def __init__(self, stride_size, size):
+        def __init__(self, stride_size, size, native):
             start = 0
             values = alloc_raw_storage(size * stride_size,
                                             track_allocation=False)
             Repr.__init__(self, stride_size,
-                          size, values, start)
+                          size, values, start, native)
 
         def __del__(self):
             free_raw_storage(self.values, track_allocation=False)
@@ -283,9 +278,10 @@
         else:
             axis = space.int_w(w_axis)
         # create array of indexes
+        native = arr.dtype.is_native()
         if len(arr.get_shape()) == 1:
             r = Repr(itemsize, arr.get_size(), arr.get_storage(),
-                     arr.start)
+                     arr.start, native)
             ArgSort(r).sort()
         else:
             shape = arr.get_shape()
@@ -298,7 +294,8 @@
             stride_size = arr.strides[axis]
             axis_size = arr.shape[axis]
             while not arr_iter.done(arr_state):
-                r = Repr(stride_size, axis_size, arr.get_storage(), 
arr_state.offset)
+                r = Repr(stride_size, axis_size, arr.get_storage(),
+                                                 arr_state.offset, native)
                 ArgSort(r).sort()
                 arr_state = arr_iter.next(arr_state)
 
@@ -308,9 +305,6 @@
 def sort_array(arr, space, w_axis, w_order):
     cache = space.fromcache(SortCache)  # that populates SortClasses
     itemtype = arr.dtype.itemtype
-    if arr.dtype.byteorder == NPY.OPPBYTE:
-        raise oefmt(space.w_NotImplementedError,
-                    "sorting of non-native byteorder not supported yet")
     for tp in all_types:
         if isinstance(itemtype, tp[0]):
             return cache._lookup(tp)(arr, space, w_axis,
diff --git a/pypy/module/micronumpy/test/test_selection.py 
b/pypy/module/micronumpy/test/test_selection.py
--- a/pypy/module/micronumpy/test/test_selection.py
+++ b/pypy/module/micronumpy/test/test_selection.py
@@ -6,20 +6,17 @@
         assert array(2.0).argsort() == 0
         nnp = self.non_native_prefix
         for dtype in ['int', 'float', 'int16', 'float32', 'uint64',
-                      nnp + 'i2', complex]:
+                      nnp + 'i2', complex, 'float16']:
             a = array([6, 4, -1, 3, 8, 3, 256+20, 100, 101], dtype=dtype)
             exp = list(a)
             exp = sorted(range(len(exp)), key=exp.__getitem__)
             c = a.copy()
             res = a.argsort()
-            assert (res == exp).all(), '%r\n%r\n%r' % (a,res,exp)
+            assert (res == exp).all(), 'Failed sortng 
%r\na=%r\nres=%r\nexp=%r' % (dtype,a,res,exp)
             assert (a == c).all() # not modified
 
-            a = arange(100, dtype=dtype)
-            assert (a.argsort() == a).all()
-        import sys
-        if '__pypy__' in sys.builtin_module_names:
-            raises(NotImplementedError, 'arange(10,dtype="float16").argsort()')
+            #a = arange(100, dtype=dtype)
+            #assert (a.argsort() == a).all()
 
     def test_argsort_ndim(self):
         from numpy import array
@@ -63,14 +60,13 @@
                       'i2', complex]:
             a = array([6, 4, -1, 3, 8, 3, 256+20, 100, 101], dtype=dtype)
             exp = sorted(list(a))
-            res = a.copy()
-            res.sort()
-            assert (res == exp).all(), '%r\n%r\n%r' % (a,res,exp)
+            a.sort()
+            assert (a == exp).all(), 'Failed sorting %r\n%r\n%r' % (dtype, a, 
exp)
 
             a = arange(100, dtype=dtype)
             c = a.copy()
             a.sort()
-            assert (a == c).all()
+            assert (a == c).all(), 'Failed sortng %r\na=%r\nc=%r' % (dtype,a,c)
 
     def test_sort_nonnative(self):
         from numpy import array
@@ -79,12 +75,9 @@
             a = array([6, 4, -1, 3, 8, 3, 256+20, 100, 101], dtype=dtype)
             b = array([-1, 3, 3, 4, 6, 8, 100, 101, 256+20], dtype=dtype)
             c = a.copy()
-            import sys
-            if '__pypy__' in sys.builtin_module_names:
-                exc = raises(NotImplementedError, a.sort)
-                assert exc.value[0].find('supported') >= 0
-            #assert (a == b).all(), \
-            #    'a,orig,dtype %r,%r,%r' % (a,c,dtype)
+            a.sort()
+            assert (a == b).all(), \
+                'a,orig,dtype %r,%r,%r' % (a,c,dtype)
 
 # tests from numpy/tests/test_multiarray.py
     def test_sort_corner_cases(self):
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -30,7 +30,7 @@
 
 
 def simple_unary_op(func):
-    specialize.argtype(1)(func)
+    specialize.argtype(2)(func)
     @functools.wraps(func)
     def dispatcher(self, space, v):
         return self.box(
@@ -66,7 +66,7 @@
     return dispatcher
 
 def raw_unary_op(func):
-    specialize.argtype(1)(func)
+    specialize.argtype(2)(func)
     @functools.wraps(func)
     def dispatcher(self, space, v):
         return func(
@@ -172,30 +172,34 @@
     def default_fromstring(self, space):
         raise NotImplementedError
 
-    def _read(self, storage, i, offset):
-        res = raw_storage_getitem_unaligned(self.T, storage, i + offset)
-        if not self.native:
+    @staticmethod
+    def read_from_storage(T, storage, native, offset):
+        res = raw_storage_getitem_unaligned(T, storage, offset)
+        if not native:
             res = byteswap(res)
         return res
 
-    def _write(self, storage, i, offset, value):
-        if not self.native:
+    @staticmethod
+    def write_to_storage(T, storage, native, offset, value):
+        if not native:
             value = byteswap(value)
-        raw_storage_setitem_unaligned(storage, i + offset, value)
+        raw_storage_setitem_unaligned(storage, offset, value)
 
     def read(self, arr, i, offset, dtype=None):
-        return self.box(self._read(arr.storage, i, offset))
+        return self.box(self.read_from_storage(
+                                self.T, arr.storage, self.native, i + offset))
 
     def read_bool(self, arr, i, offset):
-        return bool(self.for_computation(self._read(arr.storage, i, offset)))
+        return bool(self.for_computation(self.read_from_storage(
+                self.T, arr.storage, self.native, i + offset)))
 
     def store(self, arr, i, offset, box):
-        self._write(arr.storage, i, offset, self.unbox(box))
+        self.write_to_storage(self.T, arr.storage, self.native, i + offset, 
self.unbox(box))
 
     def fill(self, storage, width, box, start, stop, offset):
         value = self.unbox(box)
         for i in xrange(start, stop, width):
-            self._write(storage, i, offset, value)
+            self.write_to_storage(self.T, storage, self.native, i + offset, 
value)
 
     def runpack_str(self, space, s):
         v = rffi.cast(self.T, runpack(self.format_code, s))
@@ -967,8 +971,8 @@
         else:
             return x
 
+FLOAT16_STORAGE_T = rffi.USHORT
 class Float16(BaseType, Float):
-    _STORAGE_T = rffi.USHORT
     T = rffi.SHORT
     BoxType = boxes.W_Float16Box
 
@@ -989,24 +993,26 @@
     def byteswap(self, w_v):
         value = self.unbox(w_v)
         hbits = float_pack(value, 2)
-        swapped = byteswap(rffi.cast(self._STORAGE_T, hbits))
+        swapped = byteswap(rffi.cast(FLOAT16_STORAGE_T, hbits))
         return self.box(float_unpack(r_ulonglong(swapped), 2))
 
-    def _read(self, storage, i, offset):
-        hbits = raw_storage_getitem_unaligned(self._STORAGE_T, storage, i + 
offset)
-        if not self.native:
+    @staticmethod
+    def read_from_storage(T, storage, native, offset):
+        hbits = raw_storage_getitem_unaligned(FLOAT16_STORAGE_T, storage, 
offset)
+        if not native:
             hbits = byteswap(hbits)
         return float_unpack(r_ulonglong(hbits), 2)
 
-    def _write(self, storage, i, offset, value):
+    @staticmethod
+    def write_to_storage(T, storage, native, offset, value):
         try:
             hbits = float_pack(value, 2)
         except OverflowError:
             hbits = float_pack(rfloat.INFINITY, 2)
-        hbits = rffi.cast(self._STORAGE_T, hbits)
-        if not self.native:
+        hbits = rffi.cast(FLOAT16_STORAGE_T, hbits)
+        if not native:
             hbits = byteswap(hbits)
-        raw_storage_setitem_unaligned(storage, i + offset, hbits)
+        raw_storage_setitem_unaligned(storage, offset, hbits)
 
 class Float32(BaseType, Float):
     T = rffi.FLOAT
@@ -1084,7 +1090,8 @@
         return bool(real) or bool(imag)
 
     def read_bool(self, arr, i, offset):
-        v = self.for_computation(self._read(arr.storage, i, offset))
+        v = self.for_computation(self.read_from_storage(
+                self.T, arr.storage, self.native, i + offset))
         return bool(v[0]) or bool(v[1])
 
     def get_element_size(self):
@@ -1127,33 +1134,36 @@
         assert isinstance(box, self.BoxType)
         return box.real, box.imag
 
-    def _read(self, storage, i, offset):
-        real = raw_storage_getitem_unaligned(self.T, storage, i + offset)
-        imag = raw_storage_getitem_unaligned(self.T, storage, i + offset + 
rffi.sizeof(self.T))
-        if not self.native:
+    @staticmethod
+    def read_from_storage(T, storage, native, offset):
+        real = raw_storage_getitem_unaligned(T, storage, offset)
+        imag = raw_storage_getitem_unaligned(T, storage, offset + 
rffi.sizeof(T))
+        if not native:
             real = byteswap(real)
             imag = byteswap(imag)
         return real, imag
 
     def read(self, arr, i, offset, dtype=None):
-        real, imag = self._read(arr.storage, i, offset)
+        real, imag = self.read_from_storage(
+                                self.T, arr.storage, self.native, i + offset)
         return self.box_complex(real, imag)
 
-    def _write(self, storage, i, offset, value):
+    @staticmethod
+    def write_to_storage(T, storage, native, offset, value):
         real, imag = value
-        if not self.native:
+        if not native:
             real = byteswap(real)
             imag = byteswap(imag)
-        raw_storage_setitem_unaligned(storage, i + offset, real)
-        raw_storage_setitem_unaligned(storage, i + offset + 
rffi.sizeof(self.T), imag)
+        raw_storage_setitem_unaligned(storage, offset, real)
+        raw_storage_setitem_unaligned(storage, offset + rffi.sizeof(T), imag)
 
     def store(self, arr, i, offset, box):
-        self._write(arr.storage, i, offset, self.unbox(box))
+        self.write_to_storage(self.T, arr.storage, self.native, i + offset, 
self.unbox(box))
 
     def fill(self, storage, width, box, start, stop, offset):
         value = self.unbox(box)
         for i in xrange(start, stop, width):
-            self._write(storage, i, offset, value)
+            self.write_to_storage(self.T, storage, self.native, i + offset, 
value)
 
     @complex_binary_op
     def add(self, space, v1, v2):
@@ -1642,21 +1652,24 @@
         return boxes.W_ObjectBox(w_item)
 
     def store(self, arr, i, offset, box):
-        self._write(arr.storage, i, offset, self.unbox(box))
+        self.write_to_storage(self.T, arr.storage, self.native, i + offset, 
self.unbox(box))
 
     def read(self, arr, i, offset, dtype=None):
-        return self.box(self._read(arr.storage, i, offset))
+        return self.box(self.read_from_storage(
+                            self.T, arr.storage, self.native, i + offset))
 
-    def _write(self, storage, i, offset, w_obj):
+    @staticmethod
+    def write_to_storage(T, storage, native, offset, w_obj):
         if we_are_translated():
             value = rffi.cast(lltype.Signed, cast_instance_to_gcref(w_obj))
         else:
             value = len(_all_objs_for_tests)
             _all_objs_for_tests.append(w_obj)
-        raw_storage_setitem_unaligned(storage, i + offset, value)
+        raw_storage_setitem_unaligned(storage, offset, value)
 
-    def _read(self, storage, i, offset):
-        res = raw_storage_getitem_unaligned(self.T, storage, i + offset)
+    @staticmethod
+    def read_from_storage(T, storage, native, offset):
+        res = raw_storage_getitem_unaligned(T, storage, offset)
         if we_are_translated():
             gcref = rffi.cast(llmemory.GCREF, res)
             w_obj = cast_gcref_to_instance(W_Root, gcref)
@@ -1667,13 +1680,16 @@
     def fill(self, storage, width, box, start, stop, offset):
         value = self.unbox(box)
         for i in xrange(start, stop, width):
-            self._write(storage, i, offset, value)
+            self.write_to_storage(self.T, storage, self.native, i + offset, 
value)
 
     def unbox(self, box):
         assert isinstance(box, self.BoxType)
         return box.w_obj
 
+    @specialize.argtype(1)
     def box(self, w_obj):
+        if we_are_translated():
+            assert isinstance(w_obj, W_Root)
         return self.BoxType(w_obj)
 
     def str_format(self, space, box):
@@ -1685,11 +1701,11 @@
 
     @simple_binary_op
     def add(self, space, v1, v2):
-        return space.add(v1, v2)
+        return space.add(space.wrap(v1), space.wrap(v2))
 
     @raw_binary_op
     def eq(self, space, v1, v2):
-        return space.eq_w(v1, v2)
+        return space.eq_w(space.wrap(v1), space.wrap(v2))
 
 class FlexibleType(BaseType):
     def get_element_size(self):
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -949,7 +949,7 @@
         return dt2
 
     if dt1.num == NPY.OBJECT or dt2.num == NPY.OBJECT:
-        return descriptor.get_dtype_cache(space).w_objectdtype
+        return get_dtype_cache(space).w_objectdtype
 
     # dt1.num should be <= dt2.num
     if dt1.num > dt2.num:
@@ -1065,7 +1065,7 @@
     uint64_dtype = get_dtype_cache(space).w_uint64dtype
     complex_dtype = get_dtype_cache(space).w_complex128dtype
     float_dtype = get_dtype_cache(space).w_float64dtype
-    object_dtype = descriptor.get_dtype_cache(space).w_objectdtype
+    object_dtype = get_dtype_cache(space).w_objectdtype
     if isinstance(w_obj, boxes.W_GenericBox):
         dtype = w_obj.get_dtype(space)
         return find_binop_result_dtype(space, dtype, current_guess)
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to