I am developing an improved API for differentiation in sympy which is a 
class for scalar differential operators (a linear combination of sympy 
expressions and partial derivatives).  The code (test_dop.py attached) 
produces the following output -
[image: test_dop-1.jpg]
To run this code also requires the gprinter.py and dop.py modules (also 
attached).  Note that the scalar differential operator API could be 
extended to vector operators or differential forms.  It only depends on how 
you define the coefficients of the partial derivative operators.  You could 
also define partial derivative operators that operate on the left argument 
instead of on the right argument of *.

-- 
You received this message because you are subscribed to the Google Groups 
"sympy" group.
To unsubscribe from this group and stop receiving emails from it, send an email 
to sympy+unsubscr...@googlegroups.com.
To view this discussion on the web visit 
https://groups.google.com/d/msgid/sympy/d94763b0-0ff2-4dfd-8947-2825211655a5n%40googlegroups.com.
#gprinter.py

import sys,shutil,os

from sympy import init_printing, latex

try:
    from IPython.display import display, Latex, Math, display_latex
except ImportError:
    pass
try:
    from sympy.interactive import printing
except ImportError:
    pass

from sympy import *

def isinteractive():  #Is ipython running
    '''
    We will assume that if ipython is running then jupyter notebook is
    running.
    '''
    try:
        __IPYTHON__
        return True
    except NameError:
        return False

