details:   http://hg.sympy.org/sympy/rev/66c846fa9476
changeset: 1826:66c846fa9476
user:      Ondrej Certik <[EMAIL PROTECTED]>
date:      Sun Oct 19 20:10:54 2008 +0200
description:
merge

diffs (405 lines):

diff -r 18e3ec461fa2 -r 66c846fa9476 sympy/core/evalf.py
--- a/sympy/core/evalf.py       Sun Oct 19 16:15:19 2008 +0200
+++ b/sympy/core/evalf.py       Sun Oct 19 20:10:54 2008 +0200
@@ -578,6 +578,20 @@
         raise NotImplementedError
     return mpf_atan(xre, prec, round_nearest), None, prec, None
 
+def evalf_piecewise(expr, prec, options):
+    if 'subs' in options:
+        expr = expr.subs(options['subs'])
+        del options['subs']
+        if hasattr(expr,'func'):
+            return evalf(expr, prec, options)
+        if type(expr) == float:
+            return evalf(C.Real(expr), prec, options)
+        if type(expr) == int:
+            return evalf(C.Integer(expr), prec, options)
+
+    # We still have undefined symbols
+    raise NotImplementedError
+
 
 #----------------------------------------------------------------------------#
 #                                                                            #
@@ -904,6 +918,7 @@
 
     C.Integral : evalf_integral,
     C.Sum : evalf_sum,
+    C.Piecewise : evalf_piecewise,
     }
 
 def evalf(x, prec, options):
diff -r 18e3ec461fa2 -r 66c846fa9476 sympy/functions/elementary/piecewise.py
--- a/sympy/functions/elementary/piecewise.py   Sun Oct 19 16:15:19 2008 +0200
+++ b/sympy/functions/elementary/piecewise.py   Sun Oct 19 20:10:54 2008 +0200
@@ -1,7 +1,9 @@
 
-from sympy.core.basic import Basic
-from sympy.core.function import Function, diff
-
+from sympy.core.basic import Basic, S
+from sympy.core.function import Function, FunctionClass, diff
+from sympy.core.numbers import Number
+from sympy.core.relational import Relational
+from sympy.core.sympify import sympify
 
 class Piecewise(Function):
     """
@@ -9,10 +11,14 @@
 
     Usage
     =====
-      Piecewise(x, (-1, 0, f(x)), (0, oo, g(x))) -> Returns piecewise function
-        - The first argument is the variable of the intervals.
-        - The subsequent arguments are tuples defining each piece
-          (begin, end, function)
+      Piecewise( (expr,cond), (expr,cond), ... )
+        - Each argument is a 2-tuple defining a expression and condition
+        - The conds are evaluated in turn returning the first that is True.
+          If any of the evaluated conds are not determined explicitly False,
+          e.g. x < 1, the function is returned in symbolic form.
+        - If the function is evaluated at a place where all conditions are 
False,
+          a ValueError exception will be raised.
+        - Pairs where the cond is explicitly False, will be removed.
 
     Examples
     ========
@@ -20,35 +26,205 @@
       >>> x = Symbol('x')
       >>> f = x**2
       >>> g = log(x)
-      >>> p = Piecewise(x, (-1,0,f), (0,oo,g))
-      >>> p.diff(x)
-      Piecewise(x, (-1, 0, 2*x), (0, oo, 1/x))
-      >>> f*p
-      x**2*Piecewise(x, (-1, 0, x**2), (0, oo, log(x)))
+      >>> p = Piecewise( (0, x<-1), (f, x<=1), (g, True))
+      >>> p.subs(x,1)
+      1
+      >>> p.subs(x,5)
+      log(5)
     """
 
-    nargs=1
+    nargs=None
+
+    def __new__(cls, *args, **options):
+        for opt in ["nargs", "dummy", "comparable", "noncommutative", 
"commutative"]:
+            if opt in options:
+                del options[opt]
+        r = cls.canonize(*args)
+        if r is None:
+            return Basic.__new__(cls, *args, **options)
+        else:
+            return r
 
     @classmethod
     def canonize(cls, *args):
