Author: Amaury Forgeot d'Arc <[email protected]>
Branch: decimal-libmpdec
Changeset: r71550:282b135bd0ee
Date: 2014-05-11 19:00 +0200
http://bitbucket.org/pypy/pypy/changeset/282b135bd0ee/

Log:    Add Decimal binary operations, and all comparisons.

diff --git a/pypy/module/_decimal/interp_context.py 
b/pypy/module/_decimal/interp_context.py
--- a/pypy/module/_decimal/interp_context.py
+++ b/pypy/module/_decimal/interp_context.py
@@ -212,9 +212,9 @@
 
     # Binary arithmetic functions
     def binary_method(self, space, mpd_func, w_x, w_y):
-        from pypy.module._decimal.interp_decimal import W_Decimal
-        w_a, w_b = W_Decimal.convert_binop_raise(space, self, w_x, w_y)
-        w_result = W_Decimal.allocate(space)
+        from pypy.module._decimal import interp_decimal
+        w_a, w_b = interp_decimal.convert_binop_raise(space, self, w_x, w_y)
+        w_result = interp_decimal.W_Decimal.allocate(space)
         with self.catch_status(space) as (ctx, status_ptr):
             mpd_func(w_result.mpd, w_a.mpd, w_b.mpd, ctx, status_ptr)
         return w_result
diff --git a/pypy/module/_decimal/interp_decimal.py 
b/pypy/module/_decimal/interp_decimal.py
--- a/pypy/module/_decimal/interp_decimal.py
+++ b/pypy/module/_decimal/interp_decimal.py
@@ -161,72 +161,153 @@
         return w_result
 
     def compare(self, space, w_other, op):
-        if not isinstance(w_other, W_Decimal):  # So far
-            return space.w_NotImplemented
+        context = interp_context.getcontext(space)
+        w_err, w_other = convert_op(space, context, w_other)
+        if w_err:
+            return w_err
         with lltype.scoped_alloc(rffi.CArrayPtr(rffi.UINT).TO, 1) as 
status_ptr:
             r = rmpdec.mpd_qcmp(self.mpd, w_other.mpd, status_ptr)
+
+            if r > 0xFFFF:
+                # sNaNs or op={le,ge,lt,gt} always signal.
+                if (rmpdec.mpd_issnan(self.mpd) or 
rmpdec.mpd_issnan(w_other.mpd)
+                    or (op not in ('eq', 'ne'))):
+                    status = rffi.cast(lltype.Signed, status_ptr[0])
+                    context.addstatus(space, status)
+                # qNaN comparison with op={eq,ne} or comparison with
+                # InvalidOperation disabled.
+                if op == 'ne':
+                    return space.w_True
+                else:
+                    return space.w_False
+
         if op == 'eq':
             return space.wrap(r == 0)
+        elif op == 'ne':
+            return space.wrap(r != 0)
+        elif op == 'le':
+            return space.wrap(r <= 0)
+        elif op == 'ge':
+            return space.wrap(r >= 0)
+        elif op == 'lt':
+            return space.wrap(r == -1)
+        elif op == 'gt':
+            return space.wrap(r == 1)
         else:
             return space.w_NotImplemented
 
     def descr_eq(self, space, w_other):
         return self.compare(space, w_other, 'eq')
+    def descr_ne(self, space, w_other):
+        return self.compare(space, w_other, 'ne')
+    def descr_lt(self, space, w_other):
+        return self.compare(space, w_other, 'lt')
+    def descr_le(self, space, w_other):
+        return self.compare(space, w_other, 'le')
+    def descr_gt(self, space, w_other):
+        return self.compare(space, w_other, 'gt')
+    def descr_ge(self, space, w_other):
+        return self.compare(space, w_other, 'ge')
 
