Great job! +1

Ondrej

On Sat, Jan 24, 2009 at 11:12 AM, Mateusz Paprocki <matt...@gmail.com> wrote:
>
> This is a preliminary implementation of rsolve() using rsolve_hyper()
> as the main recurrence solver. The syntax provided by rsolve() is in
> general similar to Mathematica's syntax. The main difference is the
> way initial conditions are specified, e.g.:
>
> In [1]: f = y(n+2) - y(n+1) - y(n)
>
> In [2]: rsolve(f, y(n), { y(0):0, y(1):5 })
> Out[2]:
>                   n                      n
>      ⎛        ⎽⎽⎽⎞          ⎛        ⎽⎽⎽⎞
>  ⎽⎽⎽ ⎜      ╲╱ 5 ⎟      ⎽⎽⎽ ⎜      ╲╱ 5 ⎟
> ╲╱ 5 ⋅⎜1/2 + ─────⎟  - ╲╱ 5 ⋅⎜1/2 - ─────⎟
>      ⎝        2  ⎠          ⎝        2  ⎠
>
> Currently supported are all kinds of recurrence relations that
> are supported by rsolve_hyper(), i.e. linear with polynomial or
> fractional coefficients (by removing denominators using LCM and
> exact division), homogeneous or inhomogeneous (hypergeometric).
>
> See the docstring for details on the exact syntax.
> ---
>  sympy/solvers/recurr.py            |  177 
> +++++++++++++++++++++++++++++++++++-
>  sympy/solvers/tests/test_recurr.py |   29 ++++++-
>  2 files changed, 199 insertions(+), 7 deletions(-)
>
> diff --git a/sympy/solvers/recurr.py b/sympy/solvers/recurr.py
> index cf9a339..f06d798 100644
> --- a/sympy/solvers/recurr.py
> +++ b/sympy/solvers/recurr.py
> @@ -5,18 +5,20 @@
>    The solutions are obtained among polynomials, rational functions,
>    hypergeometric terms, or combinations of hypergeometric term which
>    are pairwise dissimilar.
> +
>  """
>
>  from sympy.core.basic import Basic, S
>  from sympy.core.numbers import Rational
> -from sympy.core.symbol import Symbol
> +from sympy.core.symbol import Symbol, Wild
> +from sympy.core.relational import Equality
>  from sympy.core.add import Add
>  from sympy.core.mul import Mul
>  from sympy.core import sympify
>
>  from sympy.simplify import simplify, hypersimp, hypersimilar, collect
>  from sympy.solvers import solve, solve_undetermined_coeffs
> -from sympy.polys import Poly, quo, gcd, roots, resultant
> +from sympy.polys import Poly, quo, gcd, lcm, roots, resultant
>  from sympy.functions import Binomial, FallingFactorial
>  from sympy.matrices import Matrix, casoratian
>  from sympy.concrete import product
> @@ -554,8 +556,173 @@ def rsolve_hyper(coeffs, f, n, **hints):
>     else:
>         return result
>
> -def rsolve(eq, seq):
> -    """
> +def rsolve(f, y, init=None):
> +    """Solve univariate recurrence with rational coefficients.
> +
> +       Given k-th order linear recurrence Ly = f, or equivalently:
> +
> +         a_{k}(n) y(n+k) + a_{k-1}(n) y(n+k-1) + ... + a_{0}(n) y(n) = f
> +
> +       where a_{i}(n), for i=0..k, are polynomials or rational functions
> +       in n, and f is a hypergeometric function or a sum of a fixed number
> +       of pairwise dissimilar hypergeometric terms in n, finds all solutions
> +       or returns None, if none were found.
> +
> +       Initial conditions can be given as a dictionary in two forms:
> +
> +          [1] {   n_0  : v_0,   n_1  : v_1, ...,   n_m  : v_m }
> +          [2] { y(n_0) : v_0, y(n_1) : v_1, ..., y(n_m) : v_m }
> +
> +       or as a list L of values:
> +
> +          L = [ v_0, v_1, ..., v_m ]
> +
> +       where L[i] = v_i, for i=0..m, maps to y(n_i).
> +
> +       As an example lets consider the following recurrence:
> +
> +         (n - 1) y(n + 2) - (n**2 + 3 n - 2) y(n + 1) + 2 n (n + 1) y(n) == 0
> +
> +       >>> from sympy import *
> +
> +       >>> y = Function('y')
> +       >>> n = Symbol('n', integer=True)
> +
> +       >>> f = (n-1)*y(n+2) - (n**2+3*n-2)*y(n+1) + 2*n*(n+1)*y(n)
> +
> +       >>> rsolve(f, y(n))
> +       C0*n! + C1*2**n
> +
> +       >>> rsolve(f, y(n), { y(0):0, y(1):3 })
> +       -3*n! + 3*2**n
>
>     """
> -    pass
> +    if isinstance(f, Equality):
> +        f = f.lhs - f.rhs
> +
> +    if f.is_Add:
> +        F = f.args
> +    else:
> +        F = [f]
> +
> +    k = Wild('k')
> +    n = y.args[0]
> +
> +    h_part = {}
> +    i_part = S.Zero
> +
> +    for g in F:
> +        if g.is_Mul:
> +            G = g.args
> +        else:
> +            G = [g]
> +
> +        coeff = S.One
> +        kspec = None
> +
> +        for h in G:
> +            if h.is_Function:
> +                if h.func == y.func:
> +                    result = h.args[0].match(n + k)
> +
> +                    if result is not None:
> +                        kspec = int(result[k])
> +                    else:
> +                        raise ValueError("'%s(%s+k)' expected, got '%s'" % 
> (y.func, n, h))
> +                else:
> +                    raise ValueError("'%s' expected, got '%s'" % (y.func, 
> h.func))
> +            else:
> +                coeff *= h
> +
> +        if kspec is not None:
> +            if h_part.has_key(kspec):
> +                h_part[kspec] += coeff
> +            else:
> +                h_part[kspec] = coeff
> +        else:
> +            i_part += coeff
> +
> +    for k, coeff in h_part.iteritems():
> +        h_part[k] = simplify(coeff)
> +
> +    common = S.One
> +
> +    for coeff in h_part.itervalues():
> +        if coeff.is_fraction(n):
> +            if not coeff.is_polynomial(n):
> +                common = lcm(common, coeff.as_numer_denom()[1], n)
> +        else:
> +            raise ValueError("Polynomial or rational function expected, got 
> '%s'" % coeff)
> +
> +    i_numer, i_denom = i_part.as_numer_denom()
> +
> +    if i_denom.is_polynomial(n):
> +        common = lcm(common, i_denom, n)
> +
> +    if common is not S.One:
> +        for k, coeff in h_part.iteritems():
> +            numer, denom = coeff.as_numer_denom()
> +            h_part[k] = numer*quo(common, denom, n)
> +
> +        i_part = i_numer*quo(common, i_denom, n)
> +
> +    K_min = min(h_part.keys())
> +
> +    if K_min < 0:
> +        K = abs(K_min)
> +
> +        H_part = {}
> +        i_part = i_part.subs(n, n+K).expand()
> +        common = common.subs(n, n+K).expand()
> +
> +        for k, coeff in h_part.iteritems():
> +            H_part[k+K] = coeff.subs(n, n+K).expand()
> +    else:
> +        H_part = h_part
> +
> +    K_max = max(H_part.keys())
> +    coeffs = []
> +
> +    for i in xrange(0, K_max+1):
> +        if H_part.has_key(i):
> +            coeffs.append(H_part[i])
> +        else:
> +            coeffs.append(S.Zero)
> +
> +    result = rsolve_hyper(coeffs, i_part, n, symbols=True)
> +
> +    if result is None:
> +        return None
> +    else:
> +        solution, symbols = result
> +
> +        if symbols and init is not None:
> +            equations = []
> +
> +            if type(init) is list:
> +                for i in xrange(0, len(init)):
> +                    eq = solution.subs(n, i) - init[i]
> +                    equations.append(eq)
> +            else:
> +                for k, v in init.iteritems():
> +                    try:
> +                        i = int(k)
> +                    except TypeError:
> +                        if k.is_Function and k.func == y.func:
> +                            i = int(k.args[0])
> +                        else:
> +                            raise ValueError("Integer or term expected, got 
> '%s'" % k)
> +
> +                    eq = solution.subs(n, i) - v
> +                    equations.append(eq)
> +
> +            result = solve(equations, *symbols)
> +
> +            if result is None:
> +                return None
> +            else:
> +                for k, v in result.iteritems():
> +                    solution = solution.subs(k, v)
> +
> +    return (solution.expand()) / common
> +
> diff --git a/sympy/solvers/tests/test_recurr.py 
> b/sympy/solvers/tests/test_recurr.py
> index 1dfe9cc..620a349 100644
> --- a/sympy/solvers/tests/test_recurr.py
> +++ b/sympy/solvers/tests/test_recurr.py
> @@ -1,6 +1,7 @@
> -from sympy import symbols, rsolve_hyper, rsolve_poly, rsolve_ratio, S, sqrt, 
> \
> -        rf, factorial
> +from sympy import Function, symbols, S, sqrt, rf, factorial
> +from sympy.solvers.recurr import rsolve, rsolve_poly, rsolve_ratio, 
> rsolve_hyper
>
> +y = Function('y')
>  n, k = symbols('nk', integer=True)
>  C0, C1, C2 = symbols('C0', 'C1', 'C2')
>
> @@ -68,3 +69,27 @@ def test_rsolve_bulk():
>                 yield rsolve_bulk_checker, rsolve_poly, c, q, p
>             #if p.is_hypergeometric(n):
>             #    yield rsolve_bulk_checker, rsolve_hyper, c, q, p
> +
> +def test_rsolve():
> +    f = y(n+2) - y(n+1) - y(n)
> +    g = C0*(S.Half + S.Half*sqrt(5))**n \
> +      + C1*(S.Half - S.Half*sqrt(5))**n
> +    h = sqrt(5)*(S.Half + S.Half*sqrt(5))**n \
> +      - sqrt(5)*(S.Half - S.Half*sqrt(5))**n
> +
> +    assert rsolve(f, y(n)) == g
> +
> +    assert rsolve(f, y(n), [      0,      5 ]) == h
> +    assert rsolve(f, y(n), {   0 :0,   1 :5 }) == h
> +    assert rsolve(f, y(n), { y(0):0, y(1):5 }) == h
> +
> +    f = (n-1)*y(n+2) - (n**2+3*n-2)*y(n+1) + 2*n*(n+1)*y(n)
> +    g = C0*factorial(n) + C1*2**n
> +    h = -3*factorial(n) + 3*2**n
> +
> +    assert rsolve(f, y(n)) == g
> +
> +    assert rsolve(f, y(n), [      0,      3 ]) == h
> +    assert rsolve(f, y(n), {   0 :0,   1 :3 }) == h
> +    assert rsolve(f, y(n), { y(0):0, y(1):3 }) == h
> +
> --
> 1.6.1
>
>
> >
>

--~--~---------~--~----~------------~-------~--~----~
You received this message because you are subscribed to the Google Groups 
"sympy-patches" group.
To post to this group, send email to sympy-patches@googlegroups.com
To unsubscribe from this group, send email to 
sympy-patches+unsubscr...@googlegroups.com
For more options, visit this group at 
http://groups.google.com/group/sympy-patches?hl=en
-~----------~----~----~----~------~----~------~--~---

Reply via email to