class gprinter:

    '''
    The gprinter class implements the functions needed to print sympy
    objects in a function, gprint, that mimics the syntax
    of the python 3 print function.  Additionally, it also impliments latex
    printing from Jupyter Notebook/Lab and from python scripts.  For latex
    printing from linux and windows the following programs are required:

    Jupyter Notebook/Lab is sufficient for printing latex in both linux and
    windows since in both cases it now includes Mathjax.

    To print latex from a python script in linux a tex distribution is
    required.  It is suggested that the texlive-full distribution which
    includes pdfcrop be installed with the command:

        sudo apt-get install texlive-full

    To print latex from a python script in windows 10 the full texlive
    installation is also suggested. Use the following link to download
    the windows texlive installation program:

        http://mirror.ctan.org/systems/texlive/tlnet/install-tl-windows.exe

    In order to use the pdfcrop program in texlive a perl distribution must
    be installed.  Here is a suggested installation link:

        http://strawberryperl.com/
    '''

    line_sep = \
    """
************************************************************************
    """
    latex_preamble = \
    """
\\pagestyle{empty}
\\usepackage[latin1]{inputenc}
\\usepackage{amsmath}
\\usepackage{amsfonts}
\\usepackage{amssymb}
\\usepackage{amsbsy}
\\usepackage{tensor}
\\usepackage{listings}
\\usepackage{color}
\\usepackage{xcolor}
\\usepackage{bm}
\\usepackage{breqn}
\\definecolor{gray}{rgb}{0.95,0.95,0.95}
\\setlength{\\parindent}{0pt}
\\DeclareMathOperator{\\Tr}{Tr}
\\DeclareMathOperator{\\Adj}{Adj}
\\newcommand{\\bfrac}[2]{\\displaystyle\\frac{#1}{#2}}
\\newcommand{\\bc}{\\begin{center}}
\\newcommand{\\ec}{\\end{center}}
\\newcommand{\\lp}{\\left (}
\\newcommand{\\rp}{\\right )}
\\newcommand{\\paren}[1]{\\lp {#1} \\rp}
\\newcommand{\\half}{\\frac{1}{2}}
\\newcommand{\\llt}{\\left <}
\\newcommand{\\rgt}{\\right >}
\\newcommand{\\abs}[1]{\\left |{#1}\\right | }
\\newcommand{\\pdiff}[2]{\\bfrac{\\partial {#1}}{\\partial {#2}}}
\\newcommand{\\lbrc}{\\left \\{}
\\newcommand{\\rbrc}{\\right \\}}
\\newcommand{\\W}{\\wedge}
\\newcommand{\\prm}[1]{{#1}'}
\\newcommand{\\ddt}[1]{\\bfrac{d{#1}}{dt}}
\\newcommand{\\R}{\\dagger}
\\newcommand{\\deriv}[3]{\\bfrac{d^{#3}#1}{d{#2}^{#3}}}
\\newcommand{\\grade}[2]{\\left < {#1} \\right >_{#2}}
\\newcommand{\\f}[2]{{#1}\\lp{#2}\\rp}
\\newcommand{\\eval}[2]{\\left . {#1} \\right |_{#2}}
\\newcommand{\\Nabla}{\\boldsymbol{\\nabla}}
\\newcommand{\\eb}{\\boldsymbol{e}}
\\newcommand{\\bs}[1]{\\boldsymbol{#1}}
\\newcommand{\\grad}{\\bs{\\nabla}}
\\usepackage{float}
\\floatstyle{plain}
\\newfloat{Code}{H}{myc}
\\lstloadlanguages{Python}
    """

    ip_cmds = \
    [r'$$\DeclareMathOperator{\Tr}{Tr}$$',\
     r'$$\DeclareMathOperator{\Adj}{Adj}$$',\
     r'$$\DeclareMathOperator{\sinc}{sinc}$$',\
     r'$$\newcommand{\bfrac}[2]{\displaystyle\frac{#1}{#2}}$$',\
     r'$$\newcommand{\lp}{\left (}$$',\
     r'$$\newcommand{\rp}{\right )}$$',\
     r'$$\newcommand{\paren}[1]{\lp {#1} \rp}$$',\
     r'$$\newcommand{\half}{\frac{1}{2}}$$',\
     r'$$\newcommand{\llt}{\left <}$$',\
     r'$$\newcommand{\rgt}{\right >}$$',\
     r'$$\newcommand{\abs}[1]{\left |{#1}\right | }$$',\
     r'$$\newcommand{\pdiff}[2]{\bfrac{\partial {#1}}{\partial {#2}}}$$',\
     r'$$\newcommand{\npdiff}[3]{\bfrac{\partial^{#3} {#1}}{\partial {#2}^{#3}}}$$',\
     r'$$\newcommand{\lbrc}{\left \{}$$',\
     r'$$\newcommand{\rbrc}{\right \}}$$',\
     r'$$\newcommand{\W}{\wedge}$$',\
     r'$$\newcommand{\prm}[1]{{#1}}$$',\
     r'$$\newcommand{\ddt}[1]{\bfrac{d{#1}}{dt}}$$',\
     r'$$\newcommand{\R}{\dagger}$$',\
     r'$$\newcommand{\deriv}[3]{\bfrac{d^{#3}#1}{d{#2}^{#3}}}$$',\
     r'$$\newcommand{\grade}[2]{\left < {#1} \right >_{#2}}$$',\
     r'$$\newcommand{\f}[2]{{#1}\lp {#2} \rp}$$',\
     r'$$\newcommand{\eval}[2]{\left . {#1} \right |_{#2}}$$',\
     r'$$\newcommand{\bs}[1]{\boldsymbol{#1}}$$',\
     r'$$\newcommand{\grad}{\bs{\nabla}}$$']