-    # Operations
-    @staticmethod
-    def convert_op(space, context, w_value):
-        if isinstance(w_value, W_Decimal):
-            return None, w_value
-        elif space.isinstance_w(w_value, space.w_int):
-            value = space.bigint_w(w_value)
-            return None, decimal_from_bigint(space, None, value, context,
-                                             exact=True)
-        return space.w_NotImplemented, None
+    # Binary operations
+
+    def descr_add(self, space, w_other):
+        return binary_number_method(space, rmpdec.mpd_qadd, self, w_other)
+    def descr_sub(self, space, w_other):
+        return binary_number_method(space, rmpdec.mpd_qsub, self, w_other)
+    def descr_mul(self, space, w_other):
+        return binary_number_method(space, rmpdec.mpd_qmul, self, w_other)
+    def descr_truediv(self, space, w_other):
+        return binary_number_method(space, rmpdec.mpd_qdiv, self, w_other)
+    def descr_floordiv(self, space, w_other):
+        return binary_number_method(space, rmpdec.mpd_qdivint, self, w_other)
+    def descr_mod(self, space, w_other):
+        return binary_number_method(space, rmpdec.mpd_qrem, self, w_other)
+
+    def descr_radd(self, space, w_other):
+        return binary_number_method(space, rmpdec.mpd_qadd, w_other, self)
+    def descr_rsub(self, space, w_other):
+        return binary_number_method(space, rmpdec.mpd_qsub, w_other, self)
+    def descr_rmul(self, space, w_other):
+        return binary_number_method(space, rmpdec.mpd_qmul, w_other, self)
+    def descr_rtruediv(self, space, w_other):
+        return binary_number_method(space, rmpdec.mpd_qdiv, w_other, self)
+    def descr_rfloordiv(self, space, w_other):
+        return binary_number_method(space, rmpdec.mpd_qdivint, w_other, self)
+    def descr_rmod(self, space, w_other):
+        return binary_number_method(space, rmpdec.mpd_qrem, w_other, self)
 
     @staticmethod
-    def convert_binop(space, context, w_x, w_y):
-        w_err, w_a = W_Decimal.convert_op(space, context, w_x)
+    def divmod_impl(space, w_x, w_y):
+        context = interp_context.getcontext(space)
+
+        w_err, w_a, w_b = convert_binop(space, context, w_x, w_y)
         if w_err:
-            return w_err, None, None
-        w_err, w_b = W_Decimal.convert_op(space, context, w_y)
-        if w_err:
-            return w_err, None, None
-        return None, w_a, w_b
+            return w_err
+        w_q = W_Decimal.allocate(space)
+        w_r = W_Decimal.allocate(space)
+        with context.catch_status(space) as (ctx, status_ptr):
+            rmpdec.mpd_qdivmod(w_q.mpd, w_r.mpd, w_a.mpd, w_b.mpd,
+                               ctx, status_ptr)
+        return space.newtuple([w_q, w_r])
+
+    def descr_divmod(self, space, w_other):
+        return W_Decimal.divmod_impl(space, self, w_other)
+    def descr_rdivmod(self, space, w_other):
+        return W_Decimal.divmod_impl(space, w_other, self)
 
     @staticmethod
-    def convert_binop_raise(space, context, w_x, w_y):
-        w_err, w_a = W_Decimal.convert_op(space, context, w_x)
-        if w_err:
-            raise oefmt(space.w_TypeError,
-                        "conversion from %N to Decimal is not supported",
-                        space.type(w_x))
-        w_err, w_b = W_Decimal.convert_op(space, context, w_y)
-        if w_err:
-            raise oefmt(space.w_TypeError,
-                        "conversion from %N to Decimal is not supported",
-                        space.type(w_y))
-        return w_a, w_b
-
-    def binary_number_method(self, space, mpd_func, w_other):
+    def pow_impl(space, w_base, w_exp, w_mod):
         context = interp_context.getcontext(space)
 
-        w_err, w_a, w_b = W_Decimal.convert_binop(space, context, self, 
w_other)
+        w_err, w_a, w_b = convert_binop(space, context, w_base, w_exp)
         if w_err:
             return w_err
+
+        if not space.is_none(w_mod):
+            w_err, w_c = convert_op(space, context, w_mod)
+            if w_err:
+                return w_err
+        else:
+            w_c = None
         w_result = W_Decimal.allocate(space)
         with context.catch_status(space) as (ctx, status_ptr):
-            mpd_func(w_result.mpd, w_a.mpd, w_b.mpd, ctx, status_ptr)
+            if w_c:
+                rmpdec.mpd_qpowmod(w_result.mpd, w_a.mpd, w_b.mpd, w_c.mpd,
+                                   ctx, status_ptr)
+            else:
+                rmpdec.mpd_qpow(w_result.mpd, w_a.mpd, w_b.mpd,
+                                ctx, status_ptr)
         return w_result
 
