This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1131c92233 [TVMScript] StmtDoc Printing (#12112)
1131c92233 is described below

commit 1131c922332ab05e985fb2f40d1c4fdb9fe51b05
Author: Lite Ye <yelite...@gmail.com>
AuthorDate: Thu Jul 28 01:33:44 2022 -0400

    [TVMScript] StmtDoc Printing (#12112)
    
    This PR addes:
    
    - StmtDoc Printing in PythonDocPrinter
    
    Tracking issue: https://github.com/apache/tvm/issues/11912
---
 src/script/printer/base_doc_printer.cc             |  22 +
 src/script/printer/base_doc_printer.h              |  63 ++-
 src/script/printer/python_doc_printer.cc           | 212 ++++++-
 .../test_tvmscript_printer_python_doc_printer.py   | 615 ++++++++++++++++++++-
 4 files changed, 906 insertions(+), 6 deletions(-)

diff --git a/src/script/printer/base_doc_printer.cc 
b/src/script/printer/base_doc_printer.cc
index 42d3f2d8f3..4129152129 100644
--- a/src/script/printer/base_doc_printer.cc
+++ b/src/script/printer/base_doc_printer.cc
@@ -58,6 +58,28 @@ void DocPrinter::PrintDoc(const Doc& doc) {
     PrintTypedDoc(GetRef<DictDoc>(doc_node));
   } else if (const auto* doc_node = doc.as<SliceDocNode>()) {
     PrintTypedDoc(GetRef<SliceDoc>(doc_node));
+  } else if (const auto* doc_node = doc.as<StmtBlockDocNode>()) {
+    PrintTypedDoc(GetRef<StmtBlockDoc>(doc_node));
+  } else if (const auto* doc_node = doc.as<AssignDocNode>()) {
+    PrintTypedDoc(GetRef<AssignDoc>(doc_node));
+  } else if (const auto* doc_node = doc.as<IfDocNode>()) {
+    PrintTypedDoc(GetRef<IfDoc>(doc_node));
+  } else if (const auto* doc_node = doc.as<WhileDocNode>()) {
+    PrintTypedDoc(GetRef<WhileDoc>(doc_node));
+  } else if (const auto* doc_node = doc.as<ForDocNode>()) {
+    PrintTypedDoc(GetRef<ForDoc>(doc_node));
+  } else if (const auto* doc_node = doc.as<ScopeDocNode>()) {
+    PrintTypedDoc(GetRef<ScopeDoc>(doc_node));
+  } else if (const auto* doc_node = doc.as<ExprStmtDocNode>()) {
+    PrintTypedDoc(GetRef<ExprStmtDoc>(doc_node));
+  } else if (const auto* doc_node = doc.as<AssertDocNode>()) {
+    PrintTypedDoc(GetRef<AssertDoc>(doc_node));
+  } else if (const auto* doc_node = doc.as<ReturnDocNode>()) {
+    PrintTypedDoc(GetRef<ReturnDoc>(doc_node));
+  } else if (const auto* doc_node = doc.as<FunctionDocNode>()) {
+    PrintTypedDoc(GetRef<FunctionDoc>(doc_node));
+  } else if (const auto* doc_node = doc.as<ClassDocNode>()) {
+    PrintTypedDoc(GetRef<ClassDoc>(doc_node));
   } else {
     LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey();
     throw;
diff --git a/src/script/printer/base_doc_printer.h 
b/src/script/printer/base_doc_printer.h
index d5bfdcd94c..8633dd0ded 100644
--- a/src/script/printer/base_doc_printer.h
+++ b/src/script/printer/base_doc_printer.h
@@ -84,22 +84,22 @@ class DocPrinter {
   virtual void PrintTypedDoc(const LiteralDoc& doc) = 0;
 
   /*!
-   * \brief Virtual method to print a IdDoc
+   * \brief Virtual method to print an IdDoc
    */
   virtual void PrintTypedDoc(const IdDoc& doc) = 0;
 
   /*!
-   * \brief Virtual method to print a AttrAccessDoc
+   * \brief Virtual method to print an AttrAccessDoc
    */
   virtual void PrintTypedDoc(const AttrAccessDoc& doc) = 0;
 
   /*!
-   * \brief Virtual method to print a IndexDoc
+   * \brief Virtual method to print an IndexDoc
    */
   virtual void PrintTypedDoc(const IndexDoc& doc) = 0;
 
   /*!
-   * \brief Virtual method to print a OperationDoc
+   * \brief Virtual method to print an OperationDoc
    */
   virtual void PrintTypedDoc(const OperationDoc& doc) = 0;
 
@@ -133,6 +133,61 @@ class DocPrinter {
    */
   virtual void PrintTypedDoc(const SliceDoc& doc) = 0;
 
+  /*!
+   * \brief Virtual method to print a StmtBlockDoc
+   */
+  virtual void PrintTypedDoc(const StmtBlockDoc& doc) = 0;
+
+  /*!
+   * \brief Virtual method to print an AssignDoc
+   */
+  virtual void PrintTypedDoc(const AssignDoc& doc) = 0;
+
+  /*!
+   * \brief Virtual method to print an IfDoc
+   */
+  virtual void PrintTypedDoc(const IfDoc& doc) = 0;
+
+  /*!
+   * \brief Virtual method to print a WhileDoc
+   */
+  virtual void PrintTypedDoc(const WhileDoc& doc) = 0;
+
+  /*!
+   * \brief Virtual method to print a ForDoc
+   */
+  virtual void PrintTypedDoc(const ForDoc& doc) = 0;
+
+  /*!
+   * \brief Virtual method to print a ScopeDoc
+   */
+  virtual void PrintTypedDoc(const ScopeDoc& doc) = 0;
+
+  /*!
+   * \brief Virtual method to print an ExprStmtDoc
+   */
+  virtual void PrintTypedDoc(const ExprStmtDoc& doc) = 0;
+
+  /*!
+   * \brief Virtual method to print an AssertDoc
+   */
+  virtual void PrintTypedDoc(const AssertDoc& doc) = 0;
+
+  /*!
+   * \brief Virtual method to print a ReturnDoc
+   */
+  virtual void PrintTypedDoc(const ReturnDoc& doc) = 0;
+
+  /*!
+   * \brief Virtual method to print a FunctionDoc
+   */
+  virtual void PrintTypedDoc(const FunctionDoc& doc) = 0;
+
+  /*!
+   * \brief Virtual method to print a ClassDoc
+   */
+  virtual void PrintTypedDoc(const ClassDoc& doc) = 0;
+
   /*!
    * \brief Increase the indent level of any content to be
    *        printed after this call
diff --git a/src/script/printer/python_doc_printer.cc 
b/src/script/printer/python_doc_printer.cc
index 5c7b048f81..f44577ff80 100644
--- a/src/script/printer/python_doc_printer.cc
+++ b/src/script/printer/python_doc_printer.cc
@@ -16,11 +16,15 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-
 #include <tvm/runtime/logging.h>
 #include <tvm/runtime/registry.h>
+#include <tvm/script/printer/doc.h>
+
+#include <algorithm>
+#include <string>
 
 #include "../../support/str_escape.h"
+#include "../../support/utils.h"
 #include "./base_doc_printer.h"
 
 namespace tvm {
@@ -45,8 +49,21 @@ class PythonDocPrinter : public DocPrinter {
   void PrintTypedDoc(const DictDoc& doc) final;
   void PrintTypedDoc(const TupleDoc& doc) final;
   void PrintTypedDoc(const SliceDoc& doc) final;
+  void PrintTypedDoc(const StmtBlockDoc& doc) final;
+  void PrintTypedDoc(const AssignDoc& doc) final;
+  void PrintTypedDoc(const IfDoc& doc) final;
+  void PrintTypedDoc(const WhileDoc& doc) final;
+  void PrintTypedDoc(const ForDoc& doc) final;
+  void PrintTypedDoc(const ExprStmtDoc& doc) final;
+  void PrintTypedDoc(const AssertDoc& doc) final;
+  void PrintTypedDoc(const ReturnDoc& doc) final;
+  void PrintTypedDoc(const ScopeDoc& doc) final;
+  void PrintTypedDoc(const FunctionDoc& doc) final;
+  void PrintTypedDoc(const ClassDoc& doc) final;
 
  private:
+  void NewLineWithoutIndent() { output_ << "\n"; }
+
   template <typename DocType>
   void PrintJoinedDocs(const Array<DocType>& docs, const std::string& 
separator) {
     bool is_first = true;
@@ -59,6 +76,65 @@ class PythonDocPrinter : public DocPrinter {
       PrintDoc(doc);
     }
   }
+
+  void PrintIndentedBlock(const Array<StmtDoc>& docs) {
+    IncreaseIndent();
+    for (const StmtDoc& d : docs) {
+      NewLine();
+      PrintDoc(d);
+    }
+    if (docs.empty()) {
+      NewLine();
+      output_ << "pass";
+    }
+    DecreaseIndent();
+  }
+
+  void PrintDecorators(const Array<ExprDoc>& decorators) {
+    for (const ExprDoc& decorator : decorators) {
+      output_ << "@";
+      PrintDoc(decorator);
+      NewLine();
+    }
+  }
+
+  void MaybePrintCommentInline(const StmtDoc& stmt) {
+    if (stmt->comment.defined()) {
+      const std::string& comment = stmt->comment.value();
+      bool has_newline = std::find(comment.begin(), comment.end(), '\n') != 
comment.end();
+      CHECK(!has_newline) << "ValueError: the comment string of " << 
stmt->GetTypeKey()
+                          << " cannot have newline.";
+      output_ << "  # " << comment;
+    }
+  }
+
+  void MaybePrintCommentWithNewLine(const StmtDoc& stmt) {
+    if (stmt->comment.defined()) {
+      std::vector<std::string> comment_lines = 
support::Split(stmt->comment.value(), '\n');
+      for (const std::string& line : comment_lines) {
+        output_ << "# " << line;
+        NewLine();
+      }
+    }
+  }
+
+  void PrintBlockComment(const String& comment) {
+    IncreaseIndent();
+    NewLine() << "\"\"\"";
+
+    std::vector<std::string> comment_lines = support::Split(comment, '\n');
+    for (const std::string& line : comment_lines) {
+      if (line.empty()) {
+        // No indentation on empty line
+        output_ << "\n";
+      } else {
+        NewLine() << line;
+      }
+    }
+
+    NewLine() << "\"\"\"";
+    DecreaseIndent();
+  }
 };
 
 void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) {
@@ -260,6 +336,140 @@ void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) 
{
   }
 }
 
+void PythonDocPrinter::PrintTypedDoc(const StmtBlockDoc& doc) {
+  for (const StmtDoc& stmt : doc->stmts) {
+    PrintDoc(stmt);
+    NewLine();
+  }
+}
+
+void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) {
+  if (const auto* tuple_doc = doc->lhs.as<TupleDocNode>()) {
+    PrintJoinedDocs(tuple_doc->elements, ", ");
+  } else {
+    PrintDoc(doc->lhs);
+  }
+
+  if (doc->annotation) {
+    output_ << ": ";
+    PrintDoc(doc->annotation.value());
+  }
+  if (doc->rhs) {
+    output_ << " = ";
+    PrintDoc(doc->rhs.value());
+  }
+  MaybePrintCommentInline(doc);
+}
+
+void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) {
+  MaybePrintCommentWithNewLine(doc);
+  output_ << "if ";
+  PrintDoc(doc->predicate);
+  output_ << ":";
+
+  PrintIndentedBlock(doc->then_branch);
+
+  if (!doc->else_branch.empty()) {
+    NewLine();
+    output_ << "else:";
+    PrintIndentedBlock(doc->else_branch);
+  }
+}
+
+void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
+  MaybePrintCommentWithNewLine(doc);
+  output_ << "while ";
+  PrintDoc(doc->predicate);
+  output_ << ":";
+
+  PrintIndentedBlock(doc->body);
+}
+
+void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
+  MaybePrintCommentWithNewLine(doc);
+  output_ << "for ";
+  PrintDoc(doc->lhs);
+  output_ << " in ";
+  PrintDoc(doc->rhs);
+  output_ << ":";
+
+  PrintIndentedBlock(doc->body);
+}
+
+void PythonDocPrinter::PrintTypedDoc(const ScopeDoc& doc) {
+  MaybePrintCommentWithNewLine(doc);
+  output_ << "with ";
+  PrintDoc(doc->rhs);
+  if (doc->lhs != nullptr) {
+    output_ << " as ";
+    PrintDoc(doc->lhs.value());
+  }
+  output_ << ":";
+
+  PrintIndentedBlock(doc->body);
+}
+
+void PythonDocPrinter::PrintTypedDoc(const ExprStmtDoc& doc) {
+  PrintDoc(doc->expr);
+  MaybePrintCommentInline(doc);
+}
+
+void PythonDocPrinter::PrintTypedDoc(const AssertDoc& doc) {
+  output_ << "assert ";
+  PrintDoc(doc->test);
+  if (doc->msg.defined()) {
+    output_ << ", ";
+    PrintDoc(doc->msg.value());
+  }
+  MaybePrintCommentInline(doc);
+}
+
+void PythonDocPrinter::PrintTypedDoc(const ReturnDoc& doc) {
+  output_ << "return ";
+  PrintDoc(doc->value);
+  MaybePrintCommentInline(doc);
+}
+
+void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) {
+  for (const AssignDoc& arg_doc : doc->args) {
+    ICHECK(arg_doc->comment == nullptr) << "Function arg cannot have comment 
attached to them.";
+  }
+
+  PrintDecorators(doc->decorators);
+
+  output_ << "def ";
+  PrintDoc(doc->name);
+
+  output_ << "(";
+  PrintJoinedDocs(doc->args, ", ");
+  output_ << ")";
+
+  output_ << " -> ";
+  PrintDoc(doc->return_type);
+
+  output_ << ":";
+
+  if (doc->comment.defined()) {
+    PrintBlockComment(doc->comment.value());
+  }
+  PrintIndentedBlock(doc->body);
+  NewLineWithoutIndent();
+}
+
+void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) {
+  PrintDecorators(doc->decorators);
+
+  output_ << "class ";
+  PrintDoc(doc->name);
+  output_ << ":";
+
+  if (doc->comment.defined()) {
+    PrintBlockComment(doc->comment.value());
+  }
+  PrintIndentedBlock(doc->body);
+  NewLineWithoutIndent();
+}
+
 String DocToPythonScript(Doc doc, int indent_spaces) {
   PythonDocPrinter printer(indent_spaces);
   printer.Append(doc);
diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py 
b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
index b65eaa6b98..523f62d8b5 100644
--- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
+++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
@@ -15,18 +15,30 @@
 # specific language governing permissions and limitations
 # under the License.
 import pytest
+import itertools
 
 from tvm.script.printer.doc import (
+    AssertDoc,
+    AssignDoc,
     CallDoc,
+    ClassDoc,
     DictDoc,
+    ExprStmtDoc,
+    ForDoc,
+    FunctionDoc,
     IdDoc,
+    IfDoc,
     LambdaDoc,
     ListDoc,
     LiteralDoc,
     OperationDoc,
     OperationKind,
+    ReturnDoc,
+    ScopeDoc,
     SliceDoc,
+    StmtBlockDoc,
     TupleDoc,
+    WhileDoc,
 )
 from tvm.script.printer.doc_printer import to_python_script
 
@@ -36,10 +48,19 @@ def format_script(s: str) -> str:
     Remove leading and trailing blank lines, and make the minimum idention 0
     """
     s = s.strip("\n")
+
     non_empty_lines = [line for line in s.splitlines() if line and not 
line.isspace()]
+    if not non_empty_lines:
+        # no actual content
+        return "\n"
+
     line_indents = [len(line) - len(line.lstrip(" ")) for line in 
non_empty_lines]
     spaces_to_remove = min(line_indents)
-    return "\n".join(line[spaces_to_remove:] for line in s.splitlines()) + "\n"
+
+    cleaned_lines = "\n".join(line[spaces_to_remove:] for line in 
s.splitlines())
+    if not cleaned_lines.endswith("\n"):
+        cleaned_lines += "\n"
+    return cleaned_lines
 
 
 @pytest.mark.parametrize(
@@ -59,6 +80,7 @@ def format_script(s: str) -> str:
         (LiteralDoc(3.25), "3.25"),
         (LiteralDoc(-0.5), "-0.5"),
     ],
+    ids=itertools.count(),
 )
 def test_print_literal_doc(doc, expected):
     assert to_python_script(doc) == format_script(expected)
@@ -73,6 +95,7 @@ def test_print_literal_doc(doc, expected):
         "test_case",
         "test123",
     ],
+    ids=itertools.count(),
 )
 def test_print_id_doc(name):
     doc = IdDoc(name)
@@ -87,6 +110,7 @@ def test_print_id_doc(name):
         "Attr",
         "attr_1",
     ],
+    ids=itertools.count(),
 )
 def test_print_attr_doc(attr):
     doc = IdDoc("x").attr(attr)
@@ -125,6 +149,7 @@ def test_print_attr_doc(attr):
             "[x, y, z]",
         ),
     ],
+    ids=itertools.count(),
 )
 def test_print_index_doc(indices, expected):
     doc = IdDoc("x")[indices]
@@ -271,6 +296,7 @@ def test_operation_doc_test_exhaustive():
             "(x, y, key0=u, key1=v)",
         ),
     ],
+    ids=itertools.count(),
 )
 def test_print_call_doc(args, kwargs, expected):
     doc = CallDoc(IdDoc("f"), *args, **kwargs)
@@ -297,6 +323,7 @@ def test_print_call_doc(args, kwargs, expected):
             "lambda x, y, z: 0",
         ),
     ],
+    ids=itertools.count(),
 )
 def test_print_lambda_doc(args, expected):
     doc = LambdaDoc(args, body=LiteralDoc(0))
@@ -323,6 +350,7 @@ def test_print_lambda_doc(args, expected):
             "[x, y, z]",
         ),
     ],
+    ids=itertools.count(),
 )
 def test_print_list_doc(elements, expected):
     doc = ListDoc(elements)
@@ -349,6 +377,7 @@ def test_print_list_doc(elements, expected):
             "(x, y, z)",
         ),
     ],
+    ids=itertools.count(),
 )
 def test_print_tuple_doc(elements, expected):
     doc = TupleDoc(elements)
@@ -379,6 +408,7 @@ def test_print_tuple_doc(elements, expected):
             '{"key_x": x, "key_y": y, "key_z": z}',
         ),
     ],
+    ids=itertools.count(),
 )
 def test_print_dict_doc(content, expected):
     doc = DictDoc(content)
@@ -421,7 +451,590 @@ def test_print_dict_doc(content, expected):
             "1:2:3",
         ),
     ],
+    ids=itertools.count(),
 )
 def test_print_slice_doc(slice_doc, expected):
     doc = IdDoc("x")[slice_doc]
     assert to_python_script(doc) == format_script(f"x[{expected}]")
+
+
+@pytest.mark.parametrize(
+    "stmts, expected",
+    [
+        (
+            [],
+            "",
+        ),
+        (
+            [ExprStmtDoc(IdDoc("x"))],
+            "x",
+        ),
+        (
+            [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+            """
+            x
+            y
+            """,
+        ),
+    ],
+    ids=itertools.count(),
+)
+def test_print_stmt_block_doc(stmts, expected):
+    doc = StmtBlockDoc(stmts)
+    assert to_python_script(doc).strip() == format_script(expected).strip()
+
+
+@pytest.mark.parametrize(
+    "doc, expected",
+    [
+        (
+            AssignDoc(IdDoc("x"), IdDoc("y"), None),
+            "x = y",
+        ),
+        (
+            AssignDoc(IdDoc("x"), IdDoc("y"), IdDoc("int")),
+            "x: int = y",
+        ),
+        (
+            AssignDoc(IdDoc("x"), None, IdDoc("int")),
+            "x: int",
+        ),
+        (
+            AssignDoc(TupleDoc([IdDoc("x"), IdDoc("y")]), IdDoc("z"), None),
+            "x, y = z",
+        ),
+        (
+            AssignDoc(TupleDoc([IdDoc("x"), TupleDoc([IdDoc("y"), 
IdDoc("z")])]), IdDoc("z"), None),
+            "x, (y, z) = z",
+        ),
+    ],
+    ids=itertools.count(),
+)
+def test_print_assign_doc(doc, expected):
+    assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+    "then_branch, else_branch, expected",
+    [
+        (
+            [ExprStmtDoc(IdDoc("x"))],
+            [],
+            """
+            if pred:
+                x
+            """,
+        ),
+        (
+            [],
+            [ExprStmtDoc(IdDoc("y"))],
+            """
+            if pred:
+                pass
+            else:
+                y
+            """,
+        ),
+        (
+            [ExprStmtDoc(IdDoc("x"))],
+            [ExprStmtDoc(IdDoc("y"))],
+            """
+            if pred:
+                x
+            else:
+                y
+            """,
+        ),
+    ],
+    ids=itertools.count(),
+)
+def test_print_if_doc(then_branch, else_branch, expected):
+    doc = IfDoc(IdDoc("pred"), then_branch, else_branch)
+    assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+    "body, expected",
+    [
+        (
+            [ExprStmtDoc(IdDoc("x"))],
+            """
+            while pred:
+                x
+            """,
+        ),
+        (
+            [],
+            """
+            while pred:
+                pass
+            """,
+        ),
+    ],
+    ids=itertools.count(),
+)
+def test_print_while_doc(body, expected):
+    doc = WhileDoc(IdDoc("pred"), body)
+    assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+    "body, expected",
+    [
+        (
+            [ExprStmtDoc(IdDoc("x"))],
+            """
+            for x in y:
+                x
+            """,
+        ),
+        (
+            [],
+            """
+            for x in y:
+                pass
+            """,
+        ),
+    ],
+    ids=itertools.count(),
+)
+def test_print_for_doc(body, expected):
+    doc = ForDoc(IdDoc("x"), IdDoc("y"), body)
+    assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+    "lhs, body, expected",
+    [
+        (
+            IdDoc("c"),
+            [ExprStmtDoc(IdDoc("x"))],
+            """
+            with context() as c:
+                x
+            """,
+        ),
+        (
+            IdDoc("c"),
+            [],
+            """
+            with context() as c:
+                pass
+            """,
+        ),
+        (
+            None,
+            [],
+            """
+            with context():
+                pass
+            """,
+        ),
+        (
+            None,
+            [ExprStmtDoc(IdDoc("x"))],
+            """
+            with context():
+                x
+            """,
+        ),
+    ],
+    ids=itertools.count(),
+)
+def test_print_scope_doc(lhs, body, expected):
+    doc = ScopeDoc(lhs, CallDoc(IdDoc("context")), body)
+    assert to_python_script(doc) == format_script(expected)
+
+
+def test_print_expr_stmt_doc():
+    doc = ExprStmtDoc(CallDoc(IdDoc("f"), IdDoc("x")))
+    assert to_python_script(doc) == format_script("f(x)")
+
+
+@pytest.mark.parametrize(
+    "msg, expected",
+    [
+        (
+            None,
+            """
+            assert True
+            """,
+        ),
+        (
+            LiteralDoc("test message"),
+            """
+            assert True, "test message"
+            """,
+        ),
+    ],
+    ids=itertools.count(),
+)
+def test_print_assert_doc(msg, expected):
+    test = LiteralDoc(True)
+
+    doc = AssertDoc(test, msg)
+
+    assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+    "value, expected",
+    [
+        (
+            LiteralDoc(None),
+            """
+            return None
+            """,
+        ),
+        (
+            IdDoc("x"),
+            """
+            return x
+            """,
+        ),
+    ],
+    ids=itertools.count(),
+)
+def test_print_return_doc(value, expected):
+    doc = ReturnDoc(value)
+    assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+    "args, decorators, body, expected",
+    [
+        (
+            [],
+            [],
+            [],
+            """
+            def func() -> None:
+                pass
+            """,
+        ),
+        (
+            [AssignDoc(IdDoc("x"), rhs=None, annotation=IdDoc("int"))],
+            [],
+            [],
+            """
+            def func(x: int) -> None:
+                pass
+            """,
+        ),
+        (
+            [AssignDoc(IdDoc("x"), rhs=LiteralDoc(1), 
annotation=IdDoc("int"))],
+            [],
+            [],
+            """
+            def func(x: int = 1) -> None:
+                pass
+            """,
+        ),
+        (
+            [],
+            [IdDoc("wrap")],
+            [],
+            """
+            @wrap
+            def func() -> None:
+                pass
+            """,
+        ),
+        (
+            [],
+            [IdDoc("wrap_outter"), IdDoc("wrap_inner")],
+            [],
+            """
+            @wrap_outter
+            @wrap_inner
+            def func() -> None:
+                pass
+            """,
+        ),
+        (
+            [
+                AssignDoc(IdDoc("x"), rhs=None, annotation=IdDoc("int")),
+                AssignDoc(IdDoc("y"), rhs=LiteralDoc(1), 
annotation=IdDoc("int")),
+            ],
+            [IdDoc("wrap")],
+            [],
+            """
+            @wrap
+            def func(x: int, y: int = 1) -> None:
+                pass
+            """,
+        ),
+        (
+            [
+                AssignDoc(IdDoc("x"), rhs=None, annotation=IdDoc("int")),
+                AssignDoc(IdDoc("y"), rhs=LiteralDoc(1), 
annotation=IdDoc("int")),
+            ],
+            [IdDoc("wrap")],
+            [
+                AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Add, 
[IdDoc("x"), LiteralDoc(1)])),
+                AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Sub, 
[IdDoc("y"), LiteralDoc(1)])),
+            ],
+            """
+            @wrap
+            def func(x: int, y: int = 1) -> None:
+                y = x + 1
+                y = y - 1
+            """,
+        ),
+    ],
+    ids=itertools.count(),
+)
+def test_print_function_doc(args, decorators, body, expected):
+    doc = FunctionDoc(IdDoc("func"), args, decorators, LiteralDoc(None), body)
+    assert to_python_script(doc) == format_script(expected)  # test
+
+
+def get_func_doc_for_class(name):
+    args = [
+        AssignDoc(IdDoc("x"), rhs=None, annotation=IdDoc("int")),
+        AssignDoc(IdDoc("y"), rhs=LiteralDoc(1), annotation=IdDoc("int")),
+    ]
+    body = [
+        AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Add, [IdDoc("x"), 
LiteralDoc(1)])),
+        AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Sub, [IdDoc("y"), 
LiteralDoc(1)])),
+    ]
+    return FunctionDoc(
+        name=IdDoc(name),
+        args=args,
+        decorators=[IdDoc("wrap")],
+        return_type=LiteralDoc(None),
+        body=body,
+    )
+
+
+@pytest.mark.parametrize(
+    "decorators, body, expected",
+    [
+        (
+            [],
+            [],
+            """
+            class TestClass:
+                pass
+            """,
+        ),
+        (
+            [IdDoc("wrap")],
+            [],
+            """
+            @wrap
+            class TestClass:
+                pass
+            """,
+        ),
+        (
+            [IdDoc("wrap_outter"), IdDoc("wrap_inner")],
+            [],
+            """
+            @wrap_outter
+            @wrap_inner
+            class TestClass:
+                pass
+            """,
+        ),
+        (
+            [IdDoc("wrap")],
+            [get_func_doc_for_class("f1")],
+            """
+            @wrap
+            class TestClass:
+                @wrap
+                def f1(x: int, y: int = 1) -> None:
+                    y = x + 1
+                    y = y - 1
+
+            """,
+        ),
+        (
+            [IdDoc("wrap")],
+            [get_func_doc_for_class("f1"), get_func_doc_for_class("f2")],
+            """
+            @wrap
+            class TestClass:
+                @wrap
+                def f1(x: int, y: int = 1) -> None:
+                    y = x + 1
+                    y = y - 1
+
+                @wrap
+                def f2(x: int, y: int = 1) -> None:
+                    y = x + 1
+                    y = y - 1
+
+            """,
+        ),
+    ],
+    ids=itertools.count(),
+)
+def test_print_class_doc(decorators, body, expected):
+    doc = ClassDoc(IdDoc("TestClass"), decorators, body)
+    assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+    "doc, comment, expected",
+    [
+        (
+            AssignDoc(IdDoc("x"), IdDoc("y"), IdDoc("int")),
+            "comment",
+            """
+            x: int = y  # comment
+            """,
+        ),
+        (
+            IfDoc(IdDoc("x"), [ExprStmtDoc(IdDoc("y"))], 
[ExprStmtDoc(IdDoc("z"))]),
+            "comment",
+            """
+            # comment
+            if x:
+                y
+            else:
+                z
+            """,
+        ),
+        (
+            IfDoc(IdDoc("x"), [ExprStmtDoc(IdDoc("y"))], 
[ExprStmtDoc(IdDoc("z"))]),
+            "comment line 1\ncomment line 2",
+            """
+            # comment line 1
+            # comment line 2
+            if x:
+                y
+            else:
+                z
+            """,
+        ),
+        (
+            WhileDoc(
+                LiteralDoc(True),
+                [
+                    AssignDoc(IdDoc("x"), IdDoc("y")),
+                ],
+            ),
+            "comment",
+            """
+            # comment
+            while True:
+                x = y
+            """,
+        ),
+        (
+            ForDoc(IdDoc("x"), IdDoc("y"), []),
+            "comment",
+            """
+            # comment
+            for x in y:
+                pass
+            """,
+        ),
+        (
+            ScopeDoc(IdDoc("x"), IdDoc("y"), []),
+            "comment",
+            """
+            # comment
+            with y as x:
+                pass
+            """,
+        ),
+        (
+            ExprStmtDoc(IdDoc("x")),
+            "comment",
+            """
+            x  # comment
+            """,
+        ),
+        (
+            AssertDoc(LiteralDoc(True)),
+            "comment",
+            """
+            assert True  # comment
+            """,
+        ),
+        (
+            ReturnDoc(LiteralDoc(1)),
+            "comment",
+            """
+            return 1  # comment
+            """,
+        ),
+        (
+            get_func_doc_for_class("f"),
+            "comment",
+            '''
+            @wrap
+            def f(x: int, y: int = 1) -> None:
+                """
+                comment
+                """
+                y = x + 1
+                y = y - 1
+            ''',
+        ),
+        (
+            get_func_doc_for_class("f"),
+            "comment line 1\n\ncomment line 3",
+            '''
+            @wrap
+            def f(x: int, y: int = 1) -> None:
+                """
+                comment line 1
+
+                comment line 3
+                """
+                y = x + 1
+                y = y - 1
+            ''',
+        ),
+        (
+            ClassDoc(IdDoc("TestClass"), decorators=[IdDoc("wrap")], body=[]),
+            "comment",
+            '''
+            @wrap
+            class TestClass:
+                """
+                comment
+                """
+                pass
+            ''',
+        ),
+        (
+            ClassDoc(IdDoc("TestClass"), decorators=[IdDoc("wrap")], body=[]),
+            "comment line 1\n\ncomment line 3",
+            '''
+            @wrap
+            class TestClass:
+                """
+                comment line 1
+
+                comment line 3
+                """
+                pass
+            ''',
+        ),
+    ],
+    ids=itertools.count(),
+)
+def test_print_doc_comment(doc, comment, expected):
+    doc.comment = comment
+    assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+    "doc",
+    [
+        AssignDoc(IdDoc("x"), IdDoc("y"), IdDoc("int")),
+        ExprStmtDoc(IdDoc("x")),
+        AssertDoc(IdDoc("x")),
+        ReturnDoc(IdDoc("x")),
+    ],
+)
+def test_print_invalid_multiline_doc_comment(doc):
+    doc.comment = "1\n2"
+    with pytest.raises(ValueError) as e:
+        to_python_script(doc)
+    assert "cannot have newline" in str(e.value)

Reply via email to