Attached is the code that allows you to use solve with equations that
have Symbol, Function, or Derivative instances in them.

It passed all tests on my machine.  Let me know if there are any problems.

I will be working on a similar patch to do the same thing for diff()


Thanks,
~Luke

--~--~---------~--~----~------------~-------~--~----~
You received this message because you are subscribed to the Google Groups 
"sympy-patches" group.
To post to this group, send email to sympy-patches@googlegroups.com
To unsubscribe from this group, send email to 
sympy-patches+unsubscr...@googlegroups.com
For more options, visit this group at 
http://groups.google.com/group/sympy-patches?hl=en
-~----------~----~----~----~------~----~------~--~---

From c0a8c3f80c494b703a7b2a1c80ebaab691f0aaeb Mon Sep 17 00:00:00 2001
From: Luke Peterson <hazelnu...@gmail.com>
Date: Thu, 21 May 2009 18:42:24 -0700
Subject: [PATCH] Implemented code to allow for solving for Function and Derivative instances.

---
 sympy/solvers/solvers.py            |   56 +++++++++++++++++++++++++++++++---
 sympy/solvers/tests/test_solvers.py |   20 ++++++++++++-
 2 files changed, 70 insertions(+), 6 deletions(-)

diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py
index ae42d4b..50ad7a4 100644
--- a/sympy/solvers/solvers.py
+++ b/sympy/solvers/solvers.py
@@ -19,7 +19,7 @@ from sympy.core.basic import Basic, S, C, Mul, Add
 from sympy.core.power import Pow
 from sympy.core.symbol import Symbol, Wild
 from sympy.core.relational import Equality
-from sympy.core.function import Derivative, diff
+from sympy.core.function import Derivative, diff, Function
 from sympy.core.numbers import ilcm
 
 from sympy.functions import sqrt, log, exp, LambertW
@@ -148,12 +148,40 @@ def solve(f, *symbols, **flags):
     symbols = map(sympify, symbols)
     result = list()
 
-    if any(not s.is_Symbol for s in symbols):
-        raise TypeError('not a Symbol')
+    symbols_new = []
+    SYMBOL_SWAPPED = False
+
+    if isinstance(symbols, (list, tuple)):
+        symbols_passed = symbols[:]
+    elif isinstance(symbols, set):
+        symbols_passed = symbols.copy()
+
+    i = 0
+    for s in symbols:
+        if s.is_Symbol:
+            s_new = s
+        elif s.is_Function:
+            SYMBOL_SWAPPED = True
+            s_new = Symbol('F%d' % i, dummy=True)
+        elif s.is_Derivative:
+            SYMBOL_SWAPPED = True
+            s_new = Symbol('D%d' % i, dummy=True)
+        else:
+            raise TypeError('not a Symbol or a Function')
+        symbols_new.append(s_new)
+        i += 1
+
+    if SYMBOL_SWAPPED:
+        swap_back_dict = dict(zip(symbols_new, symbols))
 
     if not isinstance(f, (tuple, list, set)):
         f = sympify(f)
 
+        if SYMBOL_SWAPPED:
+            swap_dict = zip(symbols, symbols_new)
+            f = f.subs(swap_dict)
+            symbols = symbols_new
+
         if isinstance(f, Equality):
             f = f.lhs - f.rhs
 
@@ -246,6 +274,11 @@ def solve(f, *symbols, **flags):
         if not f:
             return {}
         else:
+            if SYMBOL_SWAPPED:
+                swap_dict = zip(symbols, symbols_new)
+                f = [fi.subs(swap_dict) for fi in f]
+                symbols = symbols_new
+
             polys = []
 
             for g in f:
@@ -273,9 +306,21 @@ def solve(f, *symbols, **flags):
                         except ValueError:
                             matrix[i, m] = -coeff
 
-                return solve_linear_system(matrix, *symbols, **flags)
+                soln = solve_linear_system(matrix, *symbols, **flags)
             else:
-                return solve_poly_system(polys)
+                soln = solve_poly_system(polys)
+
+            if SYMBOL_SWAPPED:
+                if isinstance(soln, dict):
+                    res = {}
+                    for k in soln.keys():
+                        res.update({swap_back_dict[k]: soln[k]})
+                    return res
+                else:
+                    return soln
+            else:
+                return soln
+
 
 def solve_linear_system(system, *symbols, **flags):
     """Solve system of N linear equations with M variables, which means
@@ -883,3 +928,4 @@ def nsolve(*args, **kwargs):
     # solve the system numerically
     x = findroot(f, x0, J=J, **kwargs)
     return x
+
diff --git a/sympy/solvers/tests/test_solvers.py b/sympy/solvers/tests/test_solvers.py
index c5ea403..de6359e 100644
--- a/sympy/solvers/tests/test_solvers.py
+++ b/sympy/solvers/tests/test_solvers.py
@@ -113,7 +113,6 @@ def test_solve_polynomial_cv_1a():
 def test_solve_polynomial_cv_1b():
     x, a = symbols('x a')
 
-
     assert set(solve(4*x*(1 - a*x**(S(1)/2)), x)) == set([S(0), 1/a**2])
     assert set(solve(x * (x**(S(1)/3) - 3), x)) == set([S(0), S(27)])
 
@@ -261,3 +260,22 @@ def test_tsolve_1():
 def test_tsolve_2():
     x, y, a, b = symbols('xyab')
     assert solve(y-a*x**b, x) == [y**(1/b)*(1/a)**(1/b)]
+
+
+def test_solveForFunctionsDerivatives():
+    t = Symbol('t')
+    x = Function('x')(t)
+    y = Function('y')(t)
+    a11,a12,a21,a22,b1,b2 = symbols('a11','a12','a21','a22','b1','b2')
+    soln = solve([a11*x + a12*y - b1, a21*x + a22*y - b2], x, y)
+    assert soln == { y : (a11*b2 - a21*b1)/(a11*a22 - a12*a21),
+          x : (a22*b1 - a12*b2)/(a11*a22 - a12*a21) }
+    assert solve(x-1, x) == [1]
+    assert solve(3*x-2, x) == [Rational(2,3)]
+    soln = solve([a11*x.diff(t) + a12*y.diff(t) - b1, a21*x.diff(t) +
+        a22*y.diff(t) - b2], x.diff(t), y.diff(t))
+    assert soln == { y.diff(t) : (a11*b2 - a21*b1)/(a11*a22 - a12*a21),
+          x.diff(t) : (a22*b1 - a12*b2)/(a11*a22 - a12*a21) }
+    assert solve(x.diff(t)-1, x.diff(t)) == [1]
+    assert solve(3*x.diff(t)-2, x.diff(t)) == [Rational(2,3)]
+
-- 
1.6.0.4

Reply via email to