-        if not args[0].is_Symbol:
-            raise TypeError, "First argument must be symbol"
-        for piece in args[1:]:
-            if not isinstance(piece,tuple) or len(piece) != 3:
-                raise TypeError, "Must use 3-tuples for intervals"
+        # Check types first
+        for ec in args:
+            if (not isinstance(ec,tuple)) or len(ec)!=2:
+                raise TypeError, "args may only include (expr, cond) pairs"
+        for expr, cond in args:
+            cond_type = type(ec[1])
+            if cond_type != bool and not issubclass(cond_type,Relational):
+                raise TypeError, \
+                    "Cond %s is of type %s, but must be a bool or Relational" \
+                    % (cond, cond_type)
+
+        # Check for situations where we can evaluate the Piecewise object.
+        # 1) Hit an unevaluatable cond (e.g. x<1) -> keep object
+        # 2) Hit a true condition -> return that expr
+        # 3) Remove false conditions, if no conditions left -> raise ValueError
+        all_conds_evaled = True
+        non_false_ecpairs = []
+        for expr, cond in args:
+            cond_eval = cls.__eval_cond(cond)
+            if cond_eval is None:
+                all_conds_evaled = False
+                non_false_ecpairs.append( (expr, cond) )
+            elif cond_eval:
+                if all_conds_evaled:
+                    return expr
+                non_false_ecpairs.append( (expr, cond) )
+        if len(non_false_ecpairs) != len(args):
+            return Piecewise(*non_false_ecpairs)
+
+        # Count number of arguments.
+        nargs = 0
+        for expr, cond in args:
+            if hasattr(expr, 'nargs'):
+                nargs = max(nargs, expr.nargs)
+            elif hasattr(expr, 'args'):
+                nargs = max(nargs, len(expr.args))
+        if nargs:
+            cls.narg = nargs
         return None
 
+    def doit(self, **hints):
+        new_ecpairs = []
+        for expr, cond in self.args:
+            if hasattr(expr,'doit'):
+                new_expr = expr.doit(**hints)
+            else:
+                new_expr = expr
+            if hasattr(cond,'doit'):
+                new_cond = cond.doit(**hints)
+            else:
+                new_cond = cond
+            new_ecpairs.append( (new_expr, new_cond) )
+        return Piecewise(*new_ecpairs)
+
+    def _eval_integral(self,x):
+        from sympy.integrals import integrate
+        return  Piecewise(*[(integrate(expr, x),cond) \
+                                for expr, cond in self.args])
+
+    def _eval_interval(self, sym, ab):
+        """Evaluates the function along the sym in a given interval ab"""
+        # FIXME: Currently only supports conds of type sym < Num, or Num < sym
+        int_expr = []
+        a, b = ab
+        mul = 1
+        if a > b:
+            a = ab[1]; b = ab[0]; mul = -1
+        default = None
+
+        # Determine what intervals the expr,cond pairs affect.
+        # 1) If cond is True, then log it as default
+        # 1.1) Currently if cond can't be evaluated, throw NotImplentedError.
+        # 2) For each inequality, if previous cond defines part of the interval
+        #    update the new conds interval.
+        #    -  eg x < 1, x < 3 -> [oo,1],[1,3] instead of [oo,1],[oo,3]
+        # 3) Sort the intervals to make it easier to find correct exprs
+        for expr, cond in self.args:
+            if type(cond) == bool:
+                if cond:
+                    default = expr
+                    break
+                else:
+                    continue
+            curr = list(cond.args)
+            if cond.args[0] == sym:
+                curr[0] = S.NegativeInfinity
+            elif cond.args[1] == sym:
+                curr[1] = S.Infinity
+            else:
+                raise NotImplementedError, \
+                    "Currently only supporting evaluation with only "\
+                    "sym on one side fo the relation."
+            curr = [max(a,curr[0]),min(b,curr[1])]
+            for n in xrange(len(int_expr)):
+                if self.__eval_cond(curr[0] < int_expr[n][1]) and \
+                        self.__eval_cond(curr[0] >= int_expr[n][0]):
+                    curr[0] = int_expr[n][1]
+                if self.__eval_cond(curr[1] > int_expr[n][0]) and \
+                        self.__eval_cond(curr[1] <= int_expr[n][1]):
+                    curr[1] = int_expr[n][0]
+            if self.__eval_cond(curr[0] < curr[1]):
+                int_expr.append(curr+[expr])
+        int_expr.sort(lambda x,y:1 if x[0] > y[0] else -1)
+
+        # Add holes to list of intervals if there is a default value,
+        # otherwise raise a ValueError.
+        holes = []
+        curr_low = a
+        for int_a, int_b, expr in int_expr:
+            if curr_low < int_a:
+                holes.append([curr_low, min(b,int_a), default])
+            curr_low = int_b
+            if curr_low > b:
+                break
+        if holes and default != None:
+            int_expr.extend(holes)
+        elif holes and default == None:
+            raise ValueError, "Called interval evaluation over piecewise "\
+                              "function on undefined intervals %s" %\
+                              ", ".join([str((h[0],h[1])) for h in holes])
+
+        # Finally run through the intervals and sum the evaluation.
+        # TODO: Either refactor this code or Integral.doit to call 
_eval_interval
+        ret_fun = 0
+        for int_a, int_b, expr in int_expr:
+            B = expr.subs(sym, min(b,int_b))
+            if B is S.NaN:
+                B = limit(expr, sym, min(b,int_b))
+            if B is S.NaN:
+                return self
+            A = expr.subs(sym, max(a,int_a))
+            if A is S.NaN:
+                A = limit(expr, sym, max(a,int_a))
+            if A is S.NaN:
+                return self
+            ret_fun += B - A
+        return mul * ret_fun
+
     def _eval_derivative(self, s):
