Attached is the code that allows you to use solve with equations that have Symbol, Function, or Derivative instances in them.
It passed all tests on my machine. Let me know if there are any problems. I will be working on a similar patch to do the same thing for diff() Thanks, ~Luke --~--~---------~--~----~------------~-------~--~----~ You received this message because you are subscribed to the Google Groups "sympy-patches" group. To post to this group, send email to sympy-patches@googlegroups.com To unsubscribe from this group, send email to sympy-patches+unsubscr...@googlegroups.com For more options, visit this group at http://groups.google.com/group/sympy-patches?hl=en -~----------~----~----~----~------~----~------~--~---
From c0a8c3f80c494b703a7b2a1c80ebaab691f0aaeb Mon Sep 17 00:00:00 2001 From: Luke Peterson <hazelnu...@gmail.com> Date: Thu, 21 May 2009 18:42:24 -0700 Subject: [PATCH] Implemented code to allow for solving for Function and Derivative instances. --- sympy/solvers/solvers.py | 56 +++++++++++++++++++++++++++++++--- sympy/solvers/tests/test_solvers.py | 20 ++++++++++++- 2 files changed, 70 insertions(+), 6 deletions(-) diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py index ae42d4b..50ad7a4 100644 --- a/sympy/solvers/solvers.py +++ b/sympy/solvers/solvers.py @@ -19,7 +19,7 @@ from sympy.core.basic import Basic, S, C, Mul, Add from sympy.core.power import Pow from sympy.core.symbol import Symbol, Wild from sympy.core.relational import Equality -from sympy.core.function import Derivative, diff +from sympy.core.function import Derivative, diff, Function from sympy.core.numbers import ilcm from sympy.functions import sqrt, log, exp, LambertW @@ -148,12 +148,40 @@ def solve(f, *symbols, **flags): symbols = map(sympify, symbols) result = list() - if any(not s.is_Symbol for s in symbols): - raise TypeError('not a Symbol') + symbols_new = [] + SYMBOL_SWAPPED = False + + if isinstance(symbols, (list, tuple)): + symbols_passed = symbols[:] + elif isinstance(symbols, set): + symbols_passed = symbols.copy() + + i = 0 + for s in symbols: + if s.is_Symbol: + s_new = s + elif s.is_Function: + SYMBOL_SWAPPED = True + s_new = Symbol('F%d' % i, dummy=True) + elif s.is_Derivative: + SYMBOL_SWAPPED = True + s_new = Symbol('D%d' % i, dummy=True) + else: + raise TypeError('not a Symbol or a Function') + symbols_new.append(s_new) + i += 1 + + if SYMBOL_SWAPPED: + swap_back_dict = dict(zip(symbols_new, symbols)) if not isinstance(f, (tuple, list, set)): f = sympify(f) + if SYMBOL_SWAPPED: + swap_dict = zip(symbols, symbols_new) + f = f.subs(swap_dict) + symbols = symbols_new + if isinstance(f, Equality): f = f.lhs - f.rhs @@ -246,6 +274,11 @@ def solve(f, *symbols, **flags): if not f: return {} else: + if SYMBOL_SWAPPED: + swap_dict = zip(symbols, symbols_new) + f = [fi.subs(swap_dict) for fi in f] + symbols = symbols_new + polys = [] for g in f: @@ -273,9 +306,21 @@ def solve(f, *symbols, **flags): except ValueError: matrix[i, m] = -coeff - return solve_linear_system(matrix, *symbols, **flags) + soln = solve_linear_system(matrix, *symbols, **flags) else: - return solve_poly_system(polys) + soln = solve_poly_system(polys) + + if SYMBOL_SWAPPED: + if isinstance(soln, dict): + res = {} + for k in soln.keys(): + res.update({swap_back_dict[k]: soln[k]}) + return res + else: + return soln + else: + return soln + def solve_linear_system(system, *symbols, **flags): """Solve system of N linear equations with M variables, which means @@ -883,3 +928,4 @@ def nsolve(*args, **kwargs): # solve the system numerically x = findroot(f, x0, J=J, **kwargs) return x + diff --git a/sympy/solvers/tests/test_solvers.py b/sympy/solvers/tests/test_solvers.py index c5ea403..de6359e 100644 --- a/sympy/solvers/tests/test_solvers.py +++ b/sympy/solvers/tests/test_solvers.py @@ -113,7 +113,6 @@ def test_solve_polynomial_cv_1a(): def test_solve_polynomial_cv_1b(): x, a = symbols('x a') - assert set(solve(4*x*(1 - a*x**(S(1)/2)), x)) == set([S(0), 1/a**2]) assert set(solve(x * (x**(S(1)/3) - 3), x)) == set([S(0), S(27)]) @@ -261,3 +260,22 @@ def test_tsolve_1(): def test_tsolve_2(): x, y, a, b = symbols('xyab') assert solve(y-a*x**b, x) == [y**(1/b)*(1/a)**(1/b)] + + +def test_solveForFunctionsDerivatives(): + t = Symbol('t') + x = Function('x')(t) + y = Function('y')(t) + a11,a12,a21,a22,b1,b2 = symbols('a11','a12','a21','a22','b1','b2') + soln = solve([a11*x + a12*y - b1, a21*x + a22*y - b2], x, y) + assert soln == { y : (a11*b2 - a21*b1)/(a11*a22 - a12*a21), + x : (a22*b1 - a12*b2)/(a11*a22 - a12*a21) } + assert solve(x-1, x) == [1] + assert solve(3*x-2, x) == [Rational(2,3)] + soln = solve([a11*x.diff(t) + a12*y.diff(t) - b1, a21*x.diff(t) + + a22*y.diff(t) - b2], x.diff(t), y.diff(t)) + assert soln == { y.diff(t) : (a11*b2 - a21*b1)/(a11*a22 - a12*a21), + x.diff(t) : (a22*b1 - a12*b2)/(a11*a22 - a12*a21) } + assert solve(x.diff(t)-1, x.diff(t)) == [1] + assert solve(3*x.diff(t)-2, x.diff(t)) == [Rational(2,3)] + -- 1.6.0.4