Issue 626 is fixed by these two patches. Also, solve can now Symbol, Function, and Derivative types.
I implemented tests and everything passes on my machine. Thanks, ~Luke --~--~---------~--~----~------------~-------~--~----~ You received this message because you are subscribed to the Google Groups "sympy-patches" group. To post to this group, send email to sympy-patches@googlegroups.com To unsubscribe from this group, send email to sympy-patches+unsubscr...@googlegroups.com For more options, visit this group at http://groups.google.com/group/sympy-patches?hl=en -~----------~----~----~----~------~----~------~--~---
From c0a8c3f80c494b703a7b2a1c80ebaab691f0aaeb Mon Sep 17 00:00:00 2001 From: Luke Peterson <hazelnu...@gmail.com> Date: Thu, 21 May 2009 18:42:24 -0700 Subject: [PATCH] Implemented code to allow for solving for Function and Derivative instances. --- sympy/solvers/solvers.py | 56 +++++++++++++++++++++++++++++++--- sympy/solvers/tests/test_solvers.py | 20 ++++++++++++- 2 files changed, 70 insertions(+), 6 deletions(-) diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py index ae42d4b..50ad7a4 100644 --- a/sympy/solvers/solvers.py +++ b/sympy/solvers/solvers.py @@ -19,7 +19,7 @@ from sympy.core.basic import Basic, S, C, Mul, Add from sympy.core.power import Pow from sympy.core.symbol import Symbol, Wild from sympy.core.relational import Equality -from sympy.core.function import Derivative, diff +from sympy.core.function import Derivative, diff, Function from sympy.core.numbers import ilcm from sympy.functions import sqrt, log, exp, LambertW @@ -148,12 +148,40 @@ def solve(f, *symbols, **flags): symbols = map(sympify, symbols) result = list() - if any(not s.is_Symbol for s in symbols): - raise TypeError('not a Symbol') + symbols_new = [] + SYMBOL_SWAPPED = False + + if isinstance(symbols, (list, tuple)): + symbols_passed = symbols[:] + elif isinstance(symbols, set): + symbols_passed = symbols.copy() + + i = 0 + for s in symbols: + if s.is_Symbol: + s_new = s + elif s.is_Function: + SYMBOL_SWAPPED = True + s_new = Symbol('F%d' % i, dummy=True) + elif s.is_Derivative: + SYMBOL_SWAPPED = True + s_new = Symbol('D%d' % i, dummy=True) + else: + raise TypeError('not a Symbol or a Function') + symbols_new.append(s_new) + i += 1 + + if SYMBOL_SWAPPED: + swap_back_dict = dict(zip(symbols_new, symbols)) if not isinstance(f, (tuple, list, set)): f = sympify(f) + if SYMBOL_SWAPPED: + swap_dict = zip(symbols, symbols_new) + f = f.subs(swap_dict) + symbols = symbols_new + if isinstance(f, Equality): f = f.lhs - f.rhs @@ -246,6 +274,11 @@ def solve(f, *symbols, **flags): if not f: return {} else: + if SYMBOL_SWAPPED: + swap_dict = zip(symbols, symbols_new) + f = [fi.subs(swap_dict) for fi in f] + symbols = symbols_new + polys = [] for g in f: @@ -273,9 +306,21 @@ def solve(f, *symbols, **flags): except ValueError: matrix[i, m] = -coeff - return solve_linear_system(matrix, *symbols, **flags) + soln = solve_linear_system(matrix, *symbols, **flags) else: - return solve_poly_system(polys) + soln = solve_poly_system(polys) + + if SYMBOL_SWAPPED: + if isinstance(soln, dict): + res = {} + for k in soln.keys(): + res.update({swap_back_dict[k]: soln[k]}) + return res + else: + return soln + else: + return soln + def solve_linear_system(system, *symbols, **flags): """Solve system of N linear equations with M variables, which means @@ -883,3 +928,4 @@ def nsolve(*args, **kwargs): # solve the system numerically x = findroot(f, x0, J=J, **kwargs) return x + diff --git a/sympy/solvers/tests/test_solvers.py b/sympy/solvers/tests/test_solvers.py index c5ea403..de6359e 100644 --- a/sympy/solvers/tests/test_solvers.py +++ b/sympy/solvers/tests/test_solvers.py @@ -113,7 +113,6 @@ def test_solve_polynomial_cv_1a(): def test_solve_polynomial_cv_1b(): x, a = symbols('x a') - assert set(solve(4*x*(1 - a*x**(S(1)/2)), x)) == set([S(0), 1/a**2]) assert set(solve(x * (x**(S(1)/3) - 3), x)) == set([S(0), S(27)]) @@ -261,3 +260,22 @@ def test_tsolve_1(): def test_tsolve_2(): x, y, a, b = symbols('xyab') assert solve(y-a*x**b, x) == [y**(1/b)*(1/a)**(1/b)] + + +def test_solveForFunctionsDerivatives(): + t = Symbol('t') + x = Function('x')(t) + y = Function('y')(t) + a11,a12,a21,a22,b1,b2 = symbols('a11','a12','a21','a22','b1','b2') + soln = solve([a11*x + a12*y - b1, a21*x + a22*y - b2], x, y) + assert soln == { y : (a11*b2 - a21*b1)/(a11*a22 - a12*a21), + x : (a22*b1 - a12*b2)/(a11*a22 - a12*a21) } + assert solve(x-1, x) == [1] + assert solve(3*x-2, x) == [Rational(2,3)] + soln = solve([a11*x.diff(t) + a12*y.diff(t) - b1, a21*x.diff(t) + + a22*y.diff(t) - b2], x.diff(t), y.diff(t)) + assert soln == { y.diff(t) : (a11*b2 - a21*b1)/(a11*a22 - a12*a21), + x.diff(t) : (a22*b1 - a12*b2)/(a11*a22 - a12*a21) } + assert solve(x.diff(t)-1, x.diff(t)) == [1] + assert solve(3*x.diff(t)-2, x.diff(t)) == [Rational(2,3)] + -- 1.6.0.4
From aeef51e6907af4062909c101e2c1b7164ae1813e Mon Sep 17 00:00:00 2001 From: Luke Peterson <hazelnu...@gmail.com> Date: Fri, 22 May 2009 10:33:07 -0700 Subject: [PATCH] More functionality to allow solve to work with sets and Derivative objects. --- sympy/core/basic.py | 1 + sympy/solvers/solvers.py | 2 +- sympy/solvers/tests/test_solvers.py | 8 ++++++++ 3 files changed, 10 insertions(+), 1 deletions(-) diff --git a/sympy/core/basic.py b/sympy/core/basic.py index b48e633..f5a7556 100644 --- a/sympy/core/basic.py +++ b/sympy/core/basic.py @@ -70,6 +70,7 @@ ordering_of_classes = [ 'Lambda', # operators 'FDerivative','FApply', + # composition of functions 'FPow', 'Composition', # Landau O symbol diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py index 50ad7a4..e9d1445 100644 --- a/sympy/solvers/solvers.py +++ b/sympy/solvers/solvers.py @@ -154,7 +154,7 @@ def solve(f, *symbols, **flags): if isinstance(symbols, (list, tuple)): symbols_passed = symbols[:] elif isinstance(symbols, set): - symbols_passed = symbols.copy() + symbols_passed = list(symbols) i = 0 for s in symbols: diff --git a/sympy/solvers/tests/test_solvers.py b/sympy/solvers/tests/test_solvers.py index de6359e..6a191f4 100644 --- a/sympy/solvers/tests/test_solvers.py +++ b/sympy/solvers/tests/test_solvers.py @@ -278,4 +278,12 @@ def test_solveForFunctionsDerivatives(): x.diff(t) : (a22*b1 - a12*b2)/(a11*a22 - a12*a21) } assert solve(x.diff(t)-1, x.diff(t)) == [1] assert solve(3*x.diff(t)-2, x.diff(t)) == [Rational(2,3)] + eqns = set((3*x - 1, 2*y-4)) + assert solve(eqns, set((x,y))) == { x : Rational(1, 3), y: 2 } + + x = Symbol('x') + f = Function('f') + F = x**2 + f(x)**2 - 4*x - 1 + assert solve(F.diff(x), diff(f(x), x)) == [(2 - x)/f(x)] + -- 1.6.0.4