Mainly:
  * make it actually work
  * better report of failures.
  * fixes couple bugs related to importing
---
 sympy/utilities/runtests.py |  121 ++++++++++++++++++++-----------------------
 1 files changed, 57 insertions(+), 64 deletions(-)

diff --git a/sympy/utilities/runtests.py b/sympy/utilities/runtests.py
index 74248f8..80b7e92 100644
--- a/sympy/utilities/runtests.py
+++ b/sympy/utilities/runtests.py
@@ -295,72 +295,45 @@ class SymPyDocTests(object):
         return self._reporter.finish()
 
     def test_file(self, filename):
+        def setup_pprint():
+            from sympy import pprint_use_unicode
+            # force pprint to be in ascii mode in doctests
+            pprint_use_unicode(False)
+
+            # hook our nice, hash-stable strprinter
+            from sympy.interactive import init_printing
+            from sympy.printing import sstrrepr
+            init_printing(sstrrepr)
+
         import doctest
-        print filename
-        doctest.testfile(filename)
-        return
-        import sympy
-        import imp
-        name = "test%d" % self._count
-        name = os.path.splitext(os.path.basename(filename))[0]
-        self._count += 1
+        import unittest
+        from StringIO import StringIO
+
+        rel_name = filename[len(self._root_dir)+1:]
+        module = rel_name.replace('/', '.')[:-3]
+        setup_pprint()
         try:
-            #module = __import__(filename, globals(), locals())
-            module = imp.load_source(name, filename)
-        except ImportError:
+            module = doctest._normalize_module(module)
+            tests = doctest.DocTestFinder().find(module)
+        except:
             self._reporter.import_error(filename, sys.exc_info())
             return
-        disabled = getattr(module, "disabled", False)
-        if disabled:
-            funcs = []
-        else:
-            funcs = sorted(module.__dict__.keys())
-            # we need to filter only those functions that begin with 'test_'
-            funcs = [f for f in funcs if f.startswith("test_")]
-            # and also that are defined in this module (i.e. not imported from
-            # other modules using the import statemet, like "from sympy import
-            # *"). This is tricky to achieve. The easiest is to compare m1 and
-            # m2 below, if they are equal, we are sure that the "f" is from
-            # module. However, when one uses the XFAIL decorator, the m1
-            # appears from "sympy.utilities.pytest", so we check this one
-            # explicitely. This is not robust, as it will stop working when we
-            # move XFAIL to another module, or if we use some decorator defined
-            # elsewhere. Any help with this is welcomed.
-            funcs2 = []
-            m2 = module.__name__
-            for f in funcs:
-                f = module.__dict__[f]
-                m1 = f.__module__
-                if m1 == m2 or m1 == "sympy.utilities.pytest":
-                    if isgeneratorfunction(f):
-                        for fg in f():
-                            func = fg[0]
-                            args = fg[1:]
-                            fgw = lambda: func(*args)
-                            funcs2.append((fgw, inspect.getsourcelines(f)[1]))
-                    else:
-                        funcs2.append((f, inspect.getsourcelines(f)[1]))
-            funcs2.sort(key=lambda x: x[1])
-            funcs = [x[0] for x in funcs2]
-        self._reporter.entering_filename(filename, len(funcs))
-        for f in funcs:
-            self._reporter.entering_test(f)
+
+        tests.sort()
+        tests = [test for test in tests if len(test.examples) > 0]
+        self._reporter.entering_filename(filename, len(tests))
+        for test in tests:
+            assert len(test.examples) != 0
+            runner = doctest.DocTestRunner()
+            old = sys.stdout
+            new = StringIO()
+            sys.stdout = new
             try:
-                f()
-            except KeyboardInterrupt:
-                raise
-            except:
-                t, v, tr = sys.exc_info()
-                if t is AssertionError:
-                    self._reporter.test_fail((t, v, tr))
-                elif t.__name__ == "Skipped":
-                    self._reporter.test_skip()
-                elif t.__name__ == "XFail":
-                    self._reporter.test_xfail()
-                elif t.__name__ == "XPass":
-                    self._reporter.test_xpass()
-                else:
-                    self._reporter.test_exception((t, v, tr))
+                f, t = runner.run(test, out=new.write, clear_globs=False)
+            finally:
+                sys.stdout = old
+            if f > 0:
+                self._reporter.doctest_fail(test.name, new.getvalue())
             else:
                 self._reporter.test_pass()
         self._reporter.leaving_filename()
@@ -390,7 +363,7 @@ class SymPyDocTests(object):
         wildcards = [dir]
         for i in range(level):
             wildcards.append(os.path.join(wildcards[-1], "*"))
-        p = [os.path.join(x, "test_*.py") for x in wildcards]
+        p = [os.path.join(x, "*.py") for x in wildcards]
         return p
 
     def get_tests(self, dir):
@@ -422,6 +395,7 @@ class PyTestReporter(Reporter):
         self._xfailed = 0
         self._xpassed = []
         self._failed = []
+        self._failed_doctest = []
         self._passed = 0
         self._skipped = 0
         self._exceptions = []
@@ -531,6 +505,8 @@ class PyTestReporter(Reporter):
         text = "tests finished: %d passed" % self._passed
         if len(self._failed) > 0:
             text += ", %d failed" % len(self._failed)
+        if len(self._failed_doctest) > 0:
+            text += ", %d failed" % len(self._failed_doctest)
         if self._skipped > 0:
             text += ", %d skipped" % self._skipped
         if self._xfailed > 0:
@@ -570,8 +546,18 @@ class PyTestReporter(Reporter):
                 self.write_exception(t, val, tb)
             self.write("\n")
 
+        if self._tb_style != "no" and len(self._failed_doctest) > 0:
+            #self.write_center("Failed", "_")
+            for e in self._failed_doctest:
+                filename, msg = e
+                self.write_center("", "_")
+                self.write_center("%s" % filename, "_")
+                self.write(msg)
+            self.write("\n")
+
         self.write_center(text)
-        ok = len(self._failed) == 0 and len(self._exceptions) == 0
+        ok = len(self._failed) == 0 and len(self._exceptions) == 0 and \
+                len(self._failed_doctest) == 0
         if not ok:
             self.write("DO *NOT* COMMIT!\n")
         return ok
@@ -612,6 +598,13 @@ class PyTestReporter(Reporter):
         self.write("F")
         self._active_file_error = True
 
+    def doctest_fail(self, name, error_msg):
+        # the first line contains "******", remove it:
+        error_msg = "\n".join(error_msg.split("\n")[1:])
+        self._failed_doctest.append((name, error_msg))
+        self.write("F")
+        self._active_file_error = True
+
     def test_pass(self):
         self._passed += 1
         if self._verbose:
@@ -632,7 +625,7 @@ class PyTestReporter(Reporter):
         self._exceptions.append((filename, None, exc_info))
         rel_name = filename[len(self._root_dir)+1:]
         self.write(rel_name)
-        self.write("[?] Failed to import")
+        self.write("[?]   Failed to import")
         if self._colors:
             self.write(" ")
             self.write("[FAIL]", "Red", align="right")
-- 
1.6.0.4


--~--~---------~--~----~------------~-------~--~----~
You received this message because you are subscribed to the Google Groups 
"sympy-patches" group.
To post to this group, send email to sympy-patches@googlegroups.com
To unsubscribe from this group, send email to [EMAIL PROTECTED]
For more options, visit this group at 
http://groups.google.com/group/sympy-patches?hl=en
-~----------~----~----~----~------~----~------~--~---

Reply via email to