This permits to flatten with respect to a given class (see docstring).
This will be used in the logic module to denest classes of the same type
(but I believe it will be useful also in other places)
---
 sympy/utilities/iterables.py            |   25 +++++++++++++++++++++----
 sympy/utilities/tests/test_iterables.py |    6 ++++++
 2 files changed, 27 insertions(+), 4 deletions(-)

diff --git a/sympy/utilities/iterables.py b/sympy/utilities/iterables.py
index 24ebcd0..8d96b04 100644
--- a/sympy/utilities/iterables.py
+++ b/sympy/utilities/iterables.py
@@ -65,7 +65,7 @@ def make_list(expr, kind):
         return [expr]
 
 
-def flatten(iterable):
+def flatten(iterable, cls=None):
     """Recursively denest iterable containers.
 
        >>> flatten([1, 2, 3])
@@ -77,13 +77,30 @@ def flatten(iterable):
        >>> flatten( (1,2, (1, None)) )
        [1, 2, 1, None]
 
+       If cls argument is specif, it will only flatten instances of that
+       class, for example:
+
+       >>> from sympy.core import Basic
+       >>> class MyOp(Basic):
+       ...     pass
+       ...
+       >>> flatten([MyOp(1, MyOp(2, 3))], cls=MyOp)
+       [1, 2, 3]
+
+
+
     adapted from http://kogs-www.informatik.uni-hamburg.de/~meine/python_tricks
     """
-
+    if cls is None:
+        reducible = lambda x: hasattr(x, "__iter__") and not isinstance(x, 
basestring)
+    else:
+        reducible = lambda x: isinstance(x, cls)
     result = []
     for el in iterable:
-        if hasattr(el, "__iter__") and not isinstance(el, basestring):
-            result.extend(flatten(el))
+        if reducible(el):
+            if hasattr(el, 'args'):
+                el = el.args
+            result.extend(flatten(el, cls=cls))
         else:
             result.append(el)
     return result
diff --git a/sympy/utilities/tests/test_iterables.py 
b/sympy/utilities/tests/test_iterables.py
index 959c900..3aaf3be 100644
--- a/sympy/utilities/tests/test_iterables.py
+++ b/sympy/utilities/tests/test_iterables.py
@@ -25,6 +25,12 @@ def test_flatten():
     assert flatten( (1,(1,)) ) == [1,1]
     assert flatten( (x,(x,)) ) == [x,x]
 
+    from sympy.core.basic import Basic
+    class MyOp(Basic):
+        pass
+    assert flatten( [MyOp(x, y), z]) == [MyOp(x, y), z]
+    assert flatten( [MyOp(x, y), z], cls=MyOp) == [x, y, z]
+
 
 def test_subsets():
     assert list(subsets([1, 2, 3], 1)) == [[1], [2], [3]]
-- 
1.6.1.2


--~--~---------~--~----~------------~-------~--~----~
You received this message because you are subscribed to the Google Groups 
"sympy-patches" group.
To post to this group, send email to sympy-patches@googlegroups.com
To unsubscribe from this group, send email to 
sympy-patches+unsubscr...@googlegroups.com
For more options, visit this group at 
http://groups.google.com/group/sympy-patches?hl=en
-~----------~----~----~----~------~----~------~--~---

Reply via email to