-    def descr_add(self, space, w_other):
-        return self.binary_number_method(space, rmpdec.mpd_qadd, w_other)
-    def descr_sub(self, space, w_other):
-        return self.binary_number_method(space, rmpdec.mpd_qsub, w_other)
-    def descr_mul(self, space, w_other):
-        return self.binary_number_method(space, rmpdec.mpd_qmul, w_other)
-    def descr_truediv(self, space, w_other):
-        return self.binary_number_method(space, rmpdec.mpd_qdiv, w_other)
+    def descr_pow(self, space, w_other, w_mod=None):
+        return W_Decimal.pow_impl(space, self, w_other, w_mod)
+    def descr_rpow(self, space, w_other):
+        return W_Decimal.pow_impl(space, w_other, self, None)
+
+    # Unary operations
+    def unary_number_method(self, space, mpd_func):
+        context = interp_context.getcontext(space)
+        w_result = W_Decimal.allocate(space)
+        with context.catch_status(space) as (ctx, status_ptr):
+            mpd_func(w_result.mpd, self.mpd, ctx, status_ptr)
+        return w_result
+
+    def descr_neg(self, space):
+        return self.unary_number_method(space, rmpdec.mpd_qminus)
+    def descr_pos(self, space):
+        return self.unary_number_method(space, rmpdec.mpd_qplus)
+    def descr_abs(self, space):
+        return self.unary_number_method(space, rmpdec.mpd_qabs)
+
+    def copy_sign_w(self, space, w_other, w_context=None):
+        context = convert_context(space, w_context)
+        w_other = convert_op_raise(space, context, w_other)
+        w_result = W_Decimal.allocate(space)
+        with context.catch_status(space) as (ctx, status_ptr):
+            rmpdec.mpd_qcopy_sign(w_result.mpd, self.mpd, w_other.mpd,
+                                  ctx, status_ptr)
+        return w_result
 
     # Boolean functions
     def is_qnan_w(self, space):
@@ -235,6 +316,62 @@
         return space.wrap(bool(rmpdec.mpd_isinfinite(self.mpd)))
 
 
+# Helper functions for arithmetic conversions
+def convert_op(space, context, w_value):
+    if isinstance(w_value, W_Decimal):
+        return None, w_value
+    elif space.isinstance_w(w_value, space.w_int):
+        value = space.bigint_w(w_value)
+        return None, decimal_from_bigint(space, None, value, context,
+                                         exact=True)
+    return space.w_NotImplemented, None
+
+def convert_op_raise(space, context, w_x):
+    w_err, w_a = convert_op(space, context, w_x)
+    if w_err:
+        raise oefmt(space.w_TypeError,
+                    "conversion from %N to Decimal is not supported",
+                    space.type(w_x))
+    return w_a
+
+def convert_binop(space, context, w_x, w_y):
+    w_err, w_a = convert_op(space, context, w_x)
+    if w_err:
+        return w_err, None, None
+    w_err, w_b = convert_op(space, context, w_y)
+    if w_err:
+        return w_err, None, None
+    return None, w_a, w_b
+
+def convert_binop_raise(space, context, w_x, w_y):
+    w_err, w_a = convert_op(space, context, w_x)
+    if w_err:
+        raise oefmt(space.w_TypeError,
+                    "conversion from %N to Decimal is not supported",
+                    space.type(w_x))
+    w_err, w_b = convert_op(space, context, w_y)
+    if w_err:
+        raise oefmt(space.w_TypeError,
+                    "conversion from %N to Decimal is not supported",
+                    space.type(w_y))
+    return w_a, w_b
+
+def binary_number_method(space, mpd_func, w_x, w_y):
+    context = interp_context.getcontext(space)
+
+    w_err, w_a, w_b = convert_binop(space, context, w_x, w_y)
+    if w_err:
+        return w_err
+    w_result = W_Decimal.allocate(space)
+    with context.catch_status(space) as (ctx, status_ptr):
+        mpd_func(w_result.mpd, w_a.mpd, w_b.mpd, ctx, status_ptr)
+    return w_result
+
+def convert_context(space, w_context):
+    if w_context is None:
+        return interp_context.getcontext(space)
+    return space.interp_w(interp_context.W_Context, w_context)
+
 # Constructors
 def decimal_from_ssize(space, w_subtype, value, context, exact=True):
     w_result = W_Decimal.allocate(space, w_subtype)
