Hi, On Wed, Aug 5, 2009 at 7:25 AM, Burcin Erocal<[email protected]> wrote: > IIRC, you wrote that your implementation can coexist with the current > one in Sage. Why don't you submit your changes so people can try out > both approaches?
Here goes the patches for Sage-4.1.1. I am seeking review from both pynac and sage volunteers. I am attaching the pynac patch with this mail and marking a CC to pynac-devel. Necessary instructions and needed patch are posted in the trac. http://trac.sagemath.org/sage_trac/ticket/6756 Cheers, Golam --~--~---------~--~----~------------~-------~--~----~ To post to this group, send an email to [email protected] To unsubscribe from this group, send an email to [email protected] For more options, visit this group at http://groups.google.com/group/sage-devel URLs: http://www.sagemath.org -~----------~----~----~----~------~----~------~--~---
# HG changeset patch # User Golam Mortuza Hossain <[email protected]> # Date 1250425297 10800 # Node ID 6ecd738aa8a4241ccf6dd1d212e1c2e69b350711 # Parent 2e94e1e945b4bbc017c91ff1030e1f7e8f87fb6d Implement *diff* format symbolic derivative diff -r 2e94e1e945b4 -r 6ecd738aa8a4 ginac/fderivative.cpp --- a/ginac/fderivative.cpp Sat Aug 01 04:09:34 2009 +0200 +++ b/ginac/fderivative.cpp Sun Aug 16 09:21:37 2009 -0300 @@ -22,6 +22,12 @@ #include <iostream> +#include "ex.h" +#include "numeric.h" +#include "relational.h" +#include "symbol.h" +#include "add.h" +#include "mul.h" #include "fderivative.h" #include "operators.h" #include "archive.h" @@ -277,4 +283,305 @@ return res; } +////////////////// +// Implementation of *diff* format symbolic derivative +// Golam Mortuza Hossain ([email protected]) 14-08-2009 +////////////////// + +/* Append all symbols and conditionally functions of *e* in the list *l* */ +void list_symbols_and_functions(const ex & e, lst & l, bool only_symbols) +{ + if (is_a<symbol>(e)) + l.append(e); + else { + if (!only_symbols && is_a<function>(e)) + l.append(e); + for(const_iterator ei = e.begin(); ei != e.end(); ++ei) + list_symbols_and_functions(*ei, l, only_symbols); + } +} + +/* Check whether the expression contains the given symbol */ +bool has_symbol(const ex & e, const symbol & s) +{ + if (e.is_equal(s)) + return true; + else + for(const_iterator ei = e.begin(); ei != e.end(); ++ei) + if (has_symbol(*ei, s)) + return true; + return false; +} + +/* Check whether the expression has any symbol listed in *l* */ +bool ex_is_independent(const ex & e, const lst & l) +{ + for(lst::const_iterator li = l.begin(); li != l.end(); ++li) + if (has_symbol(e, ex_to<symbol>(*li))) + return false; + return true; +} + +/* Construct a new sorted exvector seq for diff. */ +exvector get_new_diff_seq(const exvector & seq, const ex & e, const ex & n) +{ + exvector nseq; + size_t num = seq.size(); + + // First differentiation + if (num == 1) { + if (is_a<numeric>(n) && ex_to<numeric>(n).is_nonneg_integer() + && is_a<symbol>(e) && !has_symbol(seq[0], ex_to<symbol>(e))) { + nseq.push_back(_ex0); + return nseq; + } + nseq = seq; + nseq.push_back(e); + nseq.push_back(n); + return nseq; + } + + // Check for existing variable + lst l; + for (size_t i=1; i<num; i=i+2) { + l.append(seq[i]); + if (seq[i].is_equal(e)) { + nseq = seq; + nseq[i+1] = nseq[i+1] + n; + return nseq; + } + } + + // We got a new variable + if (is_a<numeric>(n) && ex_to<numeric>(n).is_nonneg_integer() + && is_a<symbol>(e) && !has_symbol(seq[0], ex_to<symbol>(e))) { + nseq.push_back(_ex0); + return nseq; + } + l.append(e); l.sort(); + bool got_it = false; + size_t i=0, j; + nseq.push_back(seq[0]); + + for (lst::const_iterator li = l.begin(); li != l.end(); ++li) { + if(got_it) { + j = 2*i - 1; + nseq.push_back(seq[j]); + nseq.push_back(seq[j+1]); + } else { + if(e.is_equal(*li)) { + nseq.push_back(e); + nseq.push_back(n); + got_it = true; + } else { + j = 2*i + 1; + nseq.push_back(seq[j]); + nseq.push_back(seq[j+1]); + } + } + i++; + } + return nseq; +} + +/* Symbolically evaluate given diff seq. + It does: + (1) If it is diff of diff then merge two seq's + (2) remove diff var x if it acts 0 times + (3) Put the seq in preferred sorted order +*/ +exvector symbolic_eval_diff_seq(const exvector & orig_seq) +{ + ex f = orig_seq[0]; + exvector seq; + if (is_a<function>(f) && ex_to<function>(f).get_serial() + == diff_derivative_SERIAL::serial) + for (const_iterator fi = f.begin(); fi != f.end(); ++fi) + seq.push_back(*fi); + else + seq.push_back(f); + + size_t num = orig_seq.size(); + for (size_t i=1; i<num; i=i+2) + if(!orig_seq[i+1].is_zero()) + seq = get_new_diff_seq(seq, orig_seq[i], orig_seq[i+1]); + + return seq; +} + +/* Return a simplified un-evaluated symbolic derivative */ +ex simplified_diff_expression(const ex & result, const lst & l, exvector & left_over_seq) +{ + ex ifactor = _ex1; + ex nresult = _ex1; + + if(is_a<mul>(result)) { + for(const_iterator ri = result.begin(); ri != result.end(); ++ri) { + if(ex_is_independent(*ri, l)) + ifactor = (*ri)*ifactor; + else + nresult = (*ri)*nresult; + } + left_over_seq[0] = nresult; + left_over_seq = symbolic_eval_diff_seq(left_over_seq); + nresult = ifactor*diff_derivative(left_over_seq).hold(); + } else if(is_a<add>(result)) { + nresult = _ex0; + for(const_iterator ri = result.begin(); ri != result.end(); ++ri) + nresult += simplified_diff_expression(*ri, l, left_over_seq); + } else { + left_over_seq[0] = result; + left_over_seq = symbolic_eval_diff_seq(left_over_seq); + nresult = diff_derivative(left_over_seq).hold(); + } + + return nresult; +} + +/* + Derivative w.r.t. to an expression often will not work + using substitution method. For example, x^2 can't + be substituted in x^4 for evaluating diff(x^4, x^2, 1). + Here we use following relation (strictly valid when z=z(x)) + df/dz = (df/dx)/(dz/dx) + If z = z(x,y) then we verify whether we get the same answer + even w.r.t. *y*. If not then f(x,y) != f(z(x,y)) and this + derivative can't be computed without knowing which variable + combination is kept fixed. In such case, we return it un-evaluted. +*/ +unsigned int diff_recursion = 0; + +bool basic_diff_wrt_ex(const ex & f, const ex & e, unsigned int n, const lst & l, ex & result) +{ + ex lresult, nresult, denom; + symbol s, ns; + unsigned int i; + + result = f; + while (n>0) { + i = 0; + for (lst::const_iterator li = l.begin(); li != l.end(); ++li) { + if(is_a<symbol>(*li)) { + s = ex_to<symbol>(*li); + nresult = result.diff(s)/(e.diff(s)); + } else { + // If following steps results in a symbolic diff then the last + // .subs() will get us in a loop. If so break the loop. + if (diff_recursion > 0) + return false; + diff_recursion++; + nresult = result.subs(*li==ns).diff(ns).subs(ns==*li); + denom = e.subs(*li==ns).diff(ns).subs(ns==*li); + diff_recursion--; + //Check denominator, if zero then give up + if (denom.is_zero()) + return false; + nresult = nresult/denom; + } + if (i == 0) + lresult = nresult; + else if (!nresult.is_equal(lresult)) + return false; + i++; + } + result = lresult; + if(result.is_zero()) + break; + n--; + } + return true; +} + +/* Evaluate using *basic_diff_wrt_ex* with some checks*/ +ex eval_diff_wrt_ex(const ex & f, const ex & e, unsigned int n) +{ + exvector seq; + seq.push_back(f); + seq.push_back(e); + seq.push_back(n); + + lst l; + list_symbols_and_functions(e, l, false); + l.unique(); + if(l.nops() < 1) + return diff_derivative(symbolic_eval_diff_seq(seq)).hold(); + + ex result; + if(basic_diff_wrt_ex(f, e, n, l, result)) + return result; + + ex nresult; + if(is_a<add>(f)) { + nresult = _ex0; + for(const_iterator fi = f.begin(); fi != f.end(); ++fi) + if(basic_diff_wrt_ex(*fi, e, n, l, result)) + nresult += result; + else + nresult += simplified_diff_expression(*fi, l, seq); + return nresult ; + } + + return simplified_diff_expression(f, l, seq); +} + +/* flag to disable chain rule */ +bool no_diff_chain = false; + +/* Main evaluation function for symbolic diff */ +static ex diff_derivative_eval(const exvector & orig_seq) +{ + //Sanity checks for inputs + size_t num = orig_seq.size(); + if( num < 3 || num % 2 != 1) + throw(std::runtime_error("diff::eval(): malformed arguments. It should be of the form (f(x), x, n)")); + + exvector seq = symbolic_eval_diff_seq(orig_seq); + if(seq[0].is_zero()) + return _ex0; + + num = seq.size(); + ex result = seq[0]; + exvector left_over_seq; + lst l; + bool n_is_symbolic; + unsigned int n; + no_diff_chain = true; // Disable chain rule as either it has been + // applied or not needed + left_over_seq.push_back(result); + for (size_t i=1; i<num; i=i+2) { + if (!is_a<numeric>(seq[i+1]) || + !ex_to<numeric>(seq[i+1]).is_nonneg_integer()) + n_is_symbolic = true; + else { + n_is_symbolic = false; + n = ex_to<numeric>(seq[i+1]).to_int(); + } + + if(n_is_symbolic) { + left_over_seq.push_back(seq[i]); + left_over_seq.push_back(seq[i+1]); + list_symbols_and_functions(seq[i], l, true); //only symbols + } + else if(is_a<symbol>(seq[i])) + result = result.diff(ex_to<symbol>(seq[i]), n); + else + result = eval_diff_wrt_ex(result, seq[i], n); + if(result.is_zero()) + break; + } + // Check for left over + num = left_over_seq.size(); + if(!result.is_zero() && num > 1) { + l.unique(); + result = simplified_diff_expression(result, l, left_over_seq); + } + no_diff_chain = false; + return result; +} + +/* Register symbolic diff derivative function */ +unsigned diff_derivative_SERIAL::serial = function::register_new( + function_options("diff", 0). + eval_func(diff_derivative_eval)); + } // namespace GiNaC diff -r 2e94e1e945b4 -r 6ecd738aa8a4 ginac/fderivative.h --- a/ginac/fderivative.h Sat Aug 01 04:09:34 2009 +0200 +++ b/ginac/fderivative.h Sun Aug 16 09:21:37 2009 -0300 @@ -85,6 +85,16 @@ paramset parameter_set; /**< Set of parameter numbers with respect to which to take the derivative */ }; +/** Diff format symbolic derivative */ +class diff_derivative_SERIAL { public: static unsigned serial; }; +template<typename T1> +inline function diff_derivative(const T1 & p1) { + return function(diff_derivative_SERIAL::serial, exvector(p1)); +} + +exvector get_new_diff_seq(const exvector & seq, const ex & e, const ex & n); +extern bool no_diff_chain; + } // namespace GiNaC #endif // ndef __GINAC_DERIVATIVE_H__ diff -r 2e94e1e945b4 -r 6ecd738aa8a4 ginac/function.cpp --- a/ginac/function.cpp Sat Aug 01 04:09:34 2009 +0200 +++ b/ginac/function.cpp Sun Aug 16 09:21:37 2009 -0300 @@ -57,6 +57,7 @@ unsigned py_get_serial_from_sfunction(PyObject* f); std::string* py_dumps(PyObject* o); PyObject* py_loads(PyObject* o); + int diff_derivative_level; } namespace GiNaC { @@ -1210,7 +1211,8 @@ GINAC_ASSERT(serial<registered_functions().size()); // Dynamically dispatch on print_context type const print_context_class_info *pc_info = &c.get_class_info(); - if (serial >= py_get_ginac_serial()) { + if (serial >= py_get_ginac_serial() || + serial == diff_derivative_SERIAL::serial) { //convert arguments to a PyTuple of Expressions PyObject* args = exvector_to_PyTuple(seq); @@ -1916,6 +1918,37 @@ throw(std::logic_error("function::imag_part(): invalid nparams")); } +/* Appy diff format symbolic derivative */ +ex apply_diff_derivative(unsigned serial, const exvector & seq, const symbol & s) +{ + ex f; + exvector nseq; + + if (serial == diff_derivative_SERIAL::serial) { + f = seq[0]; + nseq = seq; + } else { + f = function(serial, seq).hold(); + nseq.push_back(f); + } + + if (!no_diff_chain && diff_derivative_level > 1) { + ex arg_diff; + ex result = _ex0; + for (const_iterator fi = f.begin(); fi != f.end(); ++fi) { + arg_diff = (*fi).diff(s); + if (!arg_diff.is_zero()) + result += diff_derivative(get_new_diff_seq( + nseq, *fi, _ex1)).hold()*arg_diff; + } + return result; + } + nseq = get_new_diff_seq(nseq, s, _ex1); + if (nseq[0].is_zero()) + return _ex0; + return diff_derivative(nseq).hold(); +} + // protected /** Implementation of ex::diff() for functions. It applies the chain rule, @@ -1933,6 +1966,10 @@ GINAC_ASSERT(serial<registered_functions().size()); const function_options &opt = registered_functions()[serial]; + // If requested then return diff format symbolic derivative + if (diff_derivative_level > 0 && opt.derivative_f == NULL) + return apply_diff_derivative(serial, seq, s); + // Check if we need to apply chain rule if (!(opt.apply_chain_rule)) { if (opt.derivative_f == NULL)