-        new_pieces = []
-        for start, end, f in self.args[1:]:
-            t = (start, end, diff(f, s))
-            new_pieces.append( t )
-        return Piecewise(self.args[0], *new_pieces)
+        return Piecewise(*[(diff(expr,s),cond) for expr, cond in self.args])
 
     def _eval_subs(self, old, new):
         if self == old:
             return new
-        new_pieces = []
-        for start, end, f in self.args[1:]:
-            new_pieces.append( (start, end, f._eval_subs(old,new)) )
-        return Piecewise(self.args[0], *new_pieces)
+        new_ecpairs = []
+        for expr, cond in self.args:
+            if hasattr(expr,"subs") and not isinstance(expr,FunctionClass):
+                new_expr = expr.subs(old, new)
+            else:
+                new_expr = expr
+            if hasattr(cond,"subs"):
+                new_cond = cond.subs(old, new)
+            else:
+                new_cond = cond
+            new_ecpairs.append( (new_expr, new_cond) )
+        return Piecewise(*new_ecpairs)
+
+    @classmethod
+    def __eval_cond(cls, cond):
+        """
+        Returns if the condition is True or False.
+
+        If it is undeterminable, returns None.
+        """
+        if type(cond) == bool:
+            return cond
+        arg0 = cond.args[0]
+        arg1 = cond.args[1]
+        if isinstance(arg0, FunctionClass) or isinstance(arg1, FunctionClass):
+            return None
+        if hasattr(arg0,'evalf'):
+            arg0 = arg0.evalf()
+        if not issubclass(type(arg0),Number) and \
+                type(arg0) != int and type(arg0) != float:
+            return None
+        if hasattr(arg1,'evalf'):
+            arg1 = arg1.evalf()
+        if not issubclass(type(arg1),Number) and \
+                type(arg1) != int and type(arg1) != float:
+            return None
+        return bool(cond)
diff -r 18e3ec461fa2 -r 66c846fa9476 
sympy/functions/elementary/tests/test_piecewise.py
--- a/sympy/functions/elementary/tests/test_piecewise.py        Sun Oct 19 
16:15:19 2008 +0200
+++ b/sympy/functions/elementary/tests/test_piecewise.py        Sun Oct 19 
20:10:54 2008 +0200
@@ -1,13 +1,67 @@
-from sympy import oo, diff, log, Symbol, Piecewise
+from sympy import diff, Integral, integrate, log, oo, Piecewise, raises, 
symbols
 
-x = Symbol('x')
+x,y = symbols('xy')
 
 def test_piecewise():
-    assert Piecewise(x, (0,1,x)) == Piecewise(x, (0,1,x))
 