@@ -473,13 +610,37 @@
     __floor__ = interp2app(W_Decimal.descr_floor),
     __ceil__ = interp2app(W_Decimal.descr_ceil),
     __round__ = interp2app(W_Decimal.descr_round),
+    #
     __eq__ = interp2app(W_Decimal.descr_eq),
+    __ne__ = interp2app(W_Decimal.descr_ne),
+    __le__ = interp2app(W_Decimal.descr_le),
+    __ge__ = interp2app(W_Decimal.descr_ge),
+    __lt__ = interp2app(W_Decimal.descr_lt),
+    __gt__ = interp2app(W_Decimal.descr_gt),
+    #
+    __pos__ = interp2app(W_Decimal.descr_pos),
+    __neg__ = interp2app(W_Decimal.descr_neg),
+    __abs__ = interp2app(W_Decimal.descr_abs),
     #
     __add__ = interp2app(W_Decimal.descr_add),
     __sub__ = interp2app(W_Decimal.descr_sub),
     __mul__ = interp2app(W_Decimal.descr_mul),
     __truediv__ = interp2app(W_Decimal.descr_truediv),
+    __floordiv__ = interp2app(W_Decimal.descr_floordiv),
+    __mod__ = interp2app(W_Decimal.descr_mod),
+    __divmod__ = interp2app(W_Decimal.descr_divmod),
+    __pow__ = interp2app(W_Decimal.descr_pow),
     #
+    __radd__ = interp2app(W_Decimal.descr_radd),
+    __rsub__ = interp2app(W_Decimal.descr_rsub),
+    __rmul__ = interp2app(W_Decimal.descr_rmul),
+    __rtruediv__ = interp2app(W_Decimal.descr_rtruediv),
+    __rfloordiv__ = interp2app(W_Decimal.descr_rfloordiv),
+    __rmod__ = interp2app(W_Decimal.descr_rmod),
+    __rdivmod__ = interp2app(W_Decimal.descr_rdivmod),
+    __rpow__ = interp2app(W_Decimal.descr_rpow),
+    #
+    copy_sign = interp2app(W_Decimal.copy_sign_w),
     is_qnan = interp2app(W_Decimal.is_qnan_w),
     is_infinite = interp2app(W_Decimal.is_infinite_w),
     )
diff --git a/pypy/module/_decimal/test/test_context.py 
b/pypy/module/_decimal/test/test_context.py
--- a/pypy/module/_decimal/test/test_context.py
+++ b/pypy/module/_decimal/test/test_context.py
@@ -1,3 +1,6 @@
+from pypy.interpreter import gateway
+import random
+
 class AppTestContext:
     spaceconfig = dict(usemodules=('_decimal',))
 
@@ -6,6 +9,10 @@
         cls.w_decimal = space.call_function(space.builtin.get('__import__'),
                                             space.wrap("_decimal"))
         cls.w_Decimal = space.getattr(cls.w_decimal, space.wrap("Decimal"))
+        def random_float(space):
+            f = random.expovariate(0.01) * (random.random() * 2.0 - 1.0)
+            return space.wrap(f)
+        cls.w_random_float = space.wrap(gateway.interp2app(random_float))
 
     def test_context_repr(self):
         c = self.decimal.DefaultContext.copy()
@@ -31,3 +38,73 @@
             "flags=[], traps=[])"
         assert s == t
 
