Hi,

On Wed, Aug 5, 2009 at 7:25 AM, Burcin Erocal<[email protected]> wrote:
> IIRC, you wrote that your implementation can coexist with the current
> one in Sage. Why don't you submit your changes so people can try out
> both approaches?


Here goes the patches for Sage-4.1.1.  I am seeking review from both
pynac and sage volunteers. I am attaching the pynac patch
with this mail and marking a CC to pynac-devel.

Necessary instructions and needed patch are posted in the trac.

http://trac.sagemath.org/sage_trac/ticket/6756


Cheers,
Golam

--~--~---------~--~----~------------~-------~--~----~
To post to this group, send an email to [email protected]
To unsubscribe from this group, send an email to 
[email protected]
For more options, visit this group at http://groups.google.com/group/sage-devel
URLs: http://www.sagemath.org
-~----------~----~----~----~------~----~------~--~---

# HG changeset patch
# User Golam Mortuza Hossain <[email protected]>
# Date 1250425297 10800
# Node ID 6ecd738aa8a4241ccf6dd1d212e1c2e69b350711
# Parent  2e94e1e945b4bbc017c91ff1030e1f7e8f87fb6d
Implement *diff* format symbolic derivative

diff -r 2e94e1e945b4 -r 6ecd738aa8a4 ginac/fderivative.cpp
--- a/ginac/fderivative.cpp	Sat Aug 01 04:09:34 2009 +0200
+++ b/ginac/fderivative.cpp	Sun Aug 16 09:21:37 2009 -0300
@@ -22,6 +22,12 @@
 
 #include <iostream>
 
+#include "ex.h"
+#include "numeric.h"
+#include "relational.h"
+#include "symbol.h"
+#include "add.h"
+#include "mul.h"
 #include "fderivative.h"
 #include "operators.h"
 #include "archive.h"
@@ -277,4 +283,305 @@
 	return res;
 }
 
