Attached is the patch that Fixes issue 626 and allows for solve() to
handle Function and Derivative instances, implements tests for both
cases, and has extensive commenting.

All tests pass on my machine.

Please review and let me know if I can make it better.

Thanks,
~Luke

--~--~---------~--~----~------------~-------~--~----~
You received this message because you are subscribed to the Google Groups 
"sympy-patches" group.
To post to this group, send email to sympy-patches@googlegroups.com
To unsubscribe from this group, send email to 
sympy-patches+unsubscr...@googlegroups.com
For more options, visit this group at 
http://groups.google.com/group/sympy-patches?hl=en
-~----------~----~----~----~------~----~------~--~---

From e06e32d80eaab3ed4038faa17403e95eb72eb0f0 Mon Sep 17 00:00:00 2001
From: Luke Peterson <hazelnu...@gmail.com>
Date: Mon, 25 May 2009 15:39:02 -0700
Subject: [PATCH] Fix for Issue 626, solve() can now handle Function and Derivative instances.

---
 sympy/solvers/solvers.py            |   68 ++++++++++++++++++++++++++++++++---
 sympy/solvers/tests/test_solvers.py |   38 +++++++++++++++++++
 2 files changed, 101 insertions(+), 5 deletions(-)

diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py
index ae42d4b..d608f46 100644
--- a/sympy/solvers/solvers.py
+++ b/sympy/solvers/solvers.py
@@ -19,7 +19,7 @@ from sympy.core.basic import Basic, S, C, Mul, Add
 from sympy.core.power import Pow
 from sympy.core.symbol import Symbol, Wild
 from sympy.core.relational import Equality
-from sympy.core.function import Derivative, diff
+from sympy.core.function import Derivative, diff, Function
 from sympy.core.numbers import ilcm
 
 from sympy.functions import sqrt, log, exp, LambertW
@@ -148,12 +148,50 @@ def solve(f, *symbols, **flags):
     symbols = map(sympify, symbols)
     result = list()
 
-    if any(not s.is_Symbol for s in symbols):
-        raise TypeError('not a Symbol')
+    # Begin code handling for Function and Derivative instances
+    # Basic idea:  store all the passed symbols in symbols_passed, check to see
+    # if any of them are Function or Derivative types, if so, use a dummy
+    # symbol in their place, and set symbol_swapped = True so that other parts
+    # of the code can be aware of the swap.  Once all swapping is done, the
+    # continue on with regular solving as usual, and swap back at the end of
+    # the routine, so that whatever was passed in symbols is what is returned.
+    symbols_new = []
+    symbol_swapped = False
+
+    if isinstance(symbols, (list, tuple)):
+        symbols_passed = symbols[:]
+    elif isinstance(symbols, set):
+        symbols_passed = list(symbols)
+
+    i = 0
+    for s in symbols:
+        if s.is_Symbol:
+            s_new = s
+        elif s.is_Function:
+            symbol_swapped = True
+            s_new = Symbol('F%d' % i, dummy=True)
+        elif s.is_Derivative:
+            symbol_swapped = True
+            s_new = Symbol('D%d' % i, dummy=True)
+        else:
+            raise TypeError('not a Symbol or a Function')
+        symbols_new.append(s_new)
+        i += 1
+
+        if symbol_swapped:
+            swap_back_dict = dict(zip(symbols_new, symbols))
+    # End code for handling of Function and Derivative instances
 
     if not isinstance(f, (tuple, list, set)):
         f = sympify(f)
 
+        # Create a swap dictionary for storing the passed symbols to be solved
+        # for, so that they may be swapped back.
+        if symbol_swapped:
+            swap_dict = zip(symbols, symbols_new)
+            f = f.subs(swap_dict)
+            symbols = symbols_new
+
         if isinstance(f, Equality):
             f = f.lhs - f.rhs
 
