details:   http://hg.sympy.org/sympy/rev/d040911a2af5
changeset: 1821:d040911a2af5
user:      Andy R. Terrel <[EMAIL PROTECTED]>
date:      Fri Oct 17 08:32:44 2008 +0200
description:
This patch changed the function call for Piecewise(sym,(a,b,expr), ...) to a 
more Mathematica style Piecewise((expr, cond), ... ).  It adds support for 
current functionality with this new interface.

diffs (224 lines):

diff -r 237526ad1977 -r d040911a2af5 sympy/functions/elementary/piecewise.py
--- a/sympy/functions/elementary/piecewise.py   Thu Oct 16 18:53:28 2008 +0200
+++ b/sympy/functions/elementary/piecewise.py   Fri Oct 17 08:32:44 2008 +0200
@@ -1,7 +1,10 @@
 
-from sympy.core.basic import Basic
-from sympy.core.function import Function, diff
-
+from sympy.core.basic import Basic, S
+from sympy.core.cache import cacheit
+from sympy.core.function import Function, FunctionClass, diff
+from sympy.core.numbers import Number
+from sympy.core.relational import Relational
+from sympy.core.sympify import sympify
 
 class Piecewise(Function):
     """
@@ -9,10 +12,14 @@
 
     Usage
     =====
-      Piecewise(x, (-1, 0, f(x)), (0, oo, g(x))) -> Returns piecewise function
-        - The first argument is the variable of the intervals.
-        - The subsequent arguments are tuples defining each piece
-          (begin, end, function)
+      Piecewise( (expr,cond), (expr,cond), ... )
+        - Each argument is a 2-tuple defining a expression and condition
+        - The conds are evaluated in turn returning the first that is True.
+          If any of the evaluated conds are not determined explicitly False,
+          e.g. x < 1, the function is returned in symbolic form.
+        - If the function is evaluated at a place where all conditions are 
False,
+          a ValueError exception will be raised.
+        - Pairs where the cond is explicitly False, will be removed.
 
     Examples
     ========
@@ -20,35 +27,105 @@
       >>> x = Symbol('x')
       >>> f = x**2
       >>> g = log(x)
-      >>> p = Piecewise(x, (-1,0,f), (0,oo,g))
-      >>> p.diff(x)
-      Piecewise(x, (-1, 0, 2*x), (0, oo, 1/x))
-      >>> f*p
-      x**2*Piecewise(x, (-1, 0, x**2), (0, oo, log(x)))
+      >>> p = Piecewise( (0, x<-1), (f, x<=1), (g, True))
+      >>> p.subs(x,1)
+      1
+      >>> p.subs(x,5)
+      log(5)
     """
 
-    nargs=1
+    nargs=None
+
+    @cacheit
+    def __new__(cls, *args, **options):
+        for opt in ["nargs", "dummy", "comparable", "noncommutative", 
"commutative"]:
+            if opt in options:
+                del options[opt]
+        r = cls.canonize(*args)
+        if r is None:
+            return Basic.__new__(cls, *args, **options)
+        else:
+            return r
 
     @classmethod
     def canonize(cls, *args):
-        if not args[0].is_Symbol:
-            raise TypeError, "First argument must be symbol"
-        for piece in args[1:]:
-            if not isinstance(piece,tuple) or len(piece) != 3:
-                raise TypeError, "Must use 3-tuples for intervals"
+        # Check types first
+        for ec in args:
+            if (not isinstance(ec,tuple)) or len(ec)!=2:
+                raise TypeError, "args may only include (expr, cond) pairs"
+        for expr, cond in args:
+            cond_type = type(ec[1])
+            if cond_type != bool and not issubclass(cond_type,Relational):
+                raise TypeError, \
+                    "Cond %s is of type %s, but must be a bool or Relational" \
+                    % (cond, cond_type)
+
+        # Check for situations where we can evaluate the Piecewise object.
+        # 1) Hit an unevaluatable cond (e.g. x<1) -> keep object
+        # 2) Hit a true condition -> return that expr
+        # 3) Remove false conditions, if no conditions left -> raise ValueError
+        all_conds_evaled = True
+        non_false_ecpairs = []
+        for expr, cond in args:
+            cond_eval = cls.__eval_cond(cond)
+            if cond_eval is None:
+                all_conds_evaled = False
+                non_false_ecpairs.append( (expr, cond) )
+            elif cond_eval:
+                if all_conds_evaled:
+                    return expr
+                non_false_ecpairs.append( (expr, cond) )
+        if len(non_false_ecpairs) != len(args):
+            return Piecewise(*non_false_ecpairs)
+
+        # Count number of arguments.
+        nargs = 0
+        for expr, cond in args:
+            if hasattr(expr, 'nargs'):
+                nargs = max(nargs, expr.nargs)
+            elif hasattr(expr, 'args'):
+                nargs = max(nargs, len(expr.args))
+        if nargs:
+            cls.narg = nargs
         return None
 
     def _eval_derivative(self, s):