#***********************************************************************

    SYS_CMD = {'linux2': {'rm': 'rm', 'evince': 'evince', 'null': ' > /dev/null', '&': '&'},
               'linux': {'rm': 'rm', 'evince': 'evince', 'null': ' > /dev/null', '&': '&'},
               'win32': {'rm': 'del', 'evince': 'start', 'null': ' > NUL', '&': ''},
               'darwin': {'rm': 'rm', 'evince': 'open', 'null': ' > /dev/null', '&': '&'}}

    latex_flg = False
    latex_str = ''
    Format_cnt = 0
    pdiff_format = True
    @classmethod
    def format(cls,pdiff_format=True):
        cls.pdiff_format = pdiff_format

        if cls.Format_cnt == 0:
            cls.Format_cnt += 1

        cls.latex_flg = True  # Latex printing for scripts or Notebooks

        if isinteractive():  # Latex printing for Notebooks
            init_printing(use_latex='mathjax')
            from IPython.display import Math, display
            cmds = '\n'.join(cls.ip_cmds)
            display(Math(cmds))
        return

    @classmethod
    def gprint(cls,*xargs):
        x = []
        fstr = ''
        new_eq_flg = False
        i = 0
        for xi in xargs:
            if isinstance(xi,str):
                if r'\\' in xi and i > 0:
                    if isinteractive():  # Required for Jupyter Notebook/Lab
                        xi_rep = xi.replace(r'\\',r'\end{equation*}@\begin{equation*} ')
                    else:  # Required for latex output from python scripts
                        xi_rep = xi.replace(r'\\',r'\end{equation*}'+'\n'+r'\begin{equation*} ')
                    new_eq_flg = True
                    fstr += xi_rep
                else:  # Pure text printing
                    fstr += xi
            elif isinstance(xi,type):  # Special case for printing python type of an object
                if cls.latex_flg:
                    fstr += r' \text{'+str(xi)+'} '
                else:
                    fstr += str(xi)
            else:
                if cls.latex_flg:  # Convert object to latex string and append to printing string
                    x.append(latex(xi))
                    if new_eq_flg:
                        new_eq_flg = False
                    fstr += r' %s '
                else:  # Append text object to printing string
                    x.append(str(xi))
                    fstr += r' %s '

                i += 1

        if cls.latex_flg:
            if isinteractive():  # Print all xargs to Jupyter Notebook/Lab
                lstr = fstr % tuple(x)
                if '@' in lstr:
                    lines = lstr.split('@')
                    lines[0] = r'\begin{equation*} '+lines[0]
                    lines[-1] += r'\end{equation*}'
                    for line in lines:
                        display(Math(line))
                else:
                    display(Math(lstr))
            else:  # Add latex representatin of all xargs to total latex string
                latex_str = (fstr % tuple(x))
                if latex_str[0] == '#':
                    cls.latex_str += latex_str[1:]
                else:
                    cls.latex_str +=  r'\begin{equation*} ' + latex_str + r'\end{equation*} '+'\n'
        else:  # Print all xargs in text mode
            print(fstr % tuple(x))

        return

    @classmethod
    def pdf(cls,filename=None, paper=(14, 11), crop=None, prog=False, debug=False, pt='10pt', pdfprog='pdflatex'):

        '''
        Post processes LaTeX output (see comments below), adds preamble and
        postamble, generates tex file, inputs file to latex, displays resulting
        pdf file.

        Arg         Value           Result
        pdfprog     'pdflatex'      Use pdfprog to generate pdf output, only generates tex file if pdfprog is None
        crop        margin in bp    Use "pdfcrop" to crop output file (pdfcrop must be installed, linux only)
                                    bp is the TeX big point approximately 1/72 inch.
        png         True            Use "convert" to produce png output (imagemagick must be installed, linux only)

        We assume that if gprinter.pdf() is called then gprinter.Format() has been called at the beginning of the program.
        '''
        if not cls.latex_flg:
            print('gprinter.Format() has not been called. No LaTeX string to process')
            return

        latex_str = cls.paper_format(paper,pt)+cls.latex_preamble+r'\begin{document}'+'\n'+cls.latex_str+r'\end{document}'

        pyfilepath = sys.path[0]

        if filename is None:
            pyfilename = sys.argv[0]
            rootfilename = pyfilename.replace('.py', '')
            tex_filename = rootfilename + '.tex'
            pdf_filename = rootfilename + '.pdf'
        else:
            tex_filename = filename
            pdf_filename = tex_filname.replace('.tex','.pdf')

        if debug:
            print('latex file =', filename)

        latex_file = open(tex_filename, 'w')
        latex_file.write(latex_str)
        latex_file.close()

        sys_cmd = cls.SYS_CMD[sys.platform]

        if pdfprog is not None:

            pdflatex = shutil.which(pdfprog)

            if debug:  # Display latex excution output for debugging purposes
                print('pdflatex path =', pdflatex)
            else:  # Works for Linux don't know about Windows
                os.system(pdfprog+' '+tex_filename+sys_cmd['null'])

            if debug:
                pass
                os.remove(os.path.join(pyfilepath,rootfilename+'.aux'))
                os.remove(os.path.join(pyfilepath,rootfilename+'.log'))
            else:
                os.remove(os.path.join(pyfilepath,rootfilename+'.aux'))
                os.remove(os.path.join(pyfilepath,rootfilename+'.log'))
                os.remove(os.path.join(pyfilepath,rootfilename+'.tex'))

            if crop is not None:
                crop_cmd = 'pdfcrop --margins "'+str(crop)+'" '+pdf_filename
                os.system(crop_cmd)
                os.remove(os.path.join(pyfilepath,pdf_filename))
                os.rename(os.path.join(rootfilename+'-crop.pdf'),os.path.join(pdf_filename))

            os.system(sys_cmd['evince']+' '+pdf_filename)

        return

    @classmethod
    def paper_format(cls,paper,pt):  #Set size of paper and font size

        if paper == 'letter':
            paper_size = \
    """
\\documentclass[@10pt@,fleqn]{book}
    """
        else:
            paper_size = \
    """
\\documentclass[@10pt@,fleqn]{book}
\\usepackage[vcentering]{geometry}
    """
            if paper == 'landscape':
                paper = [11, 8.5]
            paper_size += '\\geometry{papersize={' + str(paper[0]) + \
                          'in,' + str(paper[1]) + 'in},total={' + str(paper[0] - 1) + \
                          'in,' + str(paper[1] - 1) + 'in}}\n'

        paper_size = paper_size.replace('@10pt@', pt)

        return(paper_size)