@@ -246,6 +284,13 @@ def solve(f, *symbols, **flags):
         if not f:
             return {}
         else:
+            # Create a swap dictionary for storing the passed symbols to be
+            # solved for, so that they may be swapped back.
+            if symbol_swapped:
+                swap_dict = zip(symbols, symbols_new)
+                f = [fi.subs(swap_dict) for fi in f]
+                symbols = symbols_new
+
             polys = []
 
             for g in f:
@@ -273,9 +318,22 @@ def solve(f, *symbols, **flags):
                         except ValueError:
                             matrix[i, m] = -coeff
 
-                return solve_linear_system(matrix, *symbols, **flags)
+                soln = solve_linear_system(matrix, *symbols, **flags)
+            else:
+                soln = solve_poly_system(polys)
+
+            # Use swap_dict to ensure we return the same type as what was
+            # passed
+            if symbol_swapped:
+                if isinstance(soln, dict):
+                    res = {}
+                    for k in soln.keys():
+                        res.update({swap_back_dict[k]: soln[k]})
+                    return res
+                else:
+                    return soln
             else:
-                return solve_poly_system(polys)
+                return soln
 
 def solve_linear_system(system, *symbols, **flags):
     """Solve system of N linear equations with M variables, which means
diff --git a/sympy/solvers/tests/test_solvers.py b/sympy/solvers/tests/test_solvers.py
index c5ea403..5fe2fea 100644
--- a/sympy/solvers/tests/test_solvers.py
+++ b/sympy/solvers/tests/test_solvers.py
@@ -261,3 +261,41 @@ def test_tsolve_1():
 def test_tsolve_2():
     x, y, a, b = symbols('xyab')
     assert solve(y-a*x**b, x) == [y**(1/b)*(1/a)**(1/b)]
+
+def test_solveForFunctionsDerivatives():
+    t = Symbol('t')
+    x = Function('x')(t)
+    y = Function('y')(t)
+    a11,a12,a21,a22,b1,b2 = symbols('a11','a12','a21','a22','b1','b2')
+
+    soln = solve([a11*x + a12*y - b1, a21*x + a22*y - b2], x, y)
+    assert soln == { y : (a11*b2 - a21*b1)/(a11*a22 - a12*a21),
+        x : (a22*b1 - a12*b2)/(a11*a22 - a12*a21) }
+
+    assert solve(x-1, x) == [1]
+    assert solve(3*x-2, x) == [Rational(2,3)]
+
+    soln = solve([a11*x.diff(t) + a12*y.diff(t) - b1, a21*x.diff(t) +
+            a22*y.diff(t) - b2], x.diff(t), y.diff(t))
+    assert soln == { y.diff(t) : (a11*b2 - a21*b1)/(a11*a22 - a12*a21),
+            x.diff(t) : (a22*b1 - a12*b2)/(a11*a22 - a12*a21) }
+
+    assert solve(x.diff(t)-1, x.diff(t)) == [1]
+    assert solve(3*x.diff(t)-2, x.diff(t)) == [Rational(2,3)]
+
+    eqns = set((3*x - 1, 2*y-4))
+    assert solve(eqns, set((x,y))) == { x : Rational(1, 3), y: 2 }
+    x = Symbol('x')
+    f = Function('f')
+    F = x**2 + f(x)**2 - 4*x - 1
+    assert solve(F.diff(x), diff(f(x), x)) == [(2 - x)/f(x)]
+
+    # Mixed cased with a Symbol and a Function
+    x = Symbol('x')
+    y = Function('y')(t)
+
+    soln = solve([a11*x + a12*y.diff(t) - b1, a21*x +
+            a22*y.diff(t) - b2], x, y.diff(t))
+    assert soln == { y.diff(t) : (a11*b2 - a21*b1)/(a11*a22 - a12*a21),
+            x : (a22*b1 - a12*b2)/(a11*a22 - a12*a21) }
+
-- 
1.6.0.4

Reply via email to