Issue 626 is fixed by these two patches.  Also, solve can now Symbol,
Function, and Derivative types.

I implemented tests and everything passes on my machine.

Thanks,
~Luke

--~--~---------~--~----~------------~-------~--~----~
You received this message because you are subscribed to the Google Groups 
"sympy-patches" group.
To post to this group, send email to sympy-patches@googlegroups.com
To unsubscribe from this group, send email to 
sympy-patches+unsubscr...@googlegroups.com
For more options, visit this group at 
http://groups.google.com/group/sympy-patches?hl=en
-~----------~----~----~----~------~----~------~--~---

From c0a8c3f80c494b703a7b2a1c80ebaab691f0aaeb Mon Sep 17 00:00:00 2001
From: Luke Peterson <hazelnu...@gmail.com>
Date: Thu, 21 May 2009 18:42:24 -0700
Subject: [PATCH] Implemented code to allow for solving for Function and Derivative instances.

---
 sympy/solvers/solvers.py            |   56 +++++++++++++++++++++++++++++++---
 sympy/solvers/tests/test_solvers.py |   20 ++++++++++++-
 2 files changed, 70 insertions(+), 6 deletions(-)

diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py
index ae42d4b..50ad7a4 100644
--- a/sympy/solvers/solvers.py
+++ b/sympy/solvers/solvers.py
@@ -19,7 +19,7 @@ from sympy.core.basic import Basic, S, C, Mul, Add
 from sympy.core.power import Pow
 from sympy.core.symbol import Symbol, Wild
 from sympy.core.relational import Equality
-from sympy.core.function import Derivative, diff
+from sympy.core.function import Derivative, diff, Function
 from sympy.core.numbers import ilcm
 
 from sympy.functions import sqrt, log, exp, LambertW
@@ -148,12 +148,40 @@ def solve(f, *symbols, **flags):
     symbols = map(sympify, symbols)
     result = list()
 
-    if any(not s.is_Symbol for s in symbols):
-        raise TypeError('not a Symbol')
+    symbols_new = []
+    SYMBOL_SWAPPED = False
+
+    if isinstance(symbols, (list, tuple)):
+        symbols_passed = symbols[:]
+    elif isinstance(symbols, set):
+        symbols_passed = symbols.copy()
+
+    i = 0
+    for s in symbols:
+        if s.is_Symbol:
+            s_new = s
+        elif s.is_Function:
+            SYMBOL_SWAPPED = True
+            s_new = Symbol('F%d' % i, dummy=True)
+        elif s.is_Derivative:
+            SYMBOL_SWAPPED = True
+            s_new = Symbol('D%d' % i, dummy=True)
+        else:
+            raise TypeError('not a Symbol or a Function')
+        symbols_new.append(s_new)
+        i += 1
+
+    if SYMBOL_SWAPPED:
+        swap_back_dict = dict(zip(symbols_new, symbols))
 
     if not isinstance(f, (tuple, list, set)):
         f = sympify(f)
 
+        if SYMBOL_SWAPPED:
+            swap_dict = zip(symbols, symbols_new)
+            f = f.subs(swap_dict)
+            symbols = symbols_new
+
         if isinstance(f, Equality):
             f = f.lhs - f.rhs
 
@@ -246,6 +274,11 @@ def solve(f, *symbols, **flags):
         if not f:
             return {}
         else:
+            if SYMBOL_SWAPPED:
+                swap_dict = zip(symbols, symbols_new)
+                f = [fi.subs(swap_dict) for fi in f]
+                symbols = symbols_new
+
             polys = []
 
             for g in f:
@@ -273,9 +306,21 @@ def solve(f, *symbols, **flags):
                         except ValueError:
                             matrix[i, m] = -coeff
 
-                return solve_linear_system(matrix, *symbols, **flags)
+                soln = solve_linear_system(matrix, *symbols, **flags)
             else:
-                return solve_poly_system(polys)
+                soln = solve_poly_system(polys)
+
+            if SYMBOL_SWAPPED:
+                if isinstance(soln, dict):
+                    res = {}
+                    for k in soln.keys():
+                        res.update({swap_back_dict[k]: soln[k]})
+                    return res
+                else:
+                    return soln
+            else:
+                return soln
+
 
 def solve_linear_system(system, *symbols, **flags):
     """Solve system of N linear equations with M variables, which means
@@ -883,3 +928,4 @@ def nsolve(*args, **kwargs):
     # solve the system numerically
     x = findroot(f, x0, J=J, **kwargs)
     return x
+
diff --git a/sympy/solvers/tests/test_solvers.py b/sympy/solvers/tests/test_solvers.py
index c5ea403..de6359e 100644
--- a/sympy/solvers/tests/test_solvers.py
+++ b/sympy/solvers/tests/test_solvers.py
@@ -113,7 +113,6 @@ def test_solve_polynomial_cv_1a():
 def test_solve_polynomial_cv_1b():
     x, a = symbols('x a')
 
-
     assert set(solve(4*x*(1 - a*x**(S(1)/2)), x)) == set([S(0), 1/a**2])
     assert set(solve(x * (x**(S(1)/3) - 3), x)) == set([S(0), S(27)])
 
@@ -261,3 +260,22 @@ def test_tsolve_1():
 def test_tsolve_2():
     x, y, a, b = symbols('xyab')
     assert solve(y-a*x**b, x) == [y**(1/b)*(1/a)**(1/b)]
+
+
+def test_solveForFunctionsDerivatives():
+    t = Symbol('t')
+    x = Function('x')(t)
+    y = Function('y')(t)
+    a11,a12,a21,a22,b1,b2 = symbols('a11','a12','a21','a22','b1','b2')
+    soln = solve([a11*x + a12*y - b1, a21*x + a22*y - b2], x, y)
+    assert soln == { y : (a11*b2 - a21*b1)/(a11*a22 - a12*a21),
+          x : (a22*b1 - a12*b2)/(a11*a22 - a12*a21) }
+    assert solve(x-1, x) == [1]
+    assert solve(3*x-2, x) == [Rational(2,3)]
+    soln = solve([a11*x.diff(t) + a12*y.diff(t) - b1, a21*x.diff(t) +
+        a22*y.diff(t) - b2], x.diff(t), y.diff(t))
+    assert soln == { y.diff(t) : (a11*b2 - a21*b1)/(a11*a22 - a12*a21),
+          x.diff(t) : (a22*b1 - a12*b2)/(a11*a22 - a12*a21) }
+    assert solve(x.diff(t)-1, x.diff(t)) == [1]
+    assert solve(3*x.diff(t)-2, x.diff(t)) == [Rational(2,3)]
+
-- 
1.6.0.4

From aeef51e6907af4062909c101e2c1b7164ae1813e Mon Sep 17 00:00:00 2001
From: Luke Peterson <hazelnu...@gmail.com>
Date: Fri, 22 May 2009 10:33:07 -0700
Subject: [PATCH] More functionality to allow solve to work with sets and Derivative objects.

---
 sympy/core/basic.py                 |    1 +
 sympy/solvers/solvers.py            |    2 +-
 sympy/solvers/tests/test_solvers.py |    8 ++++++++
 3 files changed, 10 insertions(+), 1 deletions(-)

diff --git a/sympy/core/basic.py b/sympy/core/basic.py
index b48e633..f5a7556 100644
--- a/sympy/core/basic.py
+++ b/sympy/core/basic.py
@@ -70,6 +70,7 @@ ordering_of_classes = [
     'Lambda',
     # operators
     'FDerivative','FApply',
+
     # composition of functions
     'FPow', 'Composition',
     # Landau O symbol
diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py
index 50ad7a4..e9d1445 100644
--- a/sympy/solvers/solvers.py
+++ b/sympy/solvers/solvers.py
@@ -154,7 +154,7 @@ def solve(f, *symbols, **flags):
     if isinstance(symbols, (list, tuple)):
         symbols_passed = symbols[:]
     elif isinstance(symbols, set):
-        symbols_passed = symbols.copy()
+        symbols_passed = list(symbols)
 
     i = 0
     for s in symbols:
diff --git a/sympy/solvers/tests/test_solvers.py b/sympy/solvers/tests/test_solvers.py
index de6359e..6a191f4 100644
--- a/sympy/solvers/tests/test_solvers.py
+++ b/sympy/solvers/tests/test_solvers.py
@@ -278,4 +278,12 @@ def test_solveForFunctionsDerivatives():
           x.diff(t) : (a22*b1 - a12*b2)/(a11*a22 - a12*a21) }
     assert solve(x.diff(t)-1, x.diff(t)) == [1]
     assert solve(3*x.diff(t)-2, x.diff(t)) == [Rational(2,3)]
+    eqns = set((3*x - 1, 2*y-4))
+    assert solve(eqns, set((x,y))) == { x : Rational(1, 3), y: 2 }
+
+    x = Symbol('x')
+    f = Function('f')
+    F = x**2 + f(x)**2 - 4*x - 1
+    assert solve(F.diff(x), diff(f(x), x)) == [(2 - x)/f(x)]
+
 
-- 
1.6.0.4

Reply via email to