def gprint(*args):  #Shortcut so you don't have to code gprinter.gprint
    gprinter.gprint(*args)
    return


"""
Differential operators, for all sympy expressions
"""
import copy
import numbers
import warnings
from typing import List, Tuple, Any, Iterable

from sympy import Symbol, S, Add, simplify, diff, Expr, Dummy, expand,\
                  latex, Basic

from gprinter import gprinter

def apply_function_list(f, x):
    if isinstance(f, (tuple, list)):
        fx = x
        for fi in f:
            fx = fi(fx)
        return fx
    else:
        return f(x)

class _BaseDop(Basic):
    """ Base class for differential operators - used to avoid accidental promotion """
    pass

#################### Partial Derivative Operator Class #################

def _basic_diff(f, x, n=1):
    """ Simple wrapper for `diff` that works for our types too """
    if isinstance(f, (Expr, Symbol, numbers.Number)):  # f is sympy expression
        return diff(f, x, n)
    elif hasattr(f, '_eval_derivative_n_times'):
        # one of our types
        return f._eval_derivative_n_times(x, n)
    else:
        raise ValueError('In_basic_diff type(arg) = ' + str(type(f)) + ' not allowed.')


class Pdop(_BaseDop):
    r"""
    Partial derivative operatorp.

    The partial derivatives are of the form

    .. math::
        \partial_{i_{1}...i_{n}} =
            \frac{\partial^{i_{1}+...+i_{n}}}{\partial{x_{1}^{i_{1}}}...\partial{x_{n}^{i_{n}}}}.

    If :math:`i_{j} = 0` then the partial derivative does not contain the
    :math:`x^{i_{j}}` coordinate.

    Attributes
    ----------
    pdiffs : dict
        A dictionary where coordinates are keys and key value are the number of
        times one differentiates with respect to the key.
    order : int
        Total number of differentiations.
        When this is zero (i.e. when :attr:`pdiffs` is ``{}``) then this object
        is the identity operator, and returns its operand unchanged.
    """

    pdiff_format = gprinter.pdiff_format
    x = Symbol('x',real=True)

    def sort_key(self, order=None):
        return (
            # lower order derivatives first
            self.order,
            # sorted by symbol after that, after expansion
            sorted([
                x.sort_key(order)
                for x, k in self.pdiffs.items()
                for i in range(k)
            ])
        )

    def __eq__(self, A):
        if isinstance(A, Pdop) and self.pdiffs == A.pdiffs:
            return True
        else:
            if len(self.pdiffs) == 0 and A == S.One:
                return True
            return False

    def __init__(self, __arg):
        """
        The partial differential operator is a partial derivative with
        respect to a set of real symbols (variables).
        """

        if __arg is None:
            warnings.warn(
                "`Pdop(None)` is deprecated, use `Pdop({})` instead",
                DeprecationWarning, stacklevel=2)
            __arg = {}

        if isinstance(__arg, dict):  # Pdop defined by dictionary
            if len(__arg) == 0:
                self.pdiffs = {Pdop.x:0}
            else:
                self.pdiffs = __arg
        elif isinstance(__arg, Symbol):  # First order derivative with respect to symbol
            self.pdiffs = {__arg: 1}
        else:
            raise TypeError('A dictionary or symbol is required, got {!r}'.format(__arg))

        self.order = sum(self.pdiffs.values())

    def _eval_derivative_n_times(self, x, n) -> 'Pdop':  # pdiff(self)
        # d is partial derivative
        pdiffs = copy.copy(self.pdiffs)
        if x in pdiffs:
            pdiffs[x] += n
        else:
            pdiffs[x] = n
        return Pdop(pdiffs)

    def __call__(self, arg):
        """
        Calculate nth order partial derivative (order defined by
        self) of expression
        """
        for x, n in self.pdiffs.items():
            arg = _basic_diff(arg, x, n)
        return arg

    def __pow__(self,other):
        if isinstance(other,int):
            p = self
            for i in range(1,other):
                p = p*self
            return p
        raise TypeError('For power of Pdop an integer is required, got {!r}'.format(other))

    def __mul__(self, other):  # functional product of self and arg (self*arg)
        return self(other)

    def __rmul__(self, other):  # functional product of arg and self (arg*self)
        assert not isinstance(other, Pdop)
        return Sdop([(other, self)])

    def __add__(self, other):
        if isinstance(other,Pdop):
            pd1 = Sdop([1],[self])
            pd2 = Sdop([1],[other])
            return pd1+pd2
        elif isinstance(other,Sdop):
            pd1 = Sdop([1],[self])
            return pd1+other
        else:
            return Sdop([other],[Pdop({})])+Sdop([1],[self])

    def __radd__(self, other):
        return self+other

    def _sympystr(self,printer):
        if self.order == 0:
            return 'D{}'
        s = ''
        for x in self.pdiffs:
            n = self.pdiffs[x]
            s += 'D'
            if n > 1:
                s += '^' + printer._print(n)

            s += '{' + printer._print(x) + '}'
        return s

    def _latex(self,printer):
        if self.order == 0:
            return ''
        if gprinter.pdiff_format:
            s = r'\frac{\partial'
            if self.order > 1:
                s += '^{' + printer._print(self.order) + '}'
            s += '}{'
            keys = list(self.pdiffs.keys())
            keys.sort(key=lambda x: x.sort_key())
            for key in keys:
                i = self.pdiffs[key]
                s += r'\partial ' + printer._print(key)
                if i > 1:
                    s += '^{' + printer._print(i) + '}'
            s += '}'
        else:
            s = ''
            keys = list(self.pdiffs.keys())
            keys.sort(key=lambda x: x.sort_key())
            for key in keys:
                i = self.pdiffs[key]
                s+= r'\partial'
                if ( i > 1 ):
                    s += r'^{' + printer._print(i) + r'}'
                s += r'_{' + printer._print(key) + '}'
        return s