+//////////////////
+// Implementation of *diff* format symbolic derivative
+// Golam Mortuza Hossain ([email protected]) 14-08-2009
+//////////////////
+
+/* Append all symbols and conditionally functions of *e* in the list *l* */
+void list_symbols_and_functions(const ex & e, lst & l, bool only_symbols)
+{
+    if (is_a<symbol>(e))
+        l.append(e);
+    else { 
+        if (!only_symbols && is_a<function>(e))
+            l.append(e);
+        for(const_iterator ei = e.begin(); ei != e.end(); ++ei) 
+            list_symbols_and_functions(*ei, l, only_symbols);
+    }
+}
+
+/* Check whether the expression contains the given symbol */
+bool has_symbol(const ex & e, const symbol & s) 
+{
+    if (e.is_equal(s))
+        return true;
+    else 
+        for(const_iterator ei = e.begin(); ei != e.end(); ++ei) 
+            if (has_symbol(*ei, s))
+                return true;
+    return false;
+}
+
+/* Check whether the expression has any symbol listed in *l* */
+bool ex_is_independent(const ex & e, const lst & l) 
+{
+    for(lst::const_iterator li = l.begin(); li != l.end(); ++li) 
+        if (has_symbol(e, ex_to<symbol>(*li)))
+            return false;
+    return true;
+}
+
+/* Construct a new sorted exvector seq for diff. */
+exvector get_new_diff_seq(const exvector & seq, const ex & e, const ex & n)
+{
+    exvector nseq;
+    size_t num = seq.size();
+
+    // First differentiation
+    if (num == 1) {
+        if (is_a<numeric>(n) && ex_to<numeric>(n).is_nonneg_integer()
+            && is_a<symbol>(e) && !has_symbol(seq[0], ex_to<symbol>(e))) {
+            nseq.push_back(_ex0);
+            return nseq;
+        }
+        nseq = seq;
+        nseq.push_back(e);
+        nseq.push_back(n);
+        return nseq;
+    }    
+
+    // Check for existing variable
+    lst l;
+    for (size_t i=1; i<num; i=i+2) {
+        l.append(seq[i]);
+        if (seq[i].is_equal(e)) {
+            nseq = seq;
+            nseq[i+1] = nseq[i+1] + n;
+            return nseq;
+        }
+    }
+
+    // We got a new variable
+    if (is_a<numeric>(n) && ex_to<numeric>(n).is_nonneg_integer()
+        && is_a<symbol>(e) && !has_symbol(seq[0], ex_to<symbol>(e))) {
+        nseq.push_back(_ex0);
+        return nseq;
+    }
+    l.append(e); l.sort();
+    bool got_it = false;
+    size_t i=0, j;
+    nseq.push_back(seq[0]);
+
+    for (lst::const_iterator li = l.begin(); li != l.end(); ++li) {
+        if(got_it) {
+            j = 2*i - 1;
+            nseq.push_back(seq[j]);
+            nseq.push_back(seq[j+1]);
+        } else {
+            if(e.is_equal(*li)) {
+                nseq.push_back(e);
+                nseq.push_back(n);
+                got_it = true;
+            } else {
+                j = 2*i + 1;
+                nseq.push_back(seq[j]);
+                nseq.push_back(seq[j+1]);
+            }     
+        }
+        i++;
+    }
+    return nseq;
+}
+
+/* Symbolically evaluate given diff seq. 
+   It does: 
+    (1) If it is diff of diff then merge two seq's
+    (2) remove diff var x if it acts 0 times
+    (3) Put the seq in preferred sorted order
+*/
+exvector symbolic_eval_diff_seq(const exvector & orig_seq)
+{
+    ex f = orig_seq[0];
+    exvector seq;
+    if (is_a<function>(f) && ex_to<function>(f).get_serial()
+                 == diff_derivative_SERIAL::serial) 
+        for (const_iterator fi = f.begin(); fi != f.end(); ++fi)
+             seq.push_back(*fi);
+    else 
+        seq.push_back(f);
+
+    size_t num = orig_seq.size();
+    for (size_t i=1; i<num; i=i+2) 
+        if(!orig_seq[i+1].is_zero())
+            seq = get_new_diff_seq(seq, orig_seq[i], orig_seq[i+1]);
+
+    return seq;
+}
+
+/* Return a simplified un-evaluated symbolic derivative */
+ex simplified_diff_expression(const ex & result, const lst & l, exvector & left_over_seq)
+{
+    ex ifactor = _ex1;
+    ex nresult = _ex1; 
+
+    if(is_a<mul>(result)) {
+        for(const_iterator ri = result.begin(); ri != result.end(); ++ri) {
+            if(ex_is_independent(*ri, l))
+                ifactor = (*ri)*ifactor;
+            else
+                nresult = (*ri)*nresult;
+        }
+        left_over_seq[0] = nresult;
+        left_over_seq = symbolic_eval_diff_seq(left_over_seq);
+        nresult = ifactor*diff_derivative(left_over_seq).hold(); 
+    } else if(is_a<add>(result)) {
+        nresult = _ex0; 
+        for(const_iterator ri = result.begin(); ri != result.end(); ++ri)
+            nresult += simplified_diff_expression(*ri, l, left_over_seq); 
+    } else { 
+        left_over_seq[0] = result;
+        left_over_seq = symbolic_eval_diff_seq(left_over_seq);
+        nresult = diff_derivative(left_over_seq).hold(); 
+    }
+
+    return nresult;
+}
+
+/* 
+   Derivative w.r.t. to an expression often will not work
+   using substitution method. For example, x^2 can't
+   be substituted in x^4 for evaluating diff(x^4, x^2, 1). 
+   Here we use following relation (strictly valid when z=z(x))
+        df/dz = (df/dx)/(dz/dx)
+   If z = z(x,y) then we verify whether we get the same answer 
+   even w.r.t. *y*. If not then f(x,y) != f(z(x,y)) and this
+   derivative can't be computed without knowing which variable
+   combination is kept fixed. In such case, we return it un-evaluted. 
+*/
+unsigned int diff_recursion = 0;
+
+bool basic_diff_wrt_ex(const ex & f, const ex & e, unsigned int n, const lst & l, ex & result)
+{
+    ex lresult, nresult, denom;
+    symbol s, ns;
+    unsigned int i;
+
+    result = f;
+    while (n>0) {
+        i = 0;
+        for (lst::const_iterator li = l.begin(); li != l.end(); ++li) {
+            if(is_a<symbol>(*li)) {
+                s = ex_to<symbol>(*li);
+                nresult = result.diff(s)/(e.diff(s));
+            } else { 
+                // If following steps results in a symbolic diff then the last
+                // .subs() will get us in a loop. If so break the loop.
+                if (diff_recursion > 0)
+                    return false;
+                diff_recursion++;
+                nresult = result.subs(*li==ns).diff(ns).subs(ns==*li); 
+                denom   = e.subs(*li==ns).diff(ns).subs(ns==*li);
+                diff_recursion--;
+                //Check denominator, if zero then give up
+                if (denom.is_zero())
+                    return false;
+                nresult = nresult/denom;
+            }
+            if (i == 0)
+                lresult = nresult;
+            else if (!nresult.is_equal(lresult))
+                return false;
+            i++;
+        }
+        result = lresult;
+        if(result.is_zero()) 
+            break;
+        n--; 
+    }
+    return true;
+}
+
+/* Evaluate using *basic_diff_wrt_ex* with some checks*/
+ex eval_diff_wrt_ex(const ex & f, const ex & e, unsigned int n)
+{
+    exvector seq;
+    seq.push_back(f);
+    seq.push_back(e);
+    seq.push_back(n);
+
+    lst l;
+    list_symbols_and_functions(e, l, false);
+    l.unique();
+    if(l.nops() < 1) 
+        return diff_derivative(symbolic_eval_diff_seq(seq)).hold(); 
+
+    ex result;
+    if(basic_diff_wrt_ex(f, e, n, l, result))
+        return result;
+
+    ex nresult;
+    if(is_a<add>(f)) {
+        nresult = _ex0; 
+        for(const_iterator fi = f.begin(); fi != f.end(); ++fi) 
+            if(basic_diff_wrt_ex(*fi, e, n, l, result))
+                nresult += result;
+            else
+                nresult += simplified_diff_expression(*fi, l, seq); 
+        return nresult ;
+    }   
+
+    return simplified_diff_expression(f, l, seq); 
+}
+
+/* flag to disable chain rule */
+bool no_diff_chain = false;
+
+/* Main evaluation function for symbolic diff */
+static ex diff_derivative_eval(const exvector & orig_seq)
+{
+    //Sanity checks for inputs
+    size_t num = orig_seq.size();
+    if( num < 3 || num % 2 != 1)
+		throw(std::runtime_error("diff::eval(): malformed arguments. It should be of the form (f(x), x, n)"));
+
+    exvector seq = symbolic_eval_diff_seq(orig_seq);
+    if(seq[0].is_zero())
+        return _ex0;
+
+    num = seq.size(); 
+    ex result = seq[0];
+    exvector left_over_seq;
+    lst l;
+    bool n_is_symbolic;
+    unsigned int n;
+    no_diff_chain = true; // Disable chain rule as either it has been 
+                          // applied or not needed
+    left_over_seq.push_back(result);
+    for (size_t i=1; i<num; i=i+2) {
+        if (!is_a<numeric>(seq[i+1]) || 
+            !ex_to<numeric>(seq[i+1]).is_nonneg_integer())
+            n_is_symbolic = true;
+        else {
+            n_is_symbolic = false;
+            n = ex_to<numeric>(seq[i+1]).to_int();
+        }
+
+        if(n_is_symbolic) {
+            left_over_seq.push_back(seq[i]);
+            left_over_seq.push_back(seq[i+1]);
+            list_symbols_and_functions(seq[i], l, true); //only symbols
+        } 
+        else if(is_a<symbol>(seq[i]))
+            result = result.diff(ex_to<symbol>(seq[i]), n);
+        else
+            result = eval_diff_wrt_ex(result, seq[i], n);
+        if(result.is_zero())
+            break;
+    }
+    // Check for left over
+    num = left_over_seq.size();
+    if(!result.is_zero() && num > 1) {
+        l.unique();
+        result = simplified_diff_expression(result, l, left_over_seq); 
+    }
+    no_diff_chain = false;
+    return result;
+}
+
+/* Register symbolic diff derivative function */
+unsigned diff_derivative_SERIAL::serial = function::register_new(
+                        function_options("diff", 0).
+                        eval_func(diff_derivative_eval));
+
 } // namespace GiNaC