+    def test_explicit_context_create_from_float(self):
+        Decimal = self.decimal.Decimal
+
+        nc = self.decimal.Context()
+        r = nc.create_decimal(0.1)
+        assert type(r) is Decimal
+        assert str(r) == '0.1000000000000000055511151231'
+        assert nc.create_decimal(float('nan')).is_qnan()
+        assert nc.create_decimal(float('inf')).is_infinite()
+        assert nc.create_decimal(float('-inf')).is_infinite()
+        assert (str(nc.create_decimal(float('nan'))) ==
+                str(nc.create_decimal('NaN')))
+        assert (str(nc.create_decimal(float('inf'))) ==
+                str(nc.create_decimal('Infinity')))
+        assert (str(nc.create_decimal(float('-inf'))) ==
+                str(nc.create_decimal('-Infinity')))
+        assert (str(nc.create_decimal(float('-0.0'))) ==
+                str(nc.create_decimal('-0')))
+        nc.prec = 100
+        for i in range(200):
+            x = self.random_float()
+            assert x == float(nc.create_decimal(x))  # roundtrip
+
+    def test_add(self):
+        Decimal = self.decimal.Decimal
+        Context = self.decimal.Context
+
+        c = Context()
+        d = c.add(Decimal(1), Decimal(1))
+        assert c.add(1, 1) == d
+        assert c.add(Decimal(1), 1) == d
+        assert c.add(1, Decimal(1)) == d
+        raises(TypeError, c.add, '1', 1)
+        raises(TypeError, c.add, 1, '1')
+
+    def test_subtract(self):
+        Decimal = self.decimal.Decimal
+        Context = self.decimal.Context
+
+        c = Context()
+        d = c.subtract(Decimal(1), Decimal(2))
+        assert c.subtract(1, 2) == d
+        assert c.subtract(Decimal(1), 2) == d
+        assert c.subtract(1, Decimal(2)) == d
+        raises(TypeError, c.subtract, '1', 2)
+        raises(TypeError, c.subtract, 1, '2')
+
+    def test_multiply(self):
+        Decimal = self.decimal.Decimal
+        Context = self.decimal.Context
+
+        c = Context()
+        d = c.multiply(Decimal(1), Decimal(2))
+        assert c.multiply(1, 2)== d
+        assert c.multiply(Decimal(1), 2)== d
+        assert c.multiply(1, Decimal(2))== d
+        raises(TypeError, c.multiply, '1', 2)
+        raises(TypeError, c.multiply, 1, '2')
+
+    def test_divide(self):
+        Decimal = self.decimal.Decimal
+        Context = self.decimal.Context
+
+        c = Context()
+        d = c.divide(Decimal(1), Decimal(2))
+        assert c.divide(1, 2)== d
+        assert c.divide(Decimal(1), 2)== d
+        assert c.divide(1, Decimal(2))== d
+        raises(TypeError, c.divide, '1', 2)
+        raises(TypeError, c.divide, 1, '2')
diff --git a/pypy/module/_decimal/test/test_decimal.py 
b/pypy/module/_decimal/test/test_decimal.py
--- a/pypy/module/_decimal/test/test_decimal.py
+++ b/pypy/module/_decimal/test/test_decimal.py
@@ -293,29 +293,6 @@
         assert str(nc.create_decimal(Decimal('NaN12345'))) == 'NaN'
         assert nc.flags[InvalidOperation]
 
-    def test_explicit_context_create_from_float(self):
-        Decimal = self.decimal.Decimal
-
-        nc = self.decimal.Context()
-        r = nc.create_decimal(0.1)
-        assert type(r) is Decimal
-        assert str(r) == '0.1000000000000000055511151231'
-        assert nc.create_decimal(float('nan')).is_qnan()
-        assert nc.create_decimal(float('inf')).is_infinite()
-        assert nc.create_decimal(float('-inf')).is_infinite()
-        assert (str(nc.create_decimal(float('nan'))) ==
-                str(nc.create_decimal('NaN')))
-        assert (str(nc.create_decimal(float('inf'))) ==
-                str(nc.create_decimal('Infinity')))
-        assert (str(nc.create_decimal(float('-inf'))) ==
-                str(nc.create_decimal('-Infinity')))
-        assert (str(nc.create_decimal(float('-0.0'))) ==
-                str(nc.create_decimal('-0')))
-        nc.prec = 100
-        for i in range(200):
-            x = self.random_float()
-            assert x == float(nc.create_decimal(x))  # roundtrip
-
     def test_operations(self):
         Decimal = self.decimal.Decimal
 
@@ -437,50 +414,301 @@
         for d, n, r in test_triples:
             assert str(round(Decimal(d), n)) == r
 
-    def test_add(self):
+    def test_addition(self):
         Decimal = self.decimal.Decimal
-        Context = self.decimal.Context
 
-        c = Context()
-        d = c.add(Decimal(1), Decimal(1))
-        assert c.add(1, 1) == d
-        assert c.add(Decimal(1), 1) == d
-        assert c.add(1, Decimal(1)) == d
-        raises(TypeError, c.add, '1', 1)
-        raises(TypeError, c.add, 1, '1')
+        d1 = Decimal('-11.1')
+        d2 = Decimal('22.2')
 