########################################################################

def _merge_terms(terms1, terms2):
    """ Concatenate and consolidate two sets of already-consolidated terms """
    pdiffs1 = [pdiff for _, pdiff in terms1]
    pdiffs2 = [pdiff for _, pdiff in terms2]

    pdiffs = pdiffs1 + [x for x in pdiffs2 if x not in pdiffs1]
    coefs = len(pdiffs) * [S.Zero]

    for coef, pdiff in terms1:
        index = pdiffs.index(pdiff)
        coefs[index] += coef

    for coef, pdiff in terms2:
        index = pdiffs.index(pdiff)
        coefs[index] += coef

    # remove zeros
    return [(coef, pdiff) for coef, pdiff in zip(coefs, pdiffs) if coef != S.Zero]


def _eval_derivative_n_times_terms(terms, x, n):
    for i in range(n):
        new_terms = []
        for k, term in enumerate(terms):
            dc = _basic_diff(term[0], x)
            pd = _basic_diff(term[1], x)
            # print 'D0, term, dc, pd =', D0, term, dc, pd
            if dc != 0:
                new_terms.append((dc, term[1]))
            if pd != 0:
                new_terms.append((term[0], pd))
        terms = new_terms
    return _consolidate_terms(terms)

def _consolidate_terms(terms):
    """
    Remove zero coefs and consolidate coefs with repeated pdiffs.
    """
    new_coefs = []
    new_pdiffs = []
    for coef, pd in terms:
        if coef != S.Zero:
            if pd in new_pdiffs:
                index = new_pdiffs.index(pd)
                new_coefs[index] += coef
            else:
                new_coefs.append(coef)
                new_pdiffs.append(pd)
    return tuple(zip(new_coefs, new_pdiffs))