diff -r 2e94e1e945b4 -r 6ecd738aa8a4 ginac/fderivative.h
--- a/ginac/fderivative.h	Sat Aug 01 04:09:34 2009 +0200
+++ b/ginac/fderivative.h	Sun Aug 16 09:21:37 2009 -0300
@@ -85,6 +85,16 @@
 	paramset parameter_set; /**< Set of parameter numbers with respect to which to take the derivative */
 };
 
+/** Diff format symbolic derivative */
+class diff_derivative_SERIAL { public: static unsigned serial; };
+template<typename T1>
+inline function diff_derivative(const T1 & p1) {
+	return function(diff_derivative_SERIAL::serial, exvector(p1));
+}
+
+exvector get_new_diff_seq(const exvector & seq, const ex & e, const ex & n); 
+extern bool no_diff_chain;
+
 } // namespace GiNaC
 
 #endif // ndef __GINAC_DERIVATIVE_H__
diff -r 2e94e1e945b4 -r 6ecd738aa8a4 ginac/function.cpp
--- a/ginac/function.cpp	Sat Aug 01 04:09:34 2009 +0200
+++ b/ginac/function.cpp	Sun Aug 16 09:21:37 2009 -0300
@@ -57,6 +57,7 @@
 	unsigned py_get_serial_from_sfunction(PyObject* f);
 	std::string* py_dumps(PyObject* o);
 	PyObject* py_loads(PyObject* o);
