--~--~---------~--~----~------------~-------~--~----~
You received this message because you are subscribed to the Google Groups 
"sympy" group.
To post to this group, send email to sympy@googlegroups.com
To unsubscribe from this group, send email to [EMAIL PROTECTED]
For more options, visit this group at http://groups.google.com/group/sympy?hl=en
-~----------~----~----~----~------~----~------~--~---

Fix for Issue 801: Add._eval_subs() now uses sets + tests

Before:

In [2]: (a+b+c).subs(a+b,c)
Out[2]: a + b + c

In [3]: (c+b+exp(c+b)).subs(c+b,a)
Out[3]:
     b + c
a + ℯ

In [4]: (a+b+exp(a+b)).subs(a+b,c)
Out[4]:
         c
a + b + ℯ

After:

In [2]: (a+b+c).subs(a+b,c)
Out[2]: 2⋅c

In [3]: (c+b+exp(c+b)).subs(c+b,a)
Out[3]:
     a
a + ℯ

In [4]: (a+b+exp(a+b)).subs(a+b,c)
Out[4]:
     c
c + ℯ

diff --git a/sympy/core/add.py b/sympy/core/add.py
--- a/sympy/core/add.py
+++ b/sympy/core/add.py
@@ -291,21 +291,20 @@
         return S.One,(self,)
 
     def _eval_subs(self, old, new):
-        if self==old: return new
+        if self == old: return new
         if isinstance(old, FunctionClass):
             return self.__class__(*[s.subs(old, new) for s in self.args ])
-        coeff1,factors1 = self.as_coeff_factors()
-        coeff2,factors2 = old.as_coeff_factors()
-        if factors1==factors2: # (2+a).subs(3+a,y) -> 2-3+y
-            return new + coeff1 - coeff2
+        coeff_self, factors_self = self.as_coeff_factors()
+        coeff_old, factors_old = old.as_coeff_factors()
+        if factors_self == factors_old: # (2+a).subs(3+a,y) -> 2-3+y
+            return Add(new, coeff_self, -coeff_old)
         if old.is_Add:
-            l1,l2 = len(factors1),len(factors2)
-            if l2<l1: # (a+b+c+d).subs(b+c,x) -> a+x+d
-                for i in xrange(l1-l2+1):
-                    if factors2==factors1[i:i+l2]:
-                        factors1 = list(factors1)
-                        factors2 = list(factors2)
-                        return Add(*([coeff1-coeff2]+factors1[:i]+[new]+factors1[i+l2:]))
+            if len(factors_old) < len(factors_self): # (a+b+c+d).subs(b+c,x) -> a+x+d
+                self_set = set(factors_self)
+                old_set = set(factors_old)
+                if old_set < self_set:
+                    ret_set = self_set - old_set
+                    return Add(new, coeff_self, -coeff_old, *[s.subs(old, new) for s in ret_set])
         return self.__class__(*[s.subs(old, new) for s in self.args])
 
     @cacheit
diff --git a/sympy/core/tests/test_subs.py b/sympy/core/tests/test_subs.py
--- a/sympy/core/tests/test_subs.py
+++ b/sympy/core/tests/test_subs.py
@@ -136,10 +136,14 @@
     assert (a**2 - c).subs(a**2 - c, d) == d
     assert (a**2 - b - c).subs(a**2 - c, d) in [d - b, a**2 - b - c]
     assert (a**2 - x - c).subs(a**2 - c, d) in [d - x, a**2 - x - c]
+    assert (a**2 - b - sqrt(a)).subs(a**2 - sqrt(a), c) == c - b
+    assert (a+b+exp(a+b)).subs(a+b,c) == c + exp(c)
+    assert (c+b+exp(c+b)).subs(c+b,a) == a + exp(a)
 
     # this should work everytime:
     e = a**2 - b - c
     assert e.subs(Add(*e.args[:2]), d) == d + e.args[2]
+    assert e.subs(a**2 - c, d) == d - b
 
 def test_subs_issue910():
     assert (I*Symbol("a")).subs(1, 2) == I*Symbol("a")

Reply via email to