-    p = Piecewise(x, (-oo, -1, -1), (-1, 0, x**2), (0, oo, log(x)))
-    dp = Piecewise(x, (-oo, -1, 0), (-1, 0, 2*x), (0, oo, 1/x))
+    # Test canonization
+    assert Piecewise((x, x < 1), (0, True)) == Piecewise((x, x < 1), (0, True))
+    assert Piecewise((x, x < 1), (0, False), (-1, 1>2)) == Piecewise((x, x < 
1))
+    assert Piecewise((x, True)) == x
+    raises(TypeError,"Piecewise(x)")
+    raises(TypeError,"Piecewise((x,x**2))")
+
+    # Test subs
+    p = Piecewise((-1, x < -1), (x**2, x < 0), (log(x), x >=0))
+    p_x2 = Piecewise((-1, x**2 < -1), (x**4, x**2 < 0), (log(x**2), x**2 >=0))
+    assert p.subs(x,x**2) == p_x2
+    assert p.subs(x,-5) == -1
+    assert p.subs(x,-1) == 1
+    assert p.subs(x,1) == log(1)
+
+    # Test evalf
+    assert p.evalf() == p
+    assert p.evalf(subs={x:-2}) == -1
+    assert p.evalf(subs={x:-1}) == 1
+    assert p.evalf(subs={x:1}) == log(1)
+
+    # Test doit
+    f_int = Piecewise((Integral(x,(x,0,1)), x < 1))
+    assert f_int.doit() == Piecewise( (1.0/2.0, x < 1) )
+
+    # Test differentiation
+    f = x
+    fp = x*p
+    dp = Piecewise((0, x < -1), (2*x, x < 0), (1/x, x >= 0))
+    fp_dx = x*dp + p
     assert diff(p,x) == dp
+    # FIXME: Seems that the function derivatives are flipping the args.
+    # assert diff(f*p,x) == fp_dx
+    # Test args for now.
+    assert fp_dx.args[0] == diff(f*p,x).args[1]
+    assert fp_dx.args[1] == diff(f*p,x).args[0]
 
-    p_x2 = Piecewise(x, (-oo, -1, -1), (-1, 0, x**4), (0, oo, log(x**2)))
-    assert p.subs(x,x**2) == p_x2
+    # Test simple arithmetic
+    assert x*p == fp
+    assert x*p + p == p + x*p
+    assert p + f == f + p
+    assert p + dp == dp + p
+    assert p - dp == -(dp - p)
+
+    # Test _eval_interval
+    f1 = x*y + 2
+    f2 = x*y**2 + 3
+    peval = Piecewise( (f1, x<0), (f2, x>0))
+    peval_interval = f1.subs(x,0) - f1.subs(x,-1) + f2.subs(x,1) - f2.subs(x,0)
+    assert peval._eval_interval(x,(-1,1)) == peval_interval
+
+    # Test integration
+    p_int =  Piecewise((-x,x < -1), (x**3/3.0, x < 0), (-x + x*log(x), x >= 0))
+    assert integrate(p,x) == p_int
+    p = Piecewise((x, x < 1),(x**2, -1 <= x),(x,3<x))
+    assert integrate(p,(x,-2,2)) == 5.0/6.0
+    assert integrate(p,(x,2,-2)) == -5.0/6.0
+    p = Piecewise((0, x < 0), (1,x < 1), (0, x < 2), (1, x < 3), (0, True))
+    assert integrate(p, (x,-oo,oo)) == 2
+    p = Piecewise((x, x < -10),(x**2, x <= -1),(x, 1 < x))
+    raises(ValueError, "integrate(p,(x,-2,2))")
diff -r 18e3ec461fa2 -r 66c846fa9476 sympy/integrals/integrals.py
--- a/sympy/integrals/integrals.py      Sun Oct 19 16:15:19 2008 +0200
+++ b/sympy/integrals/integrals.py      Sun Oct 19 20:10:54 2008 +0200
@@ -9,7 +9,7 @@
 from sympy.series import limit
 from sympy.polys import Poly
 from sympy.solvers import solve
-from sympy.functions import DiracDelta, Heaviside
+from sympy.functions import DiracDelta, Heaviside, Piecewise
 
 class Integral(Basic):
     """Represents unevaluated integral."""
@@ -132,6 +132,10 @@
                 if ab is None:
                     function = antideriv
                 else:
+                    if isinstance(antideriv,Piecewise):
+                        function = antideriv._eval_interval(x,ab)
+                        continue
+
                     a,b = ab
                     A = antideriv.subs(x, a)
 
@@ -241,6 +245,10 @@
         # see Polynomial for details.
         if isinstance(f, Poly):
             return f.integrate(x)
+
+        # Piecewise antiderivatives need to call special integrate.
+        if isinstance(f,Piecewise):
+            return f._eval_integral(x)
 
         # let's cut it short if `f` does not depend on `x`
         if not f.has(x):

--~--~---------~--~----~------------~-------~--~----~
You received this message because you are subscribed to the Google Groups 
"sympy-commits" group.
To post to this group, send email to sympy-commits@googlegroups.com
To unsubscribe from this group, send email to [EMAIL PROTECTED]
For more options, visit this group at 
http://groups.google.com/group/sympy-commits?hl=en
-~----------~----~----~----~------~----~------~--~---

Reply via email to