+	int diff_derivative_level;
 }
 namespace GiNaC {
 
@@ -1210,7 +1211,8 @@
 	GINAC_ASSERT(serial<registered_functions().size());
 	// Dynamically dispatch on print_context type
 	const print_context_class_info *pc_info = &c.get_class_info();
-	if (serial >= py_get_ginac_serial()) {
+	if (serial >= py_get_ginac_serial() ||
+        serial == diff_derivative_SERIAL::serial) {
 		//convert arguments to a PyTuple of Expressions
 		PyObject* args = exvector_to_PyTuple(seq);
 
@@ -1916,6 +1918,37 @@
 	throw(std::logic_error("function::imag_part(): invalid nparams"));
 }
 
+/* Appy diff format symbolic derivative */
+ex apply_diff_derivative(unsigned serial, const exvector & seq, const symbol & s) 
+{
+    ex f;
+    exvector nseq;
+
+    if (serial == diff_derivative_SERIAL::serial) {
+        f = seq[0];
+        nseq = seq;
+    } else { 
+        f = function(serial, seq).hold();
+        nseq.push_back(f);
+    } 
+
+    if (!no_diff_chain && diff_derivative_level > 1) {
+	    ex arg_diff;
+        ex result = _ex0;
+        for (const_iterator fi = f.begin(); fi != f.end(); ++fi) {
+            arg_diff = (*fi).diff(s);
+            if (!arg_diff.is_zero())
+                result += diff_derivative(get_new_diff_seq(
+                        nseq, *fi, _ex1)).hold()*arg_diff;
+        }
+        return result;
+    }
+    nseq = get_new_diff_seq(nseq, s, _ex1);
+    if (nseq[0].is_zero())
+        return _ex0;
+    return diff_derivative(nseq).hold();
+}
+
 // protected
 
 /** Implementation of ex::diff() for functions. It applies the chain rule,
@@ -1933,6 +1966,10 @@
 	GINAC_ASSERT(serial<registered_functions().size());
 	const function_options &opt = registered_functions()[serial];
 
+    // If requested then return diff format symbolic derivative
+    if (diff_derivative_level > 0 && opt.derivative_f == NULL)
+        return apply_diff_derivative(serial, seq, s);
+
 	// Check if we need to apply chain rule
 	if (!(opt.apply_chain_rule)) {
 		if (opt.derivative_f == NULL)

Reply via email to