Author: Lukas Diekmann <[email protected]>
Branch: set-strategies
Changeset: r49251:ddc6e9d447f3
Date: 2011-10-18 12:12 +0200
http://bitbucket.org/pypy/pypy/changeset/ddc6e9d447f3/

Log:    also copy storage of frozenset to avoid changing frozenset in
        methods like intersection, difference, etc

diff --git a/pypy/objspace/std/setobject.py b/pypy/objspace/std/setobject.py
--- a/pypy/objspace/std/setobject.py
+++ b/pypy/objspace/std/setobject.py
@@ -66,9 +66,9 @@
         """ Removes all elements from the set. """
         self.strategy.clear(self)
 
-    def copy(self):
-        """ Returns a clone of the set. """
-        return self.strategy.copy(self)
+    def copy_real(self):
+        """ Returns a clone of the set. Frozensets storages are also copied."""
+        return self.strategy.copy_real(self)
 
     def length(self):
         """ Returns the number of items inside the set. """
@@ -207,7 +207,7 @@
     def clear(self, w_set):
         raise NotImplementedError
 
-    def copy(self, w_set):
+    def copy_real(self, w_set):
         raise NotImplementedError
 
     def length(self, w_set):
@@ -302,7 +302,7 @@
     def clear(self, w_set):
         pass
 
-    def copy(self, w_set):
+    def copy_real(self, w_set):
         storage = self.erase(None)
         clone = w_set.from_storage_and_strategy(storage, self)
         return clone
@@ -340,22 +340,22 @@
         return False
 
     def difference(self, w_set, w_other):
-        return w_set.copy()
+        return w_set.copy_real()
 
     def difference_update(self, w_set, w_other):
         self.check_for_unhashable_objects(w_other)
 
     def intersect(self, w_set, w_other):
         self.check_for_unhashable_objects(w_other)
-        return w_set.copy()
+        return w_set.copy_real()
 
     def intersect_update(self, w_set, w_other):
         self.check_for_unhashable_objects(w_other)
-        return w_set.copy()
+        return w_set.copy_real()
 
     def intersect_multiple(self, w_set, others_w):
         self.intersect_multiple_update(w_set, others_w)
-        return w_set.copy()
+        return w_set.copy_real()
 
     def intersect_multiple_update(self, w_set, others_w):
         for w_other in others_w:
@@ -368,7 +368,7 @@
         return True
 
     def symmetric_difference(self, w_set, w_other):
-        return w_other.copy()
+        return w_other.copy_real()
 
     def symmetric_difference_update(self, w_set, w_other):
         w_set.strategy = w_other.strategy
@@ -412,10 +412,14 @@
     def clear(self, w_set):
         w_set.switch_to_empty_strategy()
 
-    def copy(self, w_set):
+    def copy_real(self, w_set):
         strategy = w_set.strategy
         if isinstance(w_set, W_FrozensetObject):
-            storage = w_set.sstorage
+            # only used internally since frozenset().copy()
+            # returns self in frozenset_copy__Frozenset
+            d = self.unerase(w_set.sstorage)
+            storage = self.erase(d.copy())
+            #storage = w_set.sstorage
         else:
             d = self.unerase(w_set.sstorage)
             storage = self.erase(d.copy())
@@ -621,7 +625,7 @@
 
     def intersect_multiple(self, w_set, others_w):
         #XXX find smarter implementations
-        result = w_set.copy()
+        result = w_set.copy_real()
         for w_other in others_w:
             if isinstance(w_other, W_BaseSetObject):
                 # optimization only
@@ -927,7 +931,7 @@
     w_left.add(w_other)
 
 def set_copy__Set(space, w_set):
-    return w_set.copy()
+    return w_set.copy_real()
 
 def frozenset_copy__Frozenset(space, w_left):
     if type(w_left) is W_FrozensetObject:
@@ -947,8 +951,8 @@
 
 def set_difference__Set(space, w_left, others_w):
     if len(others_w) == 0:
-        return w_left.copy()
-    result = w_left.copy()
+        return w_left.copy_real()
+    result = w_left.copy_real()
     set_difference_update__Set(space, result, others_w)
     return result
 
@@ -1176,7 +1180,7 @@
 
 def set_intersection__Set(space, w_left, others_w):
     if len(others_w) == 0:
-        return w_left.copy()
+        return w_left.copy_real()
     else:
         return _intersection_multiple(space, w_left, others_w)
 
@@ -1250,7 +1254,7 @@
 inplace_xor__Set_Frozenset = inplace_xor__Set_Set
 
 def or__Set_Set(space, w_left, w_other):
-    w_copy = w_left.copy()
+    w_copy = w_left.copy_real()
     w_copy.update(w_other)
     return w_copy
 
@@ -1259,7 +1263,7 @@
 or__Frozenset_Frozenset = or__Set_Set
 
 def set_union__Set(space, w_left, others_w):
-    result = w_left.copy()
+    result = w_left.copy_real()
     for w_other in others_w:
         if isinstance(w_other, W_BaseSetObject):
             result.update(w_other)     # optimization only
diff --git a/pypy/objspace/std/test/test_setobject.py 
b/pypy/objspace/std/test/test_setobject.py
--- a/pypy/objspace/std/test/test_setobject.py
+++ b/pypy/objspace/std/test/test_setobject.py
@@ -717,3 +717,53 @@
         x.pop()
         assert x == set([2,3])
         assert y == set([1,2,3])
+
+    def test_never_change_frozenset(self):
+        a = frozenset([1,2])
+        b = a.copy()
+        assert a is b
+
+        a = frozenset([1,2])
+        b = a.union(set([3,4]))
+        assert b == set([1,2,3,4])
+        assert a == set([1,2])
+
+        a = frozenset()
+        b = a.union(set([3,4]))
+        assert b == set([3,4])
+        assert a == set()
+
+        a = frozenset([1,2])#multiple
+        b = a.union(set([3,4]),[5,6])
+        assert b == set([1,2,3,4,5,6])
+        assert a == set([1,2])
+
+        a = frozenset([1,2,3])
+        b = a.difference(set([3,4,5]))
+        assert b == set([1,2])
+        assert a == set([1,2,3])
+
+        a = frozenset([1,2,3])#multiple
+        b = a.difference(set([3]), [2])
+        assert b == set([1])
+        assert a == set([1,2,3])
+
+        a = frozenset([1,2,3])
+        b = a.symmetric_difference(set([3,4,5]))
+        assert b == set([1,2,4,5])
+        assert a == set([1,2,3])
+
+        a = frozenset([1,2,3])
+        b = a.intersection(set([3,4,5]))
+        assert b == set([3])
+        assert a == set([1,2,3])
+
+        a = frozenset([1,2,3])#multiple
+        b = a.intersection(set([2,3,4]), [2])
+        assert b == set([2])
+        assert a == set([1,2,3])
+
+        raises(AttributeError, "frozenset().update()")
+        raises(AttributeError, "frozenset().difference_update()")
+        raises(AttributeError, "frozenset().symmetric_difference_update()")
+        raises(AttributeError, "frozenset().intersection_update()")
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to