-        new_pieces = []
-        for start, end, f in self.args[1:]:
-            t = (start, end, diff(f, s))
-            new_pieces.append( t )
-        return Piecewise(self.args[0], *new_pieces)
+        return Piecewise(*[(diff(expr,s),cond) for expr, cond in self.args])
 
     def _eval_subs(self, old, new):
         if self == old:
             return new
-        new_pieces = []
-        for start, end, f in self.args[1:]:
-            new_pieces.append( (start, end, f._eval_subs(old,new)) )
-        return Piecewise(self.args[0], *new_pieces)
+        new_ecpairs = []
+        for expr, cond in self.args:
+            if hasattr(expr,"subs") and not isinstance(expr,FunctionClass):
+                new_expr = expr.subs(old, new)
+            else:
+                new_expr = expr
+            if hasattr(cond,"subs"):
+                new_cond = cond.subs(old, new)
+            else:
+                new_cond = cond
+            new_ecpairs.append( (new_expr, new_cond) )
+        return Piecewise(*new_ecpairs)
+
+    @classmethod
+    def __eval_cond(cls, cond):
+        """Returns if the condition is True or False.
+           If it is undeterminable, returns None."""
+        if type(cond) == bool:
+            return cond
+        arg0 = cond.args[0]
+        arg1 = cond.args[1]
+        if isinstance(arg0, FunctionClass) or isinstance(arg1, FunctionClass):
+            return None
+        if hasattr(arg0,'evalf'):
+            arg0 = arg0.evalf()
+        if not issubclass(type(arg0),Number) and \
+                type(arg0) != int and type(arg0) != float:
+            return None
+        if hasattr(arg1,'evalf'):
+            arg1 = arg1.evalf()
+        if not issubclass(type(arg1),Number) and \
+                type(arg1) != int and type(arg1) != float:
+            return None
+        return bool(cond)
diff -r 237526ad1977 -r d040911a2af5 
sympy/functions/elementary/tests/test_piecewise.py
--- a/sympy/functions/elementary/tests/test_piecewise.py        Thu Oct 16 
18:53:28 2008 +0200
+++ b/sympy/functions/elementary/tests/test_piecewise.py        Fri Oct 17 
08:32:44 2008 +0200
@@ -1,13 +1,51 @@
-from sympy import oo, diff, log, Symbol, Piecewise
+from sympy import diff, Integral, integrate, log, oo, Piecewise, symbols
 
-x = Symbol('x')
+x,y = symbols('xy')
 
 def test_piecewise():
-    assert Piecewise(x, (0,1,x)) == Piecewise(x, (0,1,x))
 
-    p = Piecewise(x, (-oo, -1, -1), (-1, 0, x**2), (0, oo, log(x)))
-    dp = Piecewise(x, (-oo, -1, 0), (-1, 0, 2*x), (0, oo, 1/x))
+    # Test canonization
+    assert Piecewise((x, x < 1), (0, True)) == Piecewise((x, x < 1), (0, True))
+    assert Piecewise((x, x < 1), (0, False), (-1, 1>2)) == Piecewise((x, x < 
1))
+    assert Piecewise((x, True)) == x
+
+    exception_called = False
+    try:
+        Piecewise(x)
+    except TypeError:
+        exception_called = True
+    assert exception_called
+
+    exception_called = False
+    try:
+        Piecewise((x,x**2))
+    except TypeError:
+        exception_called = True
+    assert exception_called
+
+    # Test subs
+    p = Piecewise((-1, x < -1), (x**2, x < 0), (log(x), x >=0))
+    p_x2 = Piecewise((-1, x**2 < -1), (x**4, x**2 < 0), (log(x**2), x**2 >=0))
+    assert p.subs(x,x**2) == p_x2
+    assert p.subs(x,-5) == -1
+    assert p.subs(x,-1) == 1
+    assert p.subs(x,1) == log(1)
+
+    # Test differentiation
+    f = x
+    fp = x*p
+    dp = Piecewise((0, x < -1), (2*x, x < 0), (1/x, x >= 0))
+    fp_dx = x*dp + p
     assert diff(p,x) == dp
+    # FIXME: Seems that the function derivatives are flipping the args.
+    # assert diff(f*p,x) == fp_dx
+    # Test args for now.
+    assert fp_dx.args[0] == diff(f*p,x).args[1]
+    assert fp_dx.args[1] == diff(f*p,x).args[0]
 
-    p_x2 = Piecewise(x, (-oo, -1, -1), (-1, 0, x**4), (0, oo, log(x**2)))
-    assert p.subs(x,x**2) == p_x2
+    # Test simple arithmetic
+    assert x*p == fp
+    assert x*p + p == p + x*p
+    assert p + f == f + p
+    assert p + dp == dp + p
+    assert p - dp == -(dp - p)

--~--~---------~--~----~------------~-------~--~----~
You received this message because you are subscribed to the Google Groups 
"sympy-commits" group.
To post to this group, send email to sympy-commits@googlegroups.com
To unsubscribe from this group, send email to [EMAIL PROTECTED]
For more options, visit this group at 
http://groups.google.com/group/sympy-commits?hl=en
-~----------~----~----~----~------~----~------~--~---

Reply via email to