This is an automated email from the ASF dual-hosted git repository. junrushao pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 1131c92233 [TVMScript] StmtDoc Printing (#12112) 1131c92233 is described below commit 1131c922332ab05e985fb2f40d1c4fdb9fe51b05 Author: Lite Ye <yelite...@gmail.com> AuthorDate: Thu Jul 28 01:33:44 2022 -0400 [TVMScript] StmtDoc Printing (#12112) This PR addes: - StmtDoc Printing in PythonDocPrinter Tracking issue: https://github.com/apache/tvm/issues/11912 --- src/script/printer/base_doc_printer.cc | 22 + src/script/printer/base_doc_printer.h | 63 ++- src/script/printer/python_doc_printer.cc | 212 ++++++- .../test_tvmscript_printer_python_doc_printer.py | 615 ++++++++++++++++++++- 4 files changed, 906 insertions(+), 6 deletions(-) diff --git a/src/script/printer/base_doc_printer.cc b/src/script/printer/base_doc_printer.cc index 42d3f2d8f3..4129152129 100644 --- a/src/script/printer/base_doc_printer.cc +++ b/src/script/printer/base_doc_printer.cc @@ -58,6 +58,28 @@ void DocPrinter::PrintDoc(const Doc& doc) { PrintTypedDoc(GetRef<DictDoc>(doc_node)); } else if (const auto* doc_node = doc.as<SliceDocNode>()) { PrintTypedDoc(GetRef<SliceDoc>(doc_node)); + } else if (const auto* doc_node = doc.as<StmtBlockDocNode>()) { + PrintTypedDoc(GetRef<StmtBlockDoc>(doc_node)); + } else if (const auto* doc_node = doc.as<AssignDocNode>()) { + PrintTypedDoc(GetRef<AssignDoc>(doc_node)); + } else if (const auto* doc_node = doc.as<IfDocNode>()) { + PrintTypedDoc(GetRef<IfDoc>(doc_node)); + } else if (const auto* doc_node = doc.as<WhileDocNode>()) { + PrintTypedDoc(GetRef<WhileDoc>(doc_node)); + } else if (const auto* doc_node = doc.as<ForDocNode>()) { + PrintTypedDoc(GetRef<ForDoc>(doc_node)); + } else if (const auto* doc_node = doc.as<ScopeDocNode>()) { + PrintTypedDoc(GetRef<ScopeDoc>(doc_node)); + } else if (const auto* doc_node = doc.as<ExprStmtDocNode>()) { + PrintTypedDoc(GetRef<ExprStmtDoc>(doc_node)); + } else if (const auto* doc_node = doc.as<AssertDocNode>()) { + PrintTypedDoc(GetRef<AssertDoc>(doc_node)); + } else if (const auto* doc_node = doc.as<ReturnDocNode>()) { + PrintTypedDoc(GetRef<ReturnDoc>(doc_node)); + } else if (const auto* doc_node = doc.as<FunctionDocNode>()) { + PrintTypedDoc(GetRef<FunctionDoc>(doc_node)); + } else if (const auto* doc_node = doc.as<ClassDocNode>()) { + PrintTypedDoc(GetRef<ClassDoc>(doc_node)); } else { LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey(); throw; diff --git a/src/script/printer/base_doc_printer.h b/src/script/printer/base_doc_printer.h index d5bfdcd94c..8633dd0ded 100644 --- a/src/script/printer/base_doc_printer.h +++ b/src/script/printer/base_doc_printer.h @@ -84,22 +84,22 @@ class DocPrinter { virtual void PrintTypedDoc(const LiteralDoc& doc) = 0; /*! - * \brief Virtual method to print a IdDoc + * \brief Virtual method to print an IdDoc */ virtual void PrintTypedDoc(const IdDoc& doc) = 0; /*! - * \brief Virtual method to print a AttrAccessDoc + * \brief Virtual method to print an AttrAccessDoc */ virtual void PrintTypedDoc(const AttrAccessDoc& doc) = 0; /*! - * \brief Virtual method to print a IndexDoc + * \brief Virtual method to print an IndexDoc */ virtual void PrintTypedDoc(const IndexDoc& doc) = 0; /*! - * \brief Virtual method to print a OperationDoc + * \brief Virtual method to print an OperationDoc */ virtual void PrintTypedDoc(const OperationDoc& doc) = 0; @@ -133,6 +133,61 @@ class DocPrinter { */ virtual void PrintTypedDoc(const SliceDoc& doc) = 0; + /*! + * \brief Virtual method to print a StmtBlockDoc + */ + virtual void PrintTypedDoc(const StmtBlockDoc& doc) = 0; + + /*! + * \brief Virtual method to print an AssignDoc + */ + virtual void PrintTypedDoc(const AssignDoc& doc) = 0; + + /*! + * \brief Virtual method to print an IfDoc + */ + virtual void PrintTypedDoc(const IfDoc& doc) = 0; + + /*! + * \brief Virtual method to print a WhileDoc + */ + virtual void PrintTypedDoc(const WhileDoc& doc) = 0; + + /*! + * \brief Virtual method to print a ForDoc + */ + virtual void PrintTypedDoc(const ForDoc& doc) = 0; + + /*! + * \brief Virtual method to print a ScopeDoc + */ + virtual void PrintTypedDoc(const ScopeDoc& doc) = 0; + + /*! + * \brief Virtual method to print an ExprStmtDoc + */ + virtual void PrintTypedDoc(const ExprStmtDoc& doc) = 0; + + /*! + * \brief Virtual method to print an AssertDoc + */ + virtual void PrintTypedDoc(const AssertDoc& doc) = 0; + + /*! + * \brief Virtual method to print a ReturnDoc + */ + virtual void PrintTypedDoc(const ReturnDoc& doc) = 0; + + /*! + * \brief Virtual method to print a FunctionDoc + */ + virtual void PrintTypedDoc(const FunctionDoc& doc) = 0; + + /*! + * \brief Virtual method to print a ClassDoc + */ + virtual void PrintTypedDoc(const ClassDoc& doc) = 0; + /*! * \brief Increase the indent level of any content to be * printed after this call diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index 5c7b048f81..f44577ff80 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -16,11 +16,15 @@ * specific language governing permissions and limitations * under the License. */ - #include <tvm/runtime/logging.h> #include <tvm/runtime/registry.h> +#include <tvm/script/printer/doc.h> + +#include <algorithm> +#include <string> #include "../../support/str_escape.h" +#include "../../support/utils.h" #include "./base_doc_printer.h" namespace tvm { @@ -45,8 +49,21 @@ class PythonDocPrinter : public DocPrinter { void PrintTypedDoc(const DictDoc& doc) final; void PrintTypedDoc(const TupleDoc& doc) final; void PrintTypedDoc(const SliceDoc& doc) final; + void PrintTypedDoc(const StmtBlockDoc& doc) final; + void PrintTypedDoc(const AssignDoc& doc) final; + void PrintTypedDoc(const IfDoc& doc) final; + void PrintTypedDoc(const WhileDoc& doc) final; + void PrintTypedDoc(const ForDoc& doc) final; + void PrintTypedDoc(const ExprStmtDoc& doc) final; + void PrintTypedDoc(const AssertDoc& doc) final; + void PrintTypedDoc(const ReturnDoc& doc) final; + void PrintTypedDoc(const ScopeDoc& doc) final; + void PrintTypedDoc(const FunctionDoc& doc) final; + void PrintTypedDoc(const ClassDoc& doc) final; private: + void NewLineWithoutIndent() { output_ << "\n"; } + template <typename DocType> void PrintJoinedDocs(const Array<DocType>& docs, const std::string& separator) { bool is_first = true; @@ -59,6 +76,65 @@ class PythonDocPrinter : public DocPrinter { PrintDoc(doc); } } + + void PrintIndentedBlock(const Array<StmtDoc>& docs) { + IncreaseIndent(); + for (const StmtDoc& d : docs) { + NewLine(); + PrintDoc(d); + } + if (docs.empty()) { + NewLine(); + output_ << "pass"; + } + DecreaseIndent(); + } + + void PrintDecorators(const Array<ExprDoc>& decorators) { + for (const ExprDoc& decorator : decorators) { + output_ << "@"; + PrintDoc(decorator); + NewLine(); + } + } + + void MaybePrintCommentInline(const StmtDoc& stmt) { + if (stmt->comment.defined()) { + const std::string& comment = stmt->comment.value(); + bool has_newline = std::find(comment.begin(), comment.end(), '\n') != comment.end(); + CHECK(!has_newline) << "ValueError: the comment string of " << stmt->GetTypeKey() + << " cannot have newline."; + output_ << " # " << comment; + } + } + + void MaybePrintCommentWithNewLine(const StmtDoc& stmt) { + if (stmt->comment.defined()) { + std::vector<std::string> comment_lines = support::Split(stmt->comment.value(), '\n'); + for (const std::string& line : comment_lines) { + output_ << "# " << line; + NewLine(); + } + } + } + + void PrintBlockComment(const String& comment) { + IncreaseIndent(); + NewLine() << "\"\"\""; + + std::vector<std::string> comment_lines = support::Split(comment, '\n'); + for (const std::string& line : comment_lines) { + if (line.empty()) { + // No indentation on empty line + output_ << "\n"; + } else { + NewLine() << line; + } + } + + NewLine() << "\"\"\""; + DecreaseIndent(); + } }; void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { @@ -260,6 +336,140 @@ void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) { } } +void PythonDocPrinter::PrintTypedDoc(const StmtBlockDoc& doc) { + for (const StmtDoc& stmt : doc->stmts) { + PrintDoc(stmt); + NewLine(); + } +} + +void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) { + if (const auto* tuple_doc = doc->lhs.as<TupleDocNode>()) { + PrintJoinedDocs(tuple_doc->elements, ", "); + } else { + PrintDoc(doc->lhs); + } + + if (doc->annotation) { + output_ << ": "; + PrintDoc(doc->annotation.value()); + } + if (doc->rhs) { + output_ << " = "; + PrintDoc(doc->rhs.value()); + } + MaybePrintCommentInline(doc); +} + +void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) { + MaybePrintCommentWithNewLine(doc); + output_ << "if "; + PrintDoc(doc->predicate); + output_ << ":"; + + PrintIndentedBlock(doc->then_branch); + + if (!doc->else_branch.empty()) { + NewLine(); + output_ << "else:"; + PrintIndentedBlock(doc->else_branch); + } +} + +void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) { + MaybePrintCommentWithNewLine(doc); + output_ << "while "; + PrintDoc(doc->predicate); + output_ << ":"; + + PrintIndentedBlock(doc->body); +} + +void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) { + MaybePrintCommentWithNewLine(doc); + output_ << "for "; + PrintDoc(doc->lhs); + output_ << " in "; + PrintDoc(doc->rhs); + output_ << ":"; + + PrintIndentedBlock(doc->body); +} + +void PythonDocPrinter::PrintTypedDoc(const ScopeDoc& doc) { + MaybePrintCommentWithNewLine(doc); + output_ << "with "; + PrintDoc(doc->rhs); + if (doc->lhs != nullptr) { + output_ << " as "; + PrintDoc(doc->lhs.value()); + } + output_ << ":"; + + PrintIndentedBlock(doc->body); +} + +void PythonDocPrinter::PrintTypedDoc(const ExprStmtDoc& doc) { + PrintDoc(doc->expr); + MaybePrintCommentInline(doc); +} + +void PythonDocPrinter::PrintTypedDoc(const AssertDoc& doc) { + output_ << "assert "; + PrintDoc(doc->test); + if (doc->msg.defined()) { + output_ << ", "; + PrintDoc(doc->msg.value()); + } + MaybePrintCommentInline(doc); +} + +void PythonDocPrinter::PrintTypedDoc(const ReturnDoc& doc) { + output_ << "return "; + PrintDoc(doc->value); + MaybePrintCommentInline(doc); +} + +void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) { + for (const AssignDoc& arg_doc : doc->args) { + ICHECK(arg_doc->comment == nullptr) << "Function arg cannot have comment attached to them."; + } + + PrintDecorators(doc->decorators); + + output_ << "def "; + PrintDoc(doc->name); + + output_ << "("; + PrintJoinedDocs(doc->args, ", "); + output_ << ")"; + + output_ << " -> "; + PrintDoc(doc->return_type); + + output_ << ":"; + + if (doc->comment.defined()) { + PrintBlockComment(doc->comment.value()); + } + PrintIndentedBlock(doc->body); + NewLineWithoutIndent(); +} + +void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) { + PrintDecorators(doc->decorators); + + output_ << "class "; + PrintDoc(doc->name); + output_ << ":"; + + if (doc->comment.defined()) { + PrintBlockComment(doc->comment.value()); + } + PrintIndentedBlock(doc->body); + NewLineWithoutIndent(); +} + String DocToPythonScript(Doc doc, int indent_spaces) { PythonDocPrinter printer(indent_spaces); printer.Append(doc); diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index b65eaa6b98..523f62d8b5 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -15,18 +15,30 @@ # specific language governing permissions and limitations # under the License. import pytest +import itertools from tvm.script.printer.doc import ( + AssertDoc, + AssignDoc, CallDoc, + ClassDoc, DictDoc, + ExprStmtDoc, + ForDoc, + FunctionDoc, IdDoc, + IfDoc, LambdaDoc, ListDoc, LiteralDoc, OperationDoc, OperationKind, + ReturnDoc, + ScopeDoc, SliceDoc, + StmtBlockDoc, TupleDoc, + WhileDoc, ) from tvm.script.printer.doc_printer import to_python_script @@ -36,10 +48,19 @@ def format_script(s: str) -> str: Remove leading and trailing blank lines, and make the minimum idention 0 """ s = s.strip("\n") + non_empty_lines = [line for line in s.splitlines() if line and not line.isspace()] + if not non_empty_lines: + # no actual content + return "\n" + line_indents = [len(line) - len(line.lstrip(" ")) for line in non_empty_lines] spaces_to_remove = min(line_indents) - return "\n".join(line[spaces_to_remove:] for line in s.splitlines()) + "\n" + + cleaned_lines = "\n".join(line[spaces_to_remove:] for line in s.splitlines()) + if not cleaned_lines.endswith("\n"): + cleaned_lines += "\n" + return cleaned_lines @pytest.mark.parametrize( @@ -59,6 +80,7 @@ def format_script(s: str) -> str: (LiteralDoc(3.25), "3.25"), (LiteralDoc(-0.5), "-0.5"), ], + ids=itertools.count(), ) def test_print_literal_doc(doc, expected): assert to_python_script(doc) == format_script(expected) @@ -73,6 +95,7 @@ def test_print_literal_doc(doc, expected): "test_case", "test123", ], + ids=itertools.count(), ) def test_print_id_doc(name): doc = IdDoc(name) @@ -87,6 +110,7 @@ def test_print_id_doc(name): "Attr", "attr_1", ], + ids=itertools.count(), ) def test_print_attr_doc(attr): doc = IdDoc("x").attr(attr) @@ -125,6 +149,7 @@ def test_print_attr_doc(attr): "[x, y, z]", ), ], + ids=itertools.count(), ) def test_print_index_doc(indices, expected): doc = IdDoc("x")[indices] @@ -271,6 +296,7 @@ def test_operation_doc_test_exhaustive(): "(x, y, key0=u, key1=v)", ), ], + ids=itertools.count(), ) def test_print_call_doc(args, kwargs, expected): doc = CallDoc(IdDoc("f"), *args, **kwargs) @@ -297,6 +323,7 @@ def test_print_call_doc(args, kwargs, expected): "lambda x, y, z: 0", ), ], + ids=itertools.count(), ) def test_print_lambda_doc(args, expected): doc = LambdaDoc(args, body=LiteralDoc(0)) @@ -323,6 +350,7 @@ def test_print_lambda_doc(args, expected): "[x, y, z]", ), ], + ids=itertools.count(), ) def test_print_list_doc(elements, expected): doc = ListDoc(elements) @@ -349,6 +377,7 @@ def test_print_list_doc(elements, expected): "(x, y, z)", ), ], + ids=itertools.count(), ) def test_print_tuple_doc(elements, expected): doc = TupleDoc(elements) @@ -379,6 +408,7 @@ def test_print_tuple_doc(elements, expected): '{"key_x": x, "key_y": y, "key_z": z}', ), ], + ids=itertools.count(), ) def test_print_dict_doc(content, expected): doc = DictDoc(content) @@ -421,7 +451,590 @@ def test_print_dict_doc(content, expected): "1:2:3", ), ], + ids=itertools.count(), ) def test_print_slice_doc(slice_doc, expected): doc = IdDoc("x")[slice_doc] assert to_python_script(doc) == format_script(f"x[{expected}]") + + +@pytest.mark.parametrize( + "stmts, expected", + [ + ( + [], + "", + ), + ( + [ExprStmtDoc(IdDoc("x"))], + "x", + ), + ( + [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))], + """ + x + y + """, + ), + ], + ids=itertools.count(), +) +def test_print_stmt_block_doc(stmts, expected): + doc = StmtBlockDoc(stmts) + assert to_python_script(doc).strip() == format_script(expected).strip() + + +@pytest.mark.parametrize( + "doc, expected", + [ + ( + AssignDoc(IdDoc("x"), IdDoc("y"), None), + "x = y", + ), + ( + AssignDoc(IdDoc("x"), IdDoc("y"), IdDoc("int")), + "x: int = y", + ), + ( + AssignDoc(IdDoc("x"), None, IdDoc("int")), + "x: int", + ), + ( + AssignDoc(TupleDoc([IdDoc("x"), IdDoc("y")]), IdDoc("z"), None), + "x, y = z", + ), + ( + AssignDoc(TupleDoc([IdDoc("x"), TupleDoc([IdDoc("y"), IdDoc("z")])]), IdDoc("z"), None), + "x, (y, z) = z", + ), + ], + ids=itertools.count(), +) +def test_print_assign_doc(doc, expected): + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "then_branch, else_branch, expected", + [ + ( + [ExprStmtDoc(IdDoc("x"))], + [], + """ + if pred: + x + """, + ), + ( + [], + [ExprStmtDoc(IdDoc("y"))], + """ + if pred: + pass + else: + y + """, + ), + ( + [ExprStmtDoc(IdDoc("x"))], + [ExprStmtDoc(IdDoc("y"))], + """ + if pred: + x + else: + y + """, + ), + ], + ids=itertools.count(), +) +def test_print_if_doc(then_branch, else_branch, expected): + doc = IfDoc(IdDoc("pred"), then_branch, else_branch) + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "body, expected", + [ + ( + [ExprStmtDoc(IdDoc("x"))], + """ + while pred: + x + """, + ), + ( + [], + """ + while pred: + pass + """, + ), + ], + ids=itertools.count(), +) +def test_print_while_doc(body, expected): + doc = WhileDoc(IdDoc("pred"), body) + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "body, expected", + [ + ( + [ExprStmtDoc(IdDoc("x"))], + """ + for x in y: + x + """, + ), + ( + [], + """ + for x in y: + pass + """, + ), + ], + ids=itertools.count(), +) +def test_print_for_doc(body, expected): + doc = ForDoc(IdDoc("x"), IdDoc("y"), body) + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "lhs, body, expected", + [ + ( + IdDoc("c"), + [ExprStmtDoc(IdDoc("x"))], + """ + with context() as c: + x + """, + ), + ( + IdDoc("c"), + [], + """ + with context() as c: + pass + """, + ), + ( + None, + [], + """ + with context(): + pass + """, + ), + ( + None, + [ExprStmtDoc(IdDoc("x"))], + """ + with context(): + x + """, + ), + ], + ids=itertools.count(), +) +def test_print_scope_doc(lhs, body, expected): + doc = ScopeDoc(lhs, CallDoc(IdDoc("context")), body) + assert to_python_script(doc) == format_script(expected) + + +def test_print_expr_stmt_doc(): + doc = ExprStmtDoc(CallDoc(IdDoc("f"), IdDoc("x"))) + assert to_python_script(doc) == format_script("f(x)") + + +@pytest.mark.parametrize( + "msg, expected", + [ + ( + None, + """ + assert True + """, + ), + ( + LiteralDoc("test message"), + """ + assert True, "test message" + """, + ), + ], + ids=itertools.count(), +) +def test_print_assert_doc(msg, expected): + test = LiteralDoc(True) + + doc = AssertDoc(test, msg) + + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "value, expected", + [ + ( + LiteralDoc(None), + """ + return None + """, + ), + ( + IdDoc("x"), + """ + return x + """, + ), + ], + ids=itertools.count(), +) +def test_print_return_doc(value, expected): + doc = ReturnDoc(value) + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "args, decorators, body, expected", + [ + ( + [], + [], + [], + """ + def func() -> None: + pass + """, + ), + ( + [AssignDoc(IdDoc("x"), rhs=None, annotation=IdDoc("int"))], + [], + [], + """ + def func(x: int) -> None: + pass + """, + ), + ( + [AssignDoc(IdDoc("x"), rhs=LiteralDoc(1), annotation=IdDoc("int"))], + [], + [], + """ + def func(x: int = 1) -> None: + pass + """, + ), + ( + [], + [IdDoc("wrap")], + [], + """ + @wrap + def func() -> None: + pass + """, + ), + ( + [], + [IdDoc("wrap_outter"), IdDoc("wrap_inner")], + [], + """ + @wrap_outter + @wrap_inner + def func() -> None: + pass + """, + ), + ( + [ + AssignDoc(IdDoc("x"), rhs=None, annotation=IdDoc("int")), + AssignDoc(IdDoc("y"), rhs=LiteralDoc(1), annotation=IdDoc("int")), + ], + [IdDoc("wrap")], + [], + """ + @wrap + def func(x: int, y: int = 1) -> None: + pass + """, + ), + ( + [ + AssignDoc(IdDoc("x"), rhs=None, annotation=IdDoc("int")), + AssignDoc(IdDoc("y"), rhs=LiteralDoc(1), annotation=IdDoc("int")), + ], + [IdDoc("wrap")], + [ + AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Add, [IdDoc("x"), LiteralDoc(1)])), + AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Sub, [IdDoc("y"), LiteralDoc(1)])), + ], + """ + @wrap + def func(x: int, y: int = 1) -> None: + y = x + 1 + y = y - 1 + """, + ), + ], + ids=itertools.count(), +) +def test_print_function_doc(args, decorators, body, expected): + doc = FunctionDoc(IdDoc("func"), args, decorators, LiteralDoc(None), body) + assert to_python_script(doc) == format_script(expected) # test + + +def get_func_doc_for_class(name): + args = [ + AssignDoc(IdDoc("x"), rhs=None, annotation=IdDoc("int")), + AssignDoc(IdDoc("y"), rhs=LiteralDoc(1), annotation=IdDoc("int")), + ] + body = [ + AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Add, [IdDoc("x"), LiteralDoc(1)])), + AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Sub, [IdDoc("y"), LiteralDoc(1)])), + ] + return FunctionDoc( + name=IdDoc(name), + args=args, + decorators=[IdDoc("wrap")], + return_type=LiteralDoc(None), + body=body, + ) + + +@pytest.mark.parametrize( + "decorators, body, expected", + [ + ( + [], + [], + """ + class TestClass: + pass + """, + ), + ( + [IdDoc("wrap")], + [], + """ + @wrap + class TestClass: + pass + """, + ), + ( + [IdDoc("wrap_outter"), IdDoc("wrap_inner")], + [], + """ + @wrap_outter + @wrap_inner + class TestClass: + pass + """, + ), + ( + [IdDoc("wrap")], + [get_func_doc_for_class("f1")], + """ + @wrap + class TestClass: + @wrap + def f1(x: int, y: int = 1) -> None: + y = x + 1 + y = y - 1 + + """, + ), + ( + [IdDoc("wrap")], + [get_func_doc_for_class("f1"), get_func_doc_for_class("f2")], + """ + @wrap + class TestClass: + @wrap + def f1(x: int, y: int = 1) -> None: + y = x + 1 + y = y - 1 + + @wrap + def f2(x: int, y: int = 1) -> None: + y = x + 1 + y = y - 1 + + """, + ), + ], + ids=itertools.count(), +) +def test_print_class_doc(decorators, body, expected): + doc = ClassDoc(IdDoc("TestClass"), decorators, body) + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "doc, comment, expected", + [ + ( + AssignDoc(IdDoc("x"), IdDoc("y"), IdDoc("int")), + "comment", + """ + x: int = y # comment + """, + ), + ( + IfDoc(IdDoc("x"), [ExprStmtDoc(IdDoc("y"))], [ExprStmtDoc(IdDoc("z"))]), + "comment", + """ + # comment + if x: + y + else: + z + """, + ), + ( + IfDoc(IdDoc("x"), [ExprStmtDoc(IdDoc("y"))], [ExprStmtDoc(IdDoc("z"))]), + "comment line 1\ncomment line 2", + """ + # comment line 1 + # comment line 2 + if x: + y + else: + z + """, + ), + ( + WhileDoc( + LiteralDoc(True), + [ + AssignDoc(IdDoc("x"), IdDoc("y")), + ], + ), + "comment", + """ + # comment + while True: + x = y + """, + ), + ( + ForDoc(IdDoc("x"), IdDoc("y"), []), + "comment", + """ + # comment + for x in y: + pass + """, + ), + ( + ScopeDoc(IdDoc("x"), IdDoc("y"), []), + "comment", + """ + # comment + with y as x: + pass + """, + ), + ( + ExprStmtDoc(IdDoc("x")), + "comment", + """ + x # comment + """, + ), + ( + AssertDoc(LiteralDoc(True)), + "comment", + """ + assert True # comment + """, + ), + ( + ReturnDoc(LiteralDoc(1)), + "comment", + """ + return 1 # comment + """, + ), + ( + get_func_doc_for_class("f"), + "comment", + ''' + @wrap + def f(x: int, y: int = 1) -> None: + """ + comment + """ + y = x + 1 + y = y - 1 + ''', + ), + ( + get_func_doc_for_class("f"), + "comment line 1\n\ncomment line 3", + ''' + @wrap + def f(x: int, y: int = 1) -> None: + """ + comment line 1 + + comment line 3 + """ + y = x + 1 + y = y - 1 + ''', + ), + ( + ClassDoc(IdDoc("TestClass"), decorators=[IdDoc("wrap")], body=[]), + "comment", + ''' + @wrap + class TestClass: + """ + comment + """ + pass + ''', + ), + ( + ClassDoc(IdDoc("TestClass"), decorators=[IdDoc("wrap")], body=[]), + "comment line 1\n\ncomment line 3", + ''' + @wrap + class TestClass: + """ + comment line 1 + + comment line 3 + """ + pass + ''', + ), + ], + ids=itertools.count(), +) +def test_print_doc_comment(doc, comment, expected): + doc.comment = comment + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "doc", + [ + AssignDoc(IdDoc("x"), IdDoc("y"), IdDoc("int")), + ExprStmtDoc(IdDoc("x")), + AssertDoc(IdDoc("x")), + ReturnDoc(IdDoc("x")), + ], +) +def test_print_invalid_multiline_doc_comment(doc): + doc.comment = "1\n2" + with pytest.raises(ValueError) as e: + to_python_script(doc) + assert "cannot have newline" in str(e.value)