-    def test_subtract(self):
+        #two Decimals
+        assert d1+d2 == Decimal('11.1')
+        assert d2+d1 == Decimal('11.1')
+
+        #with other type, left
+        c = d1 + 5
+        assert c == Decimal('-6.1')
+        assert type(c) == type(d1)
+
+        #with other type, right
+        c = 5 + d1
+        assert c == Decimal('-6.1')
+        assert type(c) == type(d1)
+
+        #inline with decimal
+        d1 += d2
+        assert d1 == Decimal('11.1')
+
+        #inline with other type
+        d1 += 5
+        assert d1 == Decimal('16.1')
+
+    def test_subtraction(self):
         Decimal = self.decimal.Decimal
-        Context = self.decimal.Context
 
-        c = Context()
-        d = c.subtract(Decimal(1), Decimal(2))
-        assert c.subtract(1, 2) == d
-        assert c.subtract(Decimal(1), 2) == d
-        assert c.subtract(1, Decimal(2)) == d
-        raises(TypeError, c.subtract, '1', 2)
-        raises(TypeError, c.subtract, 1, '2')
+        d1 = Decimal('-11.1')
+        d2 = Decimal('22.2')
 
-    def test_multiply(self):
+        #two Decimals
+        assert d1-d2 == Decimal('-33.3')
+        assert d2-d1 == Decimal('33.3')
+
+        #with other type, left
+        c = d1 - 5
+        assert c == Decimal('-16.1')
+        assert type(c) == type(d1)
+
+        #with other type, right
+        c = 5 - d1
+        assert c == Decimal('16.1')
+        assert type(c) == type(d1)
+
+        #inline with decimal
+        d1 -= d2
+        assert d1 == Decimal('-33.3')
+
+        #inline with other type
+        d1 -= 5
+        assert d1 == Decimal('-38.3')
+
+    def test_multiplication(self):
         Decimal = self.decimal.Decimal
-        Context = self.decimal.Context
 
-        c = Context()
-        d = c.multiply(Decimal(1), Decimal(2))
-        assert c.multiply(1, 2)== d
-        assert c.multiply(Decimal(1), 2)== d
-        assert c.multiply(1, Decimal(2))== d
-        raises(TypeError, c.multiply, '1', 2)
-        raises(TypeError, c.multiply, 1, '2')
+        d1 = Decimal('-5')
+        d2 = Decimal('3')
 
-    def test_divide(self):
+        #two Decimals
+        assert d1*d2 == Decimal('-15')
+        assert d2*d1 == Decimal('-15')
+
+        #with other type, left
+        c = d1 * 5
+        assert c == Decimal('-25')
+        assert type(c) == type(d1)
+
+        #with other type, right
+        c = 5 * d1
+        assert c == Decimal('-25')
+        assert type(c) == type(d1)
+
+        #inline with decimal
+        d1 *= d2
+        assert d1 == Decimal('-15')
+
+        #inline with other type
+        d1 *= 5
+        assert d1 == Decimal('-75')
+
+    def test_division(self):
         Decimal = self.decimal.Decimal
-        Context = self.decimal.Context
 