################ Scalar Partial Differential Operator Class ############

class Sdop(_BaseDop):
    """
    Scalar differential operator is of the form (Einstein summation)

    .. math:: D = c_{i}*D_{i}

    where the :math:`c_{i}`'s are scalar coefficient (they could be functions)
    and the :math:`D_{i}`'s are partial differential operators (:class:`Pdop`).

    Attributes
    ----------
    terms : tuple of tuple
        the structure :math:`((c_{1},D_{1}),(c_{2},D_{2}), ...)`
    """

    def TSimplify(self):
        return Sdop([
            (simplify(coef), pdiff) for coef, pdiff in self.terms
        ])

    @staticmethod
    def consolidate_coefs(sdop):
        """
        Remove zero coefs and consolidate coefs with repeated pdiffs.
        """
        if isinstance(sdop, Sdop):
            return Sdop(_consolidate_terms(sdop.terms))
        else:
            return _consolidate_terms(sdop)

    def simplify(self, modes=simplify):
        return Sdop([
            (metric.apply_function_list(modes, coef), pdiff)
            for coef, pdiff in self.terms
        ])

    def _with_sorted_terms(self):
        new_terms = sorted(self.terms, key=lambda term: Pdop.sort_key(term[1]))
        return Sdop(new_terms)

    def _sympystr(self, print_obj):
        if len(self.terms) == 0:
            return ZERO_STR

        self = self._with_sorted_terms()
        s = ''
        for coef, pdop in self.terms:
            coef_str = print_obj._print(coef)
            pd_str = print_obj._print(pdop)

            if coef == S.One:
                s += pd_str
            elif coef == S.NegativeOne:
                s += '-' + pd_str
            else:
                if isinstance(coef, Add):
                    s += '(' + coef_str + ')*' + pd_str
                else:
                    s += coef_str + '*' + pd_str
            s += ' + '

        s = s.replace('+ -', '- ')
        s = s[:-3]
        return s

    def _latex(self, printer):
        if len(self.terms) == 0:
            return ZERO_STR

        self = self._with_sorted_terms()

        s = ''
        for coef, pdop in self.terms:
            coef_str = printer._print(coef)
            pd_str = printer._print(pdop)
            if coef == S.One:
                if pd_str == '':
                    s += '1'
                else:
                    s += pd_str
            elif coef == S.NegativeOne:
                if pd_str == '':
                    s += '-1'
                else:
                    s += '-' + pd_str
            else:
                if isinstance(coef, Add):
                    s += r'\left ( ' + coef_str + r'\right ) ' + pd_str
                else:
                    s += coef_str + ' ' + pd_str
            s += ' + '

        s = s.replace('+ -', '- ')
        return s[:-3]

    def __init_from_symbol(self, symbol: Symbol) -> None:
        self.terms = ((S.One, Pdop(symbol)),)

    def __init_from_coef_and_pdiffs(self, coefs: List[Any], pdiffs: List['Pdop']) -> None:
        if not isinstance(coefs, list) or not isinstance(pdiffs, list):
            raise TypeError("coefs and pdiffs must be lists")
        if len(coefs) != len(pdiffs):
            raise ValueError('In Sdop.__init__ coefficent list and Pdop list must be same length.')
        self.terms = tuple(zip(coefs, pdiffs))

    def __init_from_terms(self, terms: Iterable[Tuple[Any, 'Pdop']]) -> None:
        self.terms = tuple(terms)

    def __init__(self, *args):
        if len(args) == 1:
            if isinstance(args[0], Symbol):
                self.__init_from_symbol(*args)
            elif isinstance(args[0], (list, tuple)):
                self.__init_from_terms(*args)
            else:
                raise TypeError(
                    "A symbol or sequence is required (got type {})"
                    .format(type(args[0]).__name__))
        elif len(args) == 2:
            self.__init_from_coef_and_pdiffs(*args)
        else:
            raise TypeError(
                "Sdop() takes from 1 to 2 positional arguments but {} were "
                "given".format(len(args)))

    def __call__(self, arg):
        # Ensure that we return the right type even when there are no terms - we
        # do this by adding `0 * d(arg)/d(nonexistant)`, which must be zero, but
        # will be a zero of the right type.
        dummy_var = Dummy('nonexistant')
        terms = self.terms or ((S.Zero, Pdop(dummy_var)),)
        return sum([coef * pdiff(arg) for coef, pdiff in terms])

    def __neg__(self):
        return Sdop([(-coef, pdiff) for coef, pdiff in self.terms])

    @staticmethod
    def Add(sdop1, sdop2):
        if isinstance(sdop1, Sdop) and isinstance(sdop2, Sdop):
            return Sdop(_merge_terms(sdop1.terms, sdop2.terms))
        else:
            # convert values to multiplicative operators
            if not isinstance(sdop2, _BaseDop):
                sdop2 = Sdop([(sdop2, Pdop({}))])
            elif not isinstance(sdop1, _BaseDop):
                sdop1 = Sdop([(sdop1, Pdop({}))])
            else:
                return NotImplemented
            return Sdop.Add(sdop1, sdop2)

    def __eq__(self, other):
        if isinstance(other, Sdop):
            diff = self - other
            return len(diff.terms) == 0
        else:
            return NotImplemented

    def __add__(self, other):
        if isinstance(other, Pdop):
            other = Sdop([1],[other])
        elif isinstance(other, Expr):
            other = Sdop([other],[Pdop({})])
        else:
            pass

        return Sdop.Add(self, other)

    def __pow__(self, other):
        if isinstance(other,int):
            p = self
            for i in range(1,other):
                p = p*self
            return p
        raise TypeError('For power of Sdop an integer is required, got {!r}'.format(other))

    def __radd__(self, sdop):
        if isinstance(sdop,Pdop):
            sdop = Sdop([1],[sdop])
        return Sdop.Add(sdop, self)

    def __sub__(self, sdop):
        return Sdop.Add(self, -sdop)

    def __rsub__(self, sdop):
        return Sdop.Add(-self, sdop)

    def __mul__(self, sdopr):
        # alias for applying the operator
        return self.__call__(sdopr)

    def __rmul__(self, sdop):
        return Sdop([(sdop * coef, pdiff) for coef, pdiff in self.terms])

    def _eval_derivative_n_times(self, x, n):
        return Sdop(_eval_derivative_n_times_terms(self.terms, x, n))




from sympy import symbols
from gprinter import *
from dop import *

gprinter.format()
(x,y,z) = symbols('x y z',real=True)
r2 = x**2+y**2+z**2
Dx = Pdop(x)
Dy = Pdop(y)
Dz = Pdop(z)

nabla2 = Dx**2+Dy**2+Dz**2
gprint('r^2 = ',r2)
gprint(r'\nabla^2 = ',nabla2)
gprint(r'\nabla^2 r^2 = ',nabla2*r2)
gprint(r'r^2 \nabla^2 = ',r2*nabla2)
gprint(r'r^2 \nabla^2 r^2 = ',r2*nabla2*r2)
gprint(r'r^2+ \nabla^2 = ',r2+nabla2)
rdotgrad = x*Dx+y*Dy+z*Dz
gprint(r'\bs{r}\cdot\bs{\nabla} = ',rdotgrad)
gprint(r'\bs{r}\cdot\bs{\nabla} r^2= ',rdotgrad*r2)
gprinter.pdf(crop=5)

Reply via email to