Second try.

--~--~---------~--~----~------------~-------~--~----~
You received this message because you are subscribed to the Google Groups 
"sympy-patches" group.
To post to this group, send email to sympy-patches@googlegroups.com
To unsubscribe from this group, send email to 
sympy-patches+unsubscr...@googlegroups.com
For more options, visit this group at 
http://groups.google.com/group/sympy-patches?hl=en
-~----------~----~----~----~------~----~------~--~---

From 423f56cf3a63b8c9e2918b07e60dfca2c2913fc9 Mon Sep 17 00:00:00 2001
From: Vinzent Steinberg <vinzent.steinb...@gmail.com>
Date: Fri, 10 Apr 2009 10:47:16 +0200
Subject: [PATCH] change nsolve(x, f, x0) to nsolve(f, x, x0), fix a doctest, support for Eq

For onedimensional functions nsolve(f, x) is valid.
---
 sympy/solvers/solvers.py            |   53 ++++++++++++++++++++++++++--------
 sympy/solvers/tests/test_numeric.py |    9 +++--
 2 files changed, 45 insertions(+), 17 deletions(-)

diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py
index 060884b..adb07d9 100644
--- a/sympy/solvers/solvers.py
+++ b/sympy/solvers/solvers.py
@@ -801,15 +801,19 @@ def msolve(*args, **kwargs):
     msolve() has been renamed to nsolve(), please use nsolve() directly."""
     warn('msolve() is has been renamed, please use nsolve() instead',
          DeprecationWarning)
+    args[0], args[1] = args[1], args[0]
     return nsolve(*args, **kwargs)
 
-# TODO: option for calculating J numerically, support for Eq()
-def nsolve(args, f, x0, modules=['mpmath'], **kwargs):
+# TODO: option for calculating J numerically
+def nsolve(*args, **kwargs):
     """
     Solve a nonlinear equation system numerically.
 
+    nsolve(f, [args,] x0, modules=['mpmath'], **kwargs)
+
     f is a vector function of symbolic expressions representing the system.
-    args are the variables.
+    args are the variables. If there is only one variable, this argument can be
+    omitted.
     x0 is a starting vector close to a solution.
 
     Use the modules keyword to specify which modules should be used to evaluate
@@ -820,48 +824,71 @@ def nsolve(args, f, x0, modules=['mpmath'], **kwargs):
     Overdetermined systems are supported.
 
     >>> from sympy import Symbol, nsolve
+    >>> import sympy
+    >>> sympy.mpmath.mp.dps = 15
     >>> x1 = Symbol('x1')
     >>> x2 = Symbol('x2')
     >>> f1 = 3 * x1**2 - 2 * x2**2 - 1
     >>> f2 = x1**2 - 2 * x1 + x2**2 + 2 * x2 - 8
-    >>> print nsolve((x1, x2), (f1, f2), (-1, 1))
+    >>> print nsolve((f1, f2), (x1, x2), (-1, 1))
     [-1.19287309935246]
     [ 1.27844411169911]
 
     For onedimensional functions the syntax is simplified:
 
     >>> from sympy import sin