-        c = Context()
-        d = c.divide(Decimal(1), Decimal(2))
-        assert c.divide(1, 2)== d
-        assert c.divide(Decimal(1), 2)== d
-        assert c.divide(1, Decimal(2))== d
-        raises(TypeError, c.divide, '1', 2)
-        raises(TypeError, c.divide, 1, '2')
+        d1 = Decimal('-5')
+        d2 = Decimal('2')
+
+        #two Decimals
+        assert d1/d2 == Decimal('-2.5')
+        assert d2/d1 == Decimal('-0.4')
+
+        #with other type, left
+        c = d1 / 4
+        assert c == Decimal('-1.25')
+        assert type(c) == type(d1)
+
+        #with other type, right
+        c = 4 / d1
+        assert c == Decimal('-0.8')
+        assert type(c) == type(d1)
+
+        #inline with decimal
+        d1 /= d2
+        assert d1 == Decimal('-2.5')
+
+        #inline with other type
+        d1 /= 4
+        assert d1 == Decimal('-0.625')
+
+    def test_floor_division(self):
+        Decimal = self.decimal.Decimal
+
+        d1 = Decimal('5')
+        d2 = Decimal('2')
+
+        #two Decimals
+        assert d1//d2 == Decimal('2')
+        assert d2//d1 == Decimal('0')
+
+        #with other type, left
+        c = d1 // 4
+        assert c == Decimal('1')
+        assert type(c) == type(d1)
+
+        #with other type, right
+        c = 7 // d1
+        assert c == Decimal('1')
+        assert type(c) == type(d1)
+
+        #inline with decimal
+        d1 //= d2
+        assert d1 == Decimal('2')
+
+        #inline with other type
+        d1 //= 2
+        assert d1 == Decimal('1')
+
+    def test_powering(self):
+        Decimal = self.decimal.Decimal
+
+        d1 = Decimal('5')
+        d2 = Decimal('2')
+
+        #two Decimals
+        assert d1**d2 == Decimal('25')
+        assert d2**d1 == Decimal('32')
+
+        #with other type, left
+        c = d1 ** 4
+        assert c == Decimal('625')
+        assert type(c) == type(d1)
+
+        #with other type, right
+        c = 7 ** d1
+        assert c == Decimal('16807')
+        assert type(c) == type(d1)
+
+        #inline with decimal
+        d1 **= d2
+        assert d1 == Decimal('25')
+
+        #inline with other type
+        d1 **= 4
+        assert d1 == Decimal('390625')
+
+    def test_module(self):
+        Decimal = self.decimal.Decimal
+
+        d1 = Decimal('5')
+        d2 = Decimal('2')
+
+        #two Decimals
+        assert d1%d2 == Decimal('1')
+        assert d2%d1 == Decimal('2')
+
+        #with other type, left
+        c = d1 % 4
+        assert c == Decimal('1')
+        assert type(c) == type(d1)
+
+        #with other type, right
+        c = 7 % d1
+        assert c == Decimal('2')
+        assert type(c) == type(d1)
+
+        #inline with decimal
+        d1 %= d2
+        assert d1 == Decimal('1')
+
+        #inline with other type
+        d1 %= 4
+        assert d1 == Decimal('1')
+
+    def test_floor_div_module(self):
+        Decimal = self.decimal.Decimal
+
+        d1 = Decimal('5')
+        d2 = Decimal('2')
+
+        #two Decimals
+        (p, q) = divmod(d1, d2)
+        assert p == Decimal('2')
+        assert q == Decimal('1')
+        assert type(p) == type(d1)
+        assert type(q) == type(d1)
+
+        #with other type, left
+        (p, q) = divmod(d1, 4)
+        assert p == Decimal('1')
+        assert q == Decimal('1')
+        assert type(p) == type(d1)
+        assert type(q) == type(d1)
+
+        #with other type, right
+        (p, q) = divmod(7, d1)
+        assert p == Decimal('1')
+        assert q == Decimal('2')
+        assert type(p) == type(d1)
+        assert type(q) == type(d1)
+
+    def test_unary_operators(self):
+        Decimal = self.decimal.Decimal
+
+        assert +Decimal(45) == Decimal(+45)
+        assert -Decimal(45) == Decimal(-45)
+        assert abs(Decimal(45)) == abs(Decimal(-45))
+
+    def test_nan_comparisons(self):
+        import operator
+        # comparisons involving signaling nans signal InvalidOperation
+
+        # order comparisons (<, <=, >, >=) involving only quiet nans
+        # also signal InvalidOperation
+
+        # equality comparisons (==, !=) involving only quiet nans
+        # don't signal, but return False or True respectively.
+        Decimal = self.decimal.Decimal
+        InvalidOperation = self.decimal.InvalidOperation
+        Overflow = self.decimal.Overflow
+        DivisionByZero = self.decimal.DivisionByZero
+        localcontext = self.decimal.localcontext
+
+        self.decimal.getcontext().traps[InvalidOperation] = False
+        self.decimal.getcontext().traps[Overflow] = False
+        self.decimal.getcontext().traps[DivisionByZero] = False
+
+        n = Decimal('NaN')
+        s = Decimal('sNaN')
+        i = Decimal('Inf')
+        f = Decimal('2')
+
+        qnan_pairs = (n, n), (n, i), (i, n), (n, f), (f, n)
+        snan_pairs = (s, n), (n, s), (s, i), (i, s), (s, f), (f, s), (s, s)
+        order_ops = operator.lt, operator.le, operator.gt, operator.ge
+        equality_ops = operator.eq, operator.ne
+
+        # results when InvalidOperation is not trapped
+        for x, y in qnan_pairs + snan_pairs:
+            for op in order_ops + equality_ops:
+                got = op(x, y)
+                expected = True if op is operator.ne else False
+                assert expected is got, (
+                    "expected {0!r} for operator.{1}({2!r}, {3!r}); "
+                    "got {4!r}".format(
+                        expected, op.__name__, x, y, got))
+
+        # repeat the above, but this time trap the InvalidOperation
+        with localcontext() as ctx:
+            ctx.traps[InvalidOperation] = 1
+
+            for x, y in qnan_pairs:
+                for op in equality_ops:
+                    got = op(x, y)
+                    expected = True if op is operator.ne else False
+                    assert expected is got, (
+                        "expected {0!r} for "
+                        "operator.{1}({2!r}, {3!r}); "
+                        "got {4!r}".format(
+                            expected, op.__name__, x, y, got))
+
+            for x, y in snan_pairs:
+                for op in equality_ops:
+                    raises(InvalidOperation, operator.eq, x, y)
+                    raises(InvalidOperation, operator.ne, x, y)
+
+            for x, y in qnan_pairs + snan_pairs:
+                for op in order_ops:
+                    raises(InvalidOperation, op, x, y)
+
+    def test_copy_sign(self):
+        Decimal = self.decimal.Decimal
+
+        d = Decimal(1).copy_sign(Decimal(-2))
+        assert Decimal(1).copy_sign(-2) == d
+        raises(TypeError, Decimal(1).copy_sign, '-2')
diff --git a/rpython/rlib/rmpdec.py b/rpython/rlib/rmpdec.py
--- a/rpython/rlib/rmpdec.py
+++ b/rpython/rlib/rmpdec.py
@@ -47,7 +47,10 @@
         "mpd_iszero", "mpd_isnegative", "mpd_isinfinite", "mpd_isspecial",
         "mpd_isnan", "mpd_issnan", "mpd_isqnan",
         "mpd_qcmp", "mpd_qquantize",
-        "mpd_qpow", "mpd_qadd", "mpd_qsub", "mpd_qmul", "mpd_qdiv",
+        "mpd_qplus", "mpd_qminus", "mpd_qabs",
+        "mpd_qadd", "mpd_qsub", "mpd_qmul", "mpd_qdiv", "mpd_qdivint",
+        "mpd_qrem", "mpd_qdivmod", "mpd_qpow", "mpd_qpowmod", 
+        "mpd_qcopy_sign",
         "mpd_qround_to_int",
         ],
     compile_extra=compile_extra,
@@ -221,9 +224,17 @@
     'mpd_qquantize', [MPD_PTR, MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
     lltype.Void)
 
-mpd_qpow = external(
-    'mpd_qpow',
-    [MPD_PTR, MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
+mpd_qplus = external(
+    'mpd_qplus',
+    [MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
+    lltype.Void)
+mpd_qminus = external(
+    'mpd_qminus',
+    [MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
+    lltype.Void)
+mpd_qabs = external(
+    'mpd_qabs',
+    [MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
     lltype.Void)
 mpd_qadd = external(
     'mpd_qadd',
@@ -241,6 +252,30 @@
     'mpd_qdiv',
     [MPD_PTR, MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
     lltype.Void)
+mpd_qdivint = external(
+    'mpd_qdivint',
+    [MPD_PTR, MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
+    lltype.Void)
+mpd_qrem = external(
+    'mpd_qrem',
+    [MPD_PTR, MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
+    lltype.Void)
+mpd_qdivmod = external(
+    'mpd_qdivmod',
+    [MPD_PTR, MPD_PTR, MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
+    lltype.Void)
+mpd_qpow = external(
+    'mpd_qpow',
+    [MPD_PTR, MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
+    lltype.Void)
+mpd_qpowmod = external(
+    'mpd_qpowmod',
+    [MPD_PTR, MPD_PTR, MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
+    lltype.Void)
+mpd_qcopy_sign = external(
+    'mpd_qcopy_sign',
+    [MPD_PTR, MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
+    lltype.Void)
 
 mpd_qround_to_int = external(
     'mpd_qround_to_int', [MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to