Author: Alex Gaynor <[email protected]>
Branch: numpy-dtype-alt
Changeset: r46506:20c9eeb81a95
Date: 2011-08-14 12:28 -0700
http://bitbucket.org/pypy/pypy/changeset/20c9eeb81a95/

Log:    Refactor the world to make dtypes contol things. WIP.

diff --git a/pypy/module/micronumpy/compile.py 
b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -3,7 +3,8 @@
 It should not be imported by the module itself
 """
 
-from pypy.module.micronumpy.interp_numarray import FloatWrapper, 
SingleDimArray, BaseArray
+from pypy.module.micronumpy.interp_numarray import Scalar, SingleDimArray, 
BaseArray
+
 
 class BogusBytecode(Exception):
     pass
@@ -35,7 +36,7 @@
             stack.append(create_array(array_size))
             i += 1
         elif b == 'f':
-            stack.append(FloatWrapper(1.2))
+            stack.append(Scalar(1.2))
         elif b == '+':
             right = stack.pop()
             stack.append(stack.pop().descr_add(space, right))
diff --git a/pypy/module/micronumpy/interp_dtype.py 
b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -1,11 +1,16 @@
+import functools
+import math
+
 from pypy.interpreter.baseobjspace import Wrappable
 from pypy.interpreter.gateway import interp2app
 from pypy.interpreter.typedef import TypeDef, interp_attrproperty
 from pypy.objspace.std.floatobject import float2string
+from pypy.rlib.objectmodel import specialize
 from pypy.rlib.rfloat import DTSF_STR_PRECISION
 from pypy.rlib.unroll import unrolling_iterable
 from pypy.rpython.lltypesystem import lltype, llmemory, rffi
 
+
 SIGNEDLTR = "i"
 
 class W_Dtype(Wrappable):
@@ -55,20 +60,104 @@
         ))
 
     def getitem(self, storage, i):
-        return self.unerase(storage)[i]
+        return self.Box(self.unerase(storage)[i])
 
     def setitem(self, storage, i, item):
-        self.unerase(storage)[i] = item
+        self.unerase(storage)[i] = item.val
 
     def setitem_w(self, space, storage, i, w_item):
-        self.setitem(storage, i, self.unwrap(space, w_item))
+        self.setitem(storage, i, self.Box(self.unwrap(space, w_item)))
+
+    @specialize.argtype(1)
+    def adapt_val(self, val):
+        return self.Box(rffi.cast(self.TP.TO.OF, val))
 
     def str_format(self, item):
-        return str(item)
+        assert isinstance(item, self.Box)
+        return str(item.val)
+
+    # Operations.
+    def binop(func):
+        @functools.wraps(func)
+        def impl(self, v1, v2):
+            assert isinstance(v1, self.Box)
+            assert isinstance(v2, self.Box)
+            return self.Box(func(self, v1.val, v2.val))
+        return impl
+    def unaryop(func):
+        @functools.wraps(func)
+        def impl(self, v):
+            assert isinstance(v, self.Box)
+            return self.Box(func(self, v.val))
+        return impl
+
+    @binop
+    def add(self, v1, v2):
+        return v1 + v2
+    @binop
+    def sub(self, v1, v2):
+        return v1 - v2
+    @binop
+    def mul(self, v1, v2):
+        return v1 * v2
+    @binop
+    def div(self, v1, v2):
+        return v1 / v2
+    @binop
+    def mod(self, v1, v2):
+        return math.fmod(v1, v2)
+    @binop
+    def pow(self, v1, v2):
+        return math.pow(v1, v2)
+    @unaryop
+    def neg(self, v):
+        return -v
+    @unaryop
+    def pos(self, v):
+        return v
+    @unaryop
+    def abs(self, v):
+        return abs(v)
+    @binop
+    def max(self, v1, v2):
+        return max(v1, v2)
+    @binop
+    def min(self, v1, v2):
+        return min(v1, v2)
+
+    # Comparisons, they return unwraped results (for now)
+    def ne(self, v1, v2):
+        assert isinstance(v1, self.Box)
+        assert isinstance(v2, self.Box)
+        return v1.val != v2.val
+    def bool(self, v):
+        assert isinstance(v, self.Box)
+        return bool(v.val)
+
 
 def make_array_ptr(T):
     return lltype.Ptr(lltype.Array(T, hints={"nolength": True}))
 
+class BaseBox(object):
+    _mixin_ = True
+
+    def __init__(self, val):
+        if self.valtype is not None:
+            assert isinstance(val, self.valtype)
+        self.val = val
+
+    def wrap(self, space):
+        return space.wrap(self.val)
+
+    def convert_to(self, dtype):
+        return dtype.adapt_val(self.val)
+
+def make_box(TP, v=None):
+    class Box(BaseBox):
+        valtype = v
+    Box.__name__ = "%sBox" % TP.TO.OF._name
+    return Box
+
 VOID_TP = make_array_ptr(lltype.Void)
 
 class W_BoolDtype(LowLevelDtype, W_Dtype):
@@ -77,6 +166,7 @@
     aliases = ["?"]
     applevel_types = ["bool"]
     TP = make_array_ptr(lltype.Bool)
+    Box = make_box(TP, bool)
 
     def unwrap(self, space, w_item):
         return space.is_true(w_item)
@@ -86,12 +176,14 @@
     kind = SIGNEDLTR
     aliases = ["int8"]
     TP = make_array_ptr(rffi.SIGNEDCHAR)
+    Box = make_box(TP)
 
 class W_Int32Dtype(LowLevelDtype, W_Dtype):
     num = 5
     kind = SIGNEDLTR
     aliases = ["i"]
     TP = make_array_ptr(rffi.INT)
+    Box = make_box(TP)
 
 class W_LongDtype(LowLevelDtype, W_Dtype):
     num = 7
@@ -99,6 +191,7 @@
     aliases = ["l"]
     applevel_types = ["int"]
     TP = make_array_ptr(rffi.LONG)
+    Box = make_box(TP)
 
     def unwrap(self, space, w_item):
         return space.int_w(space.int(w_item))
@@ -107,17 +200,19 @@
     num = 9
     applevel_types = ["long"]
     TP = make_array_ptr(rffi.LONGLONG)
+    Box = make_box(TP)
 
 class W_Float64Dtype(LowLevelDtype, W_Dtype):
     num = 12
     applevel_types = ["float"]
     TP = make_array_ptr(lltype.Float)
+    Box = make_box(TP)
 
     def unwrap(self, space, w_item):
         return space.float_w(space.float(w_item))
 
     def str_format(self, item):
-        return float2string(item, 'g', DTSF_STR_PRECISION)
+        return float2string(item.val, 'g', DTSF_STR_PRECISION)
 
 
 ALL_DTYPES = [
diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -7,6 +7,7 @@
 from pypy.module.micronumpy import interp_ufuncs, interp_dtype
 from pypy.module.micronumpy.interp_support import Signature
 from pypy.rlib import jit
+from pypy.rlib.objectmodel import specialize
 from pypy.rlib.rfloat import DTSF_STR_PRECISION
 from pypy.rpython.lltypesystem import lltype
 from pypy.tool.sourcetools import func_with_new_name
@@ -19,15 +20,6 @@
 slice_driver1 = jit.JitDriver(greens=['signature'], reds=['i', 'j', 'step', 
'stop', 'source', 'dest'])
 slice_driver2 = jit.JitDriver(greens=['signature'], reds=['i', 'j', 'step', 
'stop', 'source', 'dest'])
 
-def add(v1, v2):
-    return v1 + v2
-def mul(v1, v2):
-    return v1 * v2
-def maximum(v1, v2):
-    return max(v1, v2)
-def minimum(v1, v2):
-    return min(v1, v2)
-
 class BaseArray(Wrappable):
     def __init__(self):
         self.invalidates = []
@@ -78,7 +70,7 @@
 
     def _binop_right_impl(w_ufunc):
         def impl(self, space, w_other):
-            w_other = FloatWrapper(space.float_w(w_other))
+            w_other = scalar_w(space, interp_dtype.W_Float64Dtype, w_other)
             return w_ufunc(space, w_other, self)
         return func_with_new_name(impl, "binop_right_%s_impl" % 
w_ufunc.__name__)
 
@@ -89,34 +81,36 @@
     descr_rpow = _binop_right_impl(interp_ufuncs.power)
     descr_rmod = _binop_right_impl(interp_ufuncs.mod)
 
-    def _reduce_sum_prod_impl(function, init):
+    def _reduce_sum_prod_impl(op_name, init):
         reduce_driver = jit.JitDriver(greens=['signature'],
                          reds = ['i', 'size', 'self', 'result'])
 
-        def loop(self, result, size):
+        def loop(self, res_dtype, result, size):
             i = 0
             while i < size:
                 reduce_driver.jit_merge_point(signature=self.signature,
                                               self=self, size=size, i=i,
                                               result=result)
-                result = function(result, self.eval(i))
+                result = getattr(res_dtype, op_name)(result, self.eval(i))
                 i += 1
             return result
 
         def impl(self, space):
-            return space.wrap(loop(self, init, self.find_size()))
-        return func_with_new_name(impl, "reduce_%s_impl" % function.__name__)
+            result = 
space.fromcache(interp_dtype.W_Float64Dtype).Box(init).convert_to(self.find_dtype())
+            return loop(self, self.find_dtype(), result, 
self.find_size()).wrap(space)
+        return func_with_new_name(impl, "reduce_%s_impl" % op_name)
 
-    def _reduce_max_min_impl(function):
+    def _reduce_max_min_impl(op_name):
         reduce_driver = jit.JitDriver(greens=['signature'],
                          reds = ['i', 'size', 'self', 'result'])
         def loop(self, result, size):
             i = 1
+            dtype = self.find_dtype()
             while i < size:
                 reduce_driver.jit_merge_point(signature=self.signature,
                                               self=self, size=size, i=i,
                                               result=result)
-                result = function(result, self.eval(i))
+                result = getattr(dtype, op_name)(result, self.eval(i))
                 i += 1
             return result
 
@@ -125,23 +119,24 @@
             if size == 0:
                 raise OperationError(space.w_ValueError,
                     space.wrap("Can't call %s on zero-size arrays" \
-                            % function.__name__))
-            return space.wrap(loop(self, self.eval(0), size))
-        return func_with_new_name(impl, "reduce_%s_impl" % function.__name__)
+                            % op_name))
+            return loop(self, self.eval(0), size).wrap(space)
+        return func_with_new_name(impl, "reduce_%s_impl" % op_name)
 
-    def _reduce_argmax_argmin_impl(function):
+    def _reduce_argmax_argmin_impl(op_name):
         reduce_driver = jit.JitDriver(greens=['signature'],
                          reds = ['i', 'size', 'result', 'self', 'cur_best'])
         def loop(self, size):
             result = 0
             cur_best = self.eval(0)
             i = 1
+            dtype = self.find_dtype()
             while i < size:
                 reduce_driver.jit_merge_point(signature=self.signature,
                                               self=self, size=size, i=i,
                                               result=result, cur_best=cur_best)
-                new_best = function(cur_best, self.eval(i))
-                if new_best != cur_best:
+                new_best = getattr(dtype, op_name)(cur_best, self.eval(i))
+                if dtype.ne(new_best, cur_best):
                     result = i
                     cur_best = new_best
                 i += 1
@@ -151,16 +146,17 @@
             if size == 0:
                 raise OperationError(space.w_ValueError,
                     space.wrap("Can't call %s on zero-size arrays" \
-                            % function.__name__))
+                            % op_name))
             return space.wrap(loop(self, size))
-        return func_with_new_name(impl, "reduce_arg%s_impl" % 
function.__name__)
+        return func_with_new_name(impl, "reduce_arg%s_impl" % op_name)
 
     def _all(self):
         size = self.find_size()
+        dtype = self.find_dtype()
         i = 0
         while i < size:
             all_driver.jit_merge_point(signature=self.signature, self=self, 
size=size, i=i)
-            if not self.eval(i):
+            if not dtype.bool(self.eval(i)):
                 return False
             i += 1
         return True
@@ -169,22 +165,23 @@
 
     def _any(self):
         size = self.find_size()
+        dtype = self.find_dtype()
         i = 0
         while i < size:
             any_driver.jit_merge_point(signature=self.signature, self=self, 
size=size, i=i)
-            if self.eval(i):
+            if dtype.bool(self.eval(i)):
                 return True
             i += 1
         return False
     def descr_any(self, space):
         return space.wrap(self._any())
 
-    descr_sum = _reduce_sum_prod_impl(add, 0.0)
-    descr_prod = _reduce_sum_prod_impl(mul, 1.0)
-    descr_max = _reduce_max_min_impl(maximum)
-    descr_min = _reduce_max_min_impl(minimum)
-    descr_argmax = _reduce_argmax_argmin_impl(maximum)
-    descr_argmin = _reduce_argmax_argmin_impl(minimum)
+    descr_sum = _reduce_sum_prod_impl("add", 0.0)
+    descr_prod = _reduce_sum_prod_impl("mul", 1.0)
+    descr_max = _reduce_max_min_impl("max")
+    descr_min = _reduce_max_min_impl("min")
+    descr_argmax = _reduce_argmax_argmin_impl("max")
+    descr_argmin = _reduce_argmax_argmin_impl("min")
 
     def descr_dot(self, space, w_other):
         if isinstance(w_other, BaseArray):
@@ -240,7 +237,7 @@
         start, stop, step, slice_length = space.decode_index4(w_idx, 
self.find_size())
         if step == 0:
             # Single index
-            return space.wrap(self.get_concrete().eval(start))
+            return self.get_concrete().eval(start).wrap(space)
         else:
             # Slice
             res = SingleDimSlice(start, stop, step, slice_length, self, 
self.signature.transition(SingleDimSlice.static_signature))
@@ -302,18 +299,27 @@
         return w_obj
     else:
         # If it's a scalar
-        return FloatWrapper(space.float_w(w_obj))
+        return scalar_w(space, interp_dtype.W_Float64Dtype, w_obj)
 
-class FloatWrapper(BaseArray):
[email protected](1)
+def scalar_w(space, dtype, w_obj):
+    return Scalar(scalar(space, dtype, w_obj))
+
[email protected](1)
+def scalar(space, dtype, w_obj):
+    dtype = space.fromcache(dtype)
+    return dtype.Box(dtype.unwrap(space, w_obj))
+
+class Scalar(BaseArray):
     """
     Intermediate class representing a float literal.
     """
-    _immutable_fields_ = ["float_value"]
+    _immutable_fields_ = ["value"]
     signature = Signature()
 
-    def __init__(self, float_value):
+    def __init__(self, value):
         BaseArray.__init__(self)
-        self.float_value = float_value
+        self.value = value
 
     def find_size(self):
         raise ValueError
@@ -322,7 +328,7 @@
         raise ValueError
 
     def eval(self, i):
-        return self.float_value
+        return self.value
 
 class VirtualArray(BaseArray):
     """
@@ -394,7 +400,7 @@
         return self.values.find_dtype()
 
     def _eval(self, i):
-        return self.function(self.values.eval(i))
+        return self.function(self.find_dtype(), self.values.eval(i))
 
 class Call2(VirtualArray):
     """
@@ -420,8 +426,10 @@
         return self.right.find_size()
 
     def _eval(self, i):
+        dtype = self.find_dtype()
         lhs, rhs = self.left.eval(i), self.right.eval(i)
-        return self.function(lhs, rhs)
+        lhs, rhs = lhs.convert_to(dtype), rhs.convert_to(dtype)
+        return self.function(dtype, lhs, rhs)
 
     def _find_dtype(self):
         lhs_dtype = None
@@ -564,9 +572,11 @@
 
 @unwrap_spec(size=int)
 def ones(space, size):
-    arr = SingleDimArray(size, 
dtype=space.fromcache(interp_dtype.W_Float64Dtype))
+    dtype = space.fromcache(interp_dtype.W_Float64Dtype)
+    arr = SingleDimArray(size, dtype=dtype)
+    one = dtype.Box(1.0)
     for i in xrange(size):
-        arr.dtype.setitem(arr.storage, i, 1.0)
+        arr.dtype.setitem(arr.storage, i, one)
     return space.wrap(arr)
 
 BaseArray.typedef = TypeDef(
diff --git a/pypy/module/micronumpy/interp_support.py 
b/pypy/module/micronumpy/interp_support.py
--- a/pypy/module/micronumpy/interp_support.py
+++ b/pypy/module/micronumpy/interp_support.py
@@ -25,7 +25,7 @@
     i = 0
     while i < number:
         part = s[start:end]
-        a.dtype.setitem(a.storage, i, runpack('d', part))
+        a.dtype.setitem(a.storage, i, a.dtype.Box(runpack('d', part)))
         i += 1
         start += FLOAT_SIZE
         end += FLOAT_SIZE
diff --git a/pypy/module/micronumpy/interp_ufuncs.py 
b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -1,20 +1,22 @@
 import math
 
+from pypy.module.micronumpy import interp_dtype
 from pypy.module.micronumpy.interp_support import Signature
 from pypy.rlib import rfloat
 from pypy.tool.sourcetools import func_with_new_name
 
+
 def ufunc(func):
     signature = Signature()
     def impl(space, w_obj):
-        from pypy.module.micronumpy.interp_numarray import Call1, 
convert_to_array
+        from pypy.module.micronumpy.interp_numarray import Call1, 
convert_to_array, scalar
         if space.issequence_w(w_obj):
             w_obj_arr = convert_to_array(space, w_obj)
             w_res = Call1(func, w_obj_arr, 
w_obj_arr.signature.transition(signature))
             w_obj_arr.invalidates.append(w_res)
             return w_res
         else:
-            return space.wrap(func(space.float_w(w_obj)))
+            return func(scalar(interp_dtype.W_Float64_Dtype, 
w_obj)).wrap(space)
     return func_with_new_name(impl, "%s_dispatcher" % func.__name__)
 
 def ufunc2(func):
@@ -33,21 +35,35 @@
             return space.wrap(func(space.float_w(w_lhs), space.float_w(w_rhs)))
     return func_with_new_name(impl, "%s_dispatcher" % func.__name__)
 
-@ufunc
-def absolute(value):
-    return abs(value)
+def ufunc_dtype_caller(ufunc_name, op_name, argcount):
+    if argcount == 1:
+        @ufunc
+        def impl(res_dtype, value):
+            return getattr(res_dtype, op_name)(value)
+    elif argcount == 2:
+        @ufunc2
+        def impl(res_dtype, lvalue, rvalue):
+            return getattr(res_dtype, op_name)(lvalue, rvalue)
+    impl.__name__ = ufunc_name
+    return impl
 
-@ufunc2
-def add(lvalue, rvalue):
-    return lvalue + rvalue
+for ufunc_name, op_name, argcount in [
+    ("add", "add", 2),
+    ("subtract", "sub", 2),
+    ("multiply", "mul", 2),
+    ("divide", "div", 2),
+    ("mod", "mod", 2),
+    ("power", "pow", 2),
+    ("negative", "neg", 1),
+    ("positive", "pos", 1),
+    ("absolute", "abs", 1),
+]:
+    globals()[ufunc_name] = ufunc_dtype_caller(ufunc_name, op_name, argcount)
 
 @ufunc2
 def copysign(lvalue, rvalue):
     return rfloat.copysign(lvalue, rvalue)
 
-@ufunc2
-def divide(lvalue, rvalue):
-    return lvalue / rvalue
 
 @ufunc
 def exp(value):
@@ -68,18 +84,9 @@
 def minimum(lvalue, rvalue):
     return min(lvalue, rvalue)
 
-@ufunc2
-def multiply(lvalue, rvalue):
-    return lvalue * rvalue
 
-# Used by numarray for __pos__. Not visible from numpy application space.
-@ufunc
-def positive(value):
-    return value
 
-@ufunc
-def negative(value):
-    return -value
+
 
 @ufunc
 def reciprocal(value):
@@ -87,10 +94,6 @@
         return rfloat.copysign(rfloat.INFINITY, value)
     return 1.0 / value
 
-@ufunc2
-def subtract(lvalue, rvalue):
-    return lvalue - rvalue
-
 @ufunc
 def floor(value):
     return math.floor(value)
@@ -113,13 +116,6 @@
 def tan(value):
     return math.tan(value)
 
-@ufunc2
-def power(lvalue, rvalue):
-    return math.pow(lvalue, rvalue)
-
-@ufunc2
-def mod(lvalue, rvalue):
-    return math.fmod(lvalue, rvalue)
 
 
 @ufunc
diff --git a/pypy/module/micronumpy/test/test_base.py 
b/pypy/module/micronumpy/test/test_base.py
--- a/pypy/module/micronumpy/test/test_base.py
+++ b/pypy/module/micronumpy/test/test_base.py
@@ -1,6 +1,6 @@
 from pypy.conftest import gettestobjspace
 from pypy.module.micronumpy.interp_dtype import W_Float64Dtype
-from pypy.module.micronumpy.interp_numarray import SingleDimArray, FloatWrapper
+from pypy.module.micronumpy.interp_numarray import SingleDimArray, Scalar
 
 
 class BaseNumpyAppTest(object):
@@ -11,9 +11,9 @@
     def test_binop_signature(self, space):
         ar = SingleDimArray(10, dtype=space.fromcache(W_Float64Dtype))
         v1 = ar.descr_add(space, ar)
-        v2 = ar.descr_add(space, FloatWrapper(2.0))
+        v2 = ar.descr_add(space, Scalar(2.0))
         assert v1.signature is not v2.signature
-        v3 = ar.descr_add(space, FloatWrapper(1.0))
+        v3 = ar.descr_add(space, Scalar(1.0))
         assert v2.signature is v3.signature
         v4 = ar.descr_add(space, ar)
         assert v1.signature is v4.signature
diff --git a/pypy/module/micronumpy/test/test_zjit.py 
b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -2,8 +2,8 @@
 from pypy.module.micronumpy.compile import numpy_compile
 from pypy.module.micronumpy.interp_dtype import W_Float64Dtype
 from pypy.module.micronumpy.interp_numarray import (SingleDimArray, Signature,
-    FloatWrapper, Call2, SingleDimSlice, add, mul, Call1)
-from pypy.module.micronumpy.interp_ufuncs import negative
+    Scalar, Call2, SingleDimSlice, Call1)
+from pypy.module.micronumpy.interp_ufuncs import negative, add
 from pypy.rlib.nonconst import NonConstant
 from pypy.rlib.objectmodel import specialize
 from pypy.rpython.test.test_llinterp import interpret
@@ -30,7 +30,7 @@
     def test_add(self):
         def f(i):
             ar = SingleDimArray(i, dtype=self.float64_dtype)
-            v = Call2(add, ar, ar, Signature())
+            v = add(self.float64_dtype, ar, ar)
             concrete = v.get_concrete()
             return concrete.dtype.getitem(concrete.storage, 3)
 
@@ -43,7 +43,7 @@
     def test_floatadd(self):
         def f(i):
             ar = SingleDimArray(i)
-            v = Call2(add, ar, FloatWrapper(4.5), Signature())
+            v = Call2(add, ar, Scalar(4.5), Signature())
             return v.dtype.getitem(v.get_concrete().storage, 3)
 
         result = self.meta_interp(f, [5], listops=True, backendopt=True)
@@ -164,8 +164,8 @@
     def test_already_forecd(self):
         def f(i):
             ar = SingleDimArray(i)
-            v1 = Call2(add, ar, FloatWrapper(4.5), Signature())
-            v2 = Call2(mul, v1, FloatWrapper(4.5), Signature())
+            v1 = Call2(add, ar, Scalar(4.5), Signature())
+            v2 = Call2(mul, v1, Scalar(4.5), Signature())
             v1.force_if_needed()
             return v2.get_concrete().storage[3]
 
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to