-    >>> nsolve(x, sin(x), 2)
+    >>> nsolve(sin(x), x, 2)
+    mpf('3.1415926535897932')
+    >>> nsolve(sin(x), 2)
     mpf('3.1415926535897932')
 
     mpmath.findroot is used, you can find there more extensive documentation,
     especially concerning keyword parameters and available solvers.
     """
     # interpret arguments
+    if len(args) == 3:
+        f = args[0]
+        fargs = args[1]
+        x0 = args[2]
+    elif len(args) == 2:
+        f = args[0]
+        fargs = None
+        x0 = args[1]
+    elif len(args) < 2:
+        raise TypeError('nsolve expected at least 2 arguments, got %i'
+                        % len(args))
+    else:
+        raise TypeError('nsolve expected at most 3 arguments, got %i'
+                        % len(args))
+    modules = kwargs.get('modules', ['mpmath'])
     if isinstance(f,  (list,  tuple)):
         f = Matrix(f).T
     if not isinstance(f, Matrix):
         # assume it's a sympy expression
+        if isinstance(f, Equality):
+            f = f.lhs - f.rhs
         f = f.evalf()
-        atoms = f.atoms()
-        if not (len(atoms) == 1 and (args in atoms or args[0] in atoms)):
+        atoms = set(s for s in f.atoms() if isinstance(s, Symbol))
+        if fargs is None:
+            fargs = atoms.copy().pop()
+        if not (len(atoms) == 1 and (fargs in atoms or fargs[0] in atoms)):
             raise ValueError('expected a onedimensional and numerical function')
-        f = lambdify(args, f, modules)
+        f = lambdify(fargs, f, modules)
         return findroot(f, x0, **kwargs)
-    if len(args) != f.cols:
-        raise NotImplementedError('need exactly as many variables as equations')
+    if len(fargs) > f.cols:
+        raise NotImplementedError('need at least as many equations as variables')
     verbose = kwargs.get('verbose', False)
     if verbose:
         print 'f(x):'
         print f
     # derive Jacobian
-    J = f.jacobian(args)
+    J = f.jacobian(fargs)
     if verbose:
         print 'J(x):'
         print J
     # create functions
-    f = lambdify(args, f.T, modules)
-    J = lambdify(args, J, modules)
+    f = lambdify(fargs, f.T, modules)
+    J = lambdify(fargs, J, modules)
     # solve the system numerically
     x = findroot(f, x0, J=J, **kwargs)
     return x
diff --git a/sympy/solvers/tests/test_numeric.py b/sympy/solvers/tests/test_numeric.py
index 21a35f5..cfd4033 100644
--- a/sympy/solvers/tests/test_numeric.py
+++ b/sympy/solvers/tests/test_numeric.py
@@ -1,13 +1,14 @@
 from sympy.mpmath import mnorm_1
 from sympy.solvers import nsolve
 from sympy.utilities.lambdify import lambdify
-from sympy import Symbol, Matrix,  sqrt
+from sympy import Symbol, Matrix, sqrt, Eq
 
 def test_nsolve():
     # onedimensional
     from sympy import Symbol, sin, pi
     x = Symbol('x')
-    assert nsolve(x, sin(x), 2) - pi.evalf() < 1e-16
+    assert nsolve(sin(x), 2) - pi.evalf() < 1e-16
+    assert nsolve(Eq(2*x, 2), x, -10) == nsolve(2*x - 2, -10)
     # multidimensional
     x1 = Symbol('x1')
     x2 = Symbol('x2')
@@ -16,7 +17,7 @@ def test_nsolve():
     f = Matrix((f1, f2)).T
     F = lambdify((x1, x2), f.T, modules='mpmath')
     for x0 in [(-1, 1), (1, -2), (4, 4), (-4, -4)]:
-        x = nsolve((x1, x2), f, x0, tol=1.e-8)
+        x = nsolve(f, (x1, x2), x0, tol=1.e-8)
         assert mnorm_1(F(*x)) <= 1.e-10
     # The Chinese mathematician Zhu Shijie was the very first to solve this
     # nonlinear system 700 years ago (z was added to make it 3-dimensional)
@@ -29,7 +30,7 @@ def test_nsolve():
     f = Matrix((f1, f2, f3)).T
     F = lambdify((x,  y,  z), f.T, modules='mpmath')
     def getroot(x0):
-        root = nsolve((x,  y,  z),  (f1,  f2,  f3),  x0)
+        root = nsolve((f1,  f2,  f3), (x,  y,  z), x0)
         assert mnorm_1(F(*root)) <= 1.e-8
         return root
     assert map(round,  getroot((1,  1,  1))) == [2.0,  1.0,  0.0]
-- 
1.6.0.2

Reply via email to