This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 3effa45b1f [Unity][Debugging] AST printer (#14152)
3effa45b1f is described below

commit 3effa45b1fe7f68d2a57db0743c5051b850ecdc2
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Tue Feb 28 23:24:02 2023 -0500

    [Unity][Debugging] AST printer (#14152)
    
    This PR transfers over the AST printer from tlc-pack/relax. The AST printer 
is a debugging tool that prints out a Relax AST in a precise and human-readable 
format, which can be helpful for debugging the parser or various passes.
    
    Co-authored-by: Yuchen Jin <[email protected]>
    Co-authored-by: Lesheng Jin <[email protected]>
    Co-authored-by: Siyuan Feng <[email protected]>
    Co-authored-by: Ruihang Lai <[email protected]>
    Co-authored-by: Tianqi Chen <[email protected]>
---
 python/tvm/relax/testing/__init__.py    |   1 +
 python/tvm/relax/testing/ast_printer.py | 372 +++++++++++++++++++
 tests/python/relax/test_ast_printer.py  | 636 ++++++++++++++++++++++++++++++++
 3 files changed, 1009 insertions(+)

diff --git a/python/tvm/relax/testing/__init__.py 
b/python/tvm/relax/testing/__init__.py
index 7344798f70..a6e3a94251 100644
--- a/python/tvm/relax/testing/__init__.py
+++ b/python/tvm/relax/testing/__init__.py
@@ -19,3 +19,4 @@
 
 from .nn import *
 from .relay_translator import *
+from .ast_printer import dump_ast
diff --git a/python/tvm/relax/testing/ast_printer.py 
b/python/tvm/relax/testing/ast_printer.py
new file mode 100644
index 0000000000..6727b24292
--- /dev/null
+++ b/python/tvm/relax/testing/ast_printer.py
@@ -0,0 +1,372 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin, abstract-method, arguments-differ
+"""
+Utility script for printing Relax modules as AST diagrams,
+only intended to show how the AST is put together.
+It is not a pretty-printer and, in fact, is more of an ugly-printer,
+but it can be useful for tutorials and debugging.
+"""
+from typing import Iterable
+import tvm
+from tvm import relax
+from tvm.ir.expr import PrimExpr
+from tvm.relax import ExprFunctor
+
+
+def wrap_quotes(text: str) -> str:
+    """
+    Wraps the text in quotes.
+    """
+    return f'"{text}"'
+
+
+class ASTPrinter(ExprFunctor):
+    """
+    Class for recursing down ASTs and printing them in a very simple format,
+    mainly for instructive purposes and, perhaps, debugging.
+    """
+
+    def __init__(
+        self,
+        indent_str="    ",
+        include_struct_info_annotations=True,
+        include_type_annotations=False,
+        include_call_attrs=True,
+    ):
+        self.indent_str = indent_str
+        self.include_type_annotations = include_type_annotations
+        self.include_struct_info_annotations = include_struct_info_annotations
+        self.include_call_attrs = include_call_attrs
+
+    def visit_expr(self, expr: relax.Expr) -> str:
+        # extend so we also dispatch to bindings and binding blocks,
+        # a little silly but IRFunctor hasn't been ported to Python
+        if isinstance(expr, relax.DataflowBlock):
+            return self.visit_dataflow_block_(expr)
+        if isinstance(expr, relax.BindingBlock):
+            return self.visit_binding_block_(expr)
+        if isinstance(expr, relax.Binding):
+            return self.visit_binding_(expr)
+        return super().visit_expr(expr)
+
+    def indent(self, text: str) -> str:
+        """
+        Indent all lines of the input.
+        """
+        if text == "":
+            return ""
+        lines = text.split("\n")
+        return self.indent_str + f"\n{self.indent_str}".join(lines)
+
+    def build_ast_node(self, nodename: str, force_newline=False, **kwargs: 
str) -> str:
+        """
+        Returns 'nodename(..., fields[i][0]=fields[i][1], ...)'
+        with appropriate indentation
+        """
+        return self.build_list(
+            map(lambda field: f"{field[0]}={field[1]}", kwargs.items()),
+            open_tok=f"{nodename}(",
+            close_tok=")",
+            force_newline=force_newline,
+        )
+
+    def build_expr(self, node: relax.Expr, nodename: str, force_newline=False, 
**kwargs: str):
+        """
+        Renders a Relax expression as a string using `build_ast_node`.
+        Handles whether to include the checked_type_ and struct_info fields.
+        """
+        fields = kwargs.copy()
+        if node.struct_info_ and self.include_struct_info_annotations:
+            fields["struct_info"] = self.visit_struct_info_(node.struct_info)
+        if node._checked_type_ and self.include_type_annotations:
+            fields["checked_type_"] = self.visit_type_(node.checked_type)
+        return self.build_ast_node(nodename, force_newline=force_newline, 
**fields)
+
+    def build_list(
+        self, members: Iterable[str], open_tok="[", close_tok="]", 
force_newline=False
+    ) -> str:
+        """
+        Builds a list of the members given, appropriately indented,
+        with each field on a line.
+        (special case: if there is only one field, then we do not put it on a 
new line
+        unless that field contains a newline or `force_newline` is set to 
true).
+        `open_tok` and `close_tok` are used to open and close the list, 
respectively.
+        """
+        mem_list = list(members)
+        if not mem_list:
+            return f"{open_tok}{close_tok}"
+        if len(mem_list) == 1 and not force_newline and "\n" not in 
mem_list[0]:
+            return f"{open_tok}{mem_list[0]}{close_tok}"
+        member_lines = ",\n".join(map(self.indent, mem_list))
+        return f"{open_tok}\n{member_lines}\n{close_tok}"
+
+    def visit_constant_(self, op: relax.Constant) -> str:
+        # simple rule of thumb: keep scalars inline, but anything larger goes 
on a new one
+        force_newline = len(op.data.shape) > 0
+        return self.build_expr(op, "Constant", force_newline=force_newline, 
data=str(op.data))
+
+    def visit_tuple_(self, op: relax.Tuple) -> str:
+        return self.build_expr(op, "Tuple", 
fields=self.build_list(map(self.visit_expr, op.fields)))
+
+    def visit_dataflow_var_(self, op: relax.DataflowVar) -> str:
+        return self.build_expr(op, "DataflowVar", 
name_hint=wrap_quotes(op.name_hint))
+
+    def visit_var_(self, op: relax.Var) -> str:
+        return self.build_expr(op, "Var", name_hint=wrap_quotes(op.name_hint))
+
+    def visit_shape_expr_(self, op: relax.ShapeExpr) -> str:
+        return self.build_expr(
+            op, "ShapeExpr", values=self.build_list(map(self.visit_prim_expr_, 
op.values))
+        )
+
+    def visit_extern_func_(self, op: relax.ExternFunc) -> str:
+        # ExternFunc does not inherit from relax.Expr either,
+        # so it doesn't have checked_type_ or struct_info fields and we don't 
use build_expr
+        return self.build_ast_node("ExternFunc", 
global_symbol=wrap_quotes(op.global_symbol))
+
+    def visit_global_var_(self, op: relax.GlobalVar) -> str:
+        return self.build_expr(op, "GlobalVar", 
name_hint=wrap_quotes(op.name_hint))
+
+    def visit_function_(self, op: relax.Function) -> str:
+        fields = {
+            "params": self.build_list(map(self.visit_expr, op.params)),
+            "body": self.visit_expr(op.body),
+            "ret_struct_info": self.visit_struct_info_(op.ret_struct_info),
+        }
+        if op.attrs:
+            fields["attrs"] = self.build_list(
+                map(
+                    lambda kv: f"{wrap_quotes(str(kv[0]))}: 
{wrap_quotes(str(kv[1]))}",
+                    op.attrs.items(),
+                ),
+                open_tok="{",
+                close_tok="}",
+            )
+        return self.build_expr(op, "Function", **fields)
+
+    def visit_call_(self, op: relax.Call) -> str:
+        fields = {
+            "op": self.visit_expr(op.op),
+            "args": self.build_list(map(self.visit_expr, op.args)),
+        }
+        if op.sinfo_args:
+            fields["sinfo_args"] = 
self.build_list(map(self.visit_struct_info_, op.sinfo_args))
+        if op.attrs and self.include_call_attrs:
+
+            def display_attrs(attr_key):
+                attr_val = op.attrs[attr_key]
+                # attrs can be strings but also other types;
+                # we want to wrap strings in quotes
+                # (__repr__ would work but it uses single quotes)
+                attr_str = wrap_quotes(attr_val) if isinstance(attr_val, str) 
else str(attr_val)
+                return f"{wrap_quotes(attr_key)}: {attr_str}"
+
+            fields["attrs"] = self.build_list(
+                map(display_attrs, op.attrs.keys()),
+                open_tok="{",
+                close_tok="}",
+            )
+        return self.build_expr(op, "Call", **fields)
+
+    def visit_seq_expr_(self, op: relax.SeqExpr) -> str:
+        return self.build_expr(
+            op,
+            "SeqExpr",
+            blocks=self.build_list(map(self.visit_binding_block_, op.blocks)),
+            body=self.visit_expr(op.body),
+        )
+
+    def visit_if_(self, op: relax.If) -> str:
+        return self.build_expr(
+            op,
+            "If",
+            cond=self.visit_expr(op.cond),
+            true_branch=self.visit_expr(op.true_branch),
+            false_branch=self.visit_expr(op.false_branch),
+        )
+
+    def visit_prim_value_(self, op: relax.PrimValue) -> str:
+        return self.build_expr(op, "PrimValue", 
value=self.visit_prim_expr_(op.value))
+
+    def visit_string_imm_(self, op: relax.StringImm) -> str:
+        return self.build_expr(op, "StringImm", value=wrap_quotes(op.value))
+
+    def visit_data_type_imm_(self, op: relax.DataTypeImm) -> str:
+        return self.build_expr(op, "DataTypeImm", value=op.value)
+
+    def visit_op_(self, op: tvm.ir.Op) -> str:
+        # TODO: List other attributes?
+        # op is not actually a Relax expr and does not have checked_type_
+        # or struct_info fields, so we don't use build_expr here
+        return self.build_ast_node("Op", name=wrap_quotes(op.name))
+
+    def visit_prim_expr_(self, prim_expr: PrimExpr) -> str:
+        # TODO: We may want to print PrimExpr ASTs, but this is a 
simplification for now
+        return self.build_ast_node("PrimExpr", value=f"`{str(prim_expr)}`")
+
+    def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> str:
+        return self.build_expr(
+            op,
+            "TupleGetItem",
+            tuple_value=self.visit_expr(op.tuple_value),
+            index=str(op.index),
+        )
+
+    def visit_type_(self, type_node: relax.Type) -> str:
+        """
+        Recurse down types and print their ASTs too
+        """
+        if isinstance(type_node, relax.ShapeType):
+            return self.build_ast_node("ShapeType", ndim=str(type_node.ndim))
+        if isinstance(type_node, relax.ObjectType):
+            return self.build_ast_node("ObjectType")
+        if isinstance(type_node, relax.PackedFuncType):
+            return self.build_ast_node("PackedFuncType")
+        if isinstance(type_node, tvm.ir.PrimType):
+            return self.build_ast_node("PrimType", dtype=type_node.dtype)
+        if isinstance(type_node, relax.DynTensorType):
+            fields = {}
+            if type_node.ndim is not None:
+                fields["ndim"] = str(type_node.ndim)
+            if type_node.dtype != "":
+                fields["dtype"] = type_node.dtype
+            return self.build_ast_node("DynTensorType", **fields)
+        if isinstance(type_node, relax.TupleType):
+            return self.build_ast_node(
+                "TupleType", fields=self.build_list(map(self.visit_type_, 
type_node.fields))
+            )
+        if isinstance(type_node, relax.FuncType):
+            return self.build_ast_node(
+                "FuncType",
+                arg_types=self.build_list(map(self.visit_type_, 
type_node.arg_types)),
+                ret_type=self.visit_type_(type_node.ret_type),
+                # TODO: skipping type params and type constraints
+            )
+        raise ValueError(f"Invalid Relax Type {type_node} ({type(type_node)})")
+
+    def visit_struct_info_(self, struct_info_node: relax.StructInfo) -> str:
+        """
+        Recurse down struct info and print their ASTs too
+        """
+        if isinstance(struct_info_node, relax.ShapeStructInfo):
+            fields = {}
+            fields["ndim"] = str(struct_info_node.ndim)
+            if struct_info_node.values is not None:
+                fields["values"] = self.build_list(
+                    map(self.visit_prim_expr_, struct_info_node.values)
+                )
+            return self.build_ast_node("ShapeStructInfo", **fields)
+        elif isinstance(struct_info_node, relax.ObjectStructInfo):
+            return self.build_ast_node("ObjectStructInfo")
+        elif isinstance(struct_info_node, relax.PrimStructInfo):
+            return self.build_ast_node("PrimStructInfo", 
dtype=struct_info_node.dtype)
+        elif isinstance(struct_info_node, relax.TensorStructInfo):
+            fields = {}
+            fields["dtype"] = struct_info_node.dtype
+            if struct_info_node.shape:
+                fields["shape"] = self.visit_expr(struct_info_node.shape)
+            else:
+                fields["ndim"] = str(struct_info_node.ndim)
+            return self.build_ast_node("TensorStructInfo", **fields)
+        elif isinstance(struct_info_node, relax.TupleStructInfo):
+            return self.build_ast_node(
+                "TupleStructInfo",
+                fields=self.build_list(map(self.visit_struct_info_, 
struct_info_node.fields)),
+            )
+        elif isinstance(struct_info_node, relax.FuncStructInfo):
+            fields = {}
+            if struct_info_node.params is not None:
+                fields["params"] = self.build_list(
+                    map(self.visit_struct_info_, struct_info_node.params)
+                )
+            fields["ret"] = self.visit_struct_info_(struct_info_node.ret)
+            return self.build_ast_node("FuncStructInfo", **fields)
+        else:
+            raise ValueError(
+                f"Invalid Relax StructInfo {struct_info_node} 
({type(struct_info_node)})"
+            )
+
+    def visit_binding_block_(self, block: relax.BindingBlock) -> str:
+        """
+        Recurse down binding blocks
+        """
+        return self.build_ast_node(
+            "BindingBlock",
+            bindings=self.build_list(map(self.visit_binding_, block.bindings), 
force_newline=True),
+        )
+
+    def visit_dataflow_block_(self, block: relax.DataflowBlock) -> str:
+        """
+        Recurse down a dataflow block
+        """
+        return self.build_ast_node(
+            "DataflowBlock",
+            bindings=self.build_list(map(self.visit_binding_, block.bindings), 
force_newline=True),
+        )
+
+    def visit_binding_(self, binding: relax.Binding) -> str:
+        """
+        Distinguish between binding types
+        """
+        if isinstance(binding, relax.MatchCast):
+            return self.visit_match_cast_(binding)
+        if isinstance(binding, relax.VarBinding):
+            return self.visit_var_binding_(binding)
+        raise ValueError(f"Invalid binding type in {binding}: {type(binding)}")
+
+    def visit_match_cast_(self, match_cast: relax.MatchCast) -> str:
+        """
+        Handle match shape
+        """
+        fields = {
+            "var": self.visit_expr(match_cast.var),
+            "value": self.visit_expr(match_cast.value),
+            "struct_info": self.visit_struct_info_(match_cast.struct_info),
+        }
+        return self.build_ast_node("MatchCast", **fields)
+
+    def visit_var_binding_(self, var_binding: relax.VarBinding) -> str:
+        """
+        Handle ordinary var bindings
+        """
+        return self.build_ast_node(
+            "VarBinding",
+            var=self.visit_expr(var_binding.var),
+            value=self.visit_expr(var_binding.value),
+        )
+
+
+def dump_ast(
+    exp: relax.Expr,
+    indent_str="    ",
+    include_struct_info_annotations=True,
+    include_type_annotations=False,
+    include_call_attrs=True,
+) -> str:
+    """
+    Dump an AST in a text format.
+    Can vary the indentation string and choose whether to include
+    type and shape annotations or call attributes.
+    """
+    printer = ASTPrinter(
+        indent_str=indent_str,
+        include_struct_info_annotations=include_struct_info_annotations,
+        include_type_annotations=include_type_annotations,
+        include_call_attrs=include_call_attrs,
+    )
+    return printer.visit_expr(exp)
diff --git a/tests/python/relax/test_ast_printer.py 
b/tests/python/relax/test_ast_printer.py
new file mode 100644
index 0000000000..ba3c930a45
--- /dev/null
+++ b/tests/python/relax/test_ast_printer.py
@@ -0,0 +1,636 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import re
+from functools import partial
+from typing import Dict
+
+import numpy as np
+import tvm
+import tvm.testing
+from tvm import relax as rx
+from tvm import tir
+from tvm.relax.testing import dump_ast
+from tvm.relax.testing.ast_printer import ASTPrinter
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+# Overload dump_ast to test both struct info and type annotations
+dump_ast = partial(dump_ast, include_struct_info_annotations=True, 
include_type_annotations=True)
+
+
+def strip_whitespace(text: str) -> str:
+    """
+    Remove all whitespace to avoid reasoning about newlines and indents
+    """
+    return re.sub(r"\s", "", text)
+
+
+def normalize(func: rx.Function) -> rx.Function:
+    """
+    Normalize the expr to fill in the checked_type_ and struct_info fields 
everywhere
+    """
+    # using a default mutator to use the BlockBuilder's normalizer,
+    # which oddly differs from the Normalize pass
+    @rx.expr_functor.mutator
+    class DefaultMutator(rx.PyExprMutator):
+        pass
+
+    mod = tvm.IRModule()
+    mod["main"] = func
+    mut = DefaultMutator(mod)
+    mod["main"] = mut.visit_expr(func)
+    return mod["main"]
+
+
+def assert_fields(nodename: str, fields: Dict[str, str], target: str) -> None:
+    """
+    Given a target string, ensure that the string defines the specified node
+    and that the given mappings of fields to values are present in the string.
+    Strips all whitespace in the target and fields.
+    Does not assume any particular ordering for the fields.
+    """
+    stripped_target = strip_whitespace(target)
+    assert stripped_target.startswith(f"{nodename}(")
+    for field, value in fields.items():
+        assert f"{field}={strip_whitespace(value)}" in stripped_target
+
+
+# test cases are mostly adapted from text_expr, only testing very basic 
properties
+
+
+def test_var() -> None:
+    v0 = rx.Var("v0")
+    v0_str = dump_ast(v0)
+    assert v0_str == 'Var(name_hint="v0")'
+
+    v1 = rx.Var("v1", R.Tensor([54, 96], "float32"))
+    v1_no_annos = dump_ast(
+        v1, include_struct_info_annotations=False, 
include_type_annotations=False
+    )
+    assert v1_no_annos == 'Var(name_hint="v1")'
+    v1_annos = dump_ast(v1)
+    assert v1_annos != v1_no_annos
+    assert "PrimExpr" in v1_annos
+    assert "struct_info" in v1_annos
+    assert "checked_type_" in v1_annos
+
+
+def test_dataflow_var() -> None:
+    v0 = rx.DataflowVar("v0")
+    v0_str = dump_ast(v0)
+    assert v0_str == 'DataflowVar(name_hint="v0")'
+
+    v1 = rx.DataflowVar("v1", R.Tensor([54, 96], "float16"))
+    v1_no_annos = dump_ast(
+        v1, include_struct_info_annotations=False, 
include_type_annotations=False
+    )
+    assert v1_no_annos == 'DataflowVar(name_hint="v1")'
+    v1_annos = dump_ast(v1)
+    assert v1_annos != v1_no_annos
+    assert "PrimExpr" in v1_annos
+    assert "struct_info" in v1_annos
+    assert "checked_type_" in v1_annos
+
+
+def test_match_cast() -> None:
+    # match_cast([16, 8], [m, n])
+    m = tir.Var("m", dtype="int64")
+    n = tir.Var("n", dtype="int64")
+    shape = rx.const([16, 8], "int32")
+    var = rx.Var("v0", R.Shape())
+    b0 = rx.MatchCast(var, shape, R.Tensor([m, n], "int32"))
+    b0_str = dump_ast(b0)
+    assert b0_str.startswith("MatchCast(")
+    assert "Constant" in b0_str
+    assert "PrimExpr(value=`m" in b0_str
+    assert "PrimExpr(value=`n" in b0_str
+    assert "16" in b0_str
+    assert "8" in b0_str
+    assert b0_str != dump_ast(b0, include_type_annotations=False)
+
+    # var1: Tensor((m, n), "float32") =
+    #   match_cast(var0: R.Tensor("float32"), [m, n])
+    value = rx.Var("value", R.Tensor("float32"))
+    var = rx.Var("v1", R.Tensor([m, n], "float32"))
+    b1 = rx.MatchCast(var, value, R.Tensor([m, n], "float32"))
+    b1_str = dump_ast(b1)
+    assert b1_str.startswith("MatchCast(")
+    assert "PrimExpr(value=`m" in b1_str
+    assert "PrimExpr(value=`n" in b1_str
+    assert b1_str != dump_ast(
+        b1, include_type_annotations=False, 
include_struct_info_annotations=False
+    )
+
+
+def test_var_binding() -> None:
+    v0 = rx.Var("v0")
+    val = rx.const(np.random.rand(24, 56))
+    b0 = rx.VarBinding(v0, val)
+    b0_str = dump_ast(b0, include_type_annotations=False, 
include_struct_info_annotations=False)
+    assert b0_str.startswith("VarBinding(")
+    assert 'var=Var(name_hint="v0")' in b0_str
+    assert "value=" in b0_str
+    assert "Constant(" in b0_str
+
+
+def test_binding_block() -> None:
+    m = tir.Var("m", dtype="int64")
+    n = tir.Var("n", dtype="int64")
+    shape = rx.const([16, 8], "int32")
+    b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32"))
+
+    v0 = rx.Var("v0")
+    val = rx.const(np.random.rand(24, 56))
+    b1 = rx.VarBinding(v0, val)
+
+    block0 = rx.BindingBlock([b0, b1])
+    block0_str = dump_ast(block0)
+    assert block0_str.startswith("BindingBlock(")
+    assert "bindings=" in block0_str
+    assert "VarBinding(" in block0_str
+    assert "MatchCast(" in block0_str
+    assert '"v0"' in block0_str
+
+
+def test_dataflow_block() -> None:
+    m = tir.Var("m", dtype="int64")
+    n = tir.Var("n", dtype="int64")
+    shape = rx.const([16, 8], "int32")
+    b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32"))
+
+    v0 = rx.Var("v0")
+    val = rx.const(np.random.rand(24, 56))
+    b1 = rx.VarBinding(v0, val)
+
+    block0 = rx.DataflowBlock([b0, b1])
+    block0_str = dump_ast(block0)
+    assert block0_str.startswith("DataflowBlock(")
+    assert "bindings=" in block0_str
+    assert "VarBinding(" in block0_str
+    assert "MatchCast(" in block0_str
+    assert '"v0"' in block0_str
+
+
+def test_seq_expr() -> None:
+    x = rx.Var("foo")
+    bindings = [rx.VarBinding(x, rx.const(1))]
+    blocks = [rx.BindingBlock(bindings)]
+    seqe = rx.SeqExpr(blocks, x)
+    seqe_str = dump_ast(seqe)
+    assert seqe_str.startswith("SeqExpr(")
+    assert "blocks=" in seqe_str
+    assert "BindingBlock(" in seqe_str
+    assert "VarBinding(" in seqe_str
+    assert "Constant(" in seqe_str
+    assert 'var=Var(name_hint="foo")' in seqe_str
+    assert "value=Constant(data" in strip_whitespace(seqe_str)
+    assert "body=" in seqe_str
+
+
+def test_shape_expr() -> None:
+    m = tir.Var("m", dtype="int32")
+    n = tir.Var("n", dtype="int32")
+    s = rx.ShapeExpr([m, n])
+    s_str = dump_ast(s)
+    assert s_str.startswith("ShapeExpr(")
+    assert "values=" in s_str
+    assert "PrimExpr(value=`m: int32`)" in s_str
+    assert "PrimExpr(value=`n: int32`)" in s_str
+
+
+def test_func():
+    x = rx.Var("foo", R.Tensor("float32", ndim=2))
+    bindings = [rx.VarBinding(x, rx.const(1))]
+    blocks = [rx.BindingBlock(bindings)]
+    seqe = rx.SeqExpr(blocks, x)
+    func = rx.Function([x], seqe, R.Tensor("float32"))
+    func = func.with_attr("global_symbol", "func")
+
+    func_str = dump_ast(func)
+    assert func_str.startswith("Function(")
+    assert "params=" in func_str
+    assert "body=" in func_str
+    assert "ret_struct_info=" in func_str
+    assert "attrs=" in func_str
+    assert '"global_symbol": "func"' in func_str
+    assert "SeqExpr(" in func_str
+    assert "blocks=" in func_str
+    assert "VarBinding(" in func_str
+    assert func_str != dump_ast(func, include_type_annotations=False)
+
+
+def test_shape_of():
+    v0 = rx.Var("v0", R.Tensor(ndim=2))
+    s0 = rx.get_shape_of(v0)
+    s0_str = dump_ast(s0)
+    assert s0_str.startswith("Call(")
+    assert 'op=Op(name="relax.shape_of")' in s0_str
+    assert "args=" in s0_str
+    assert 'name_hint="v0"' in s0_str
+
+    v1 = rx.Var("v1", R.Tensor([96, 54]))
+    s1 = rx.get_shape_of(v1)
+    s1_str = dump_ast(s1)
+    assert s1_str.startswith("ShapeExpr("), s1_str
+    assert "values=" in s1_str
+    assert "PrimExpr(value=`T.int64(96)`)" in s1_str
+    assert "PrimExpr(value=`T.int64(54)`)" in s1_str
+
+
+def test_shape_expr():
+    shape_expr = rx.ShapeExpr([10, 20])
+    shape_expr_str = dump_ast(shape_expr)
+    assert shape_expr_str.startswith("ShapeExpr(")
+    assert "values" in shape_expr_str
+    assert "PrimExpr(value=`T.int64(10)`)" in shape_expr_str
+    assert "PrimExpr(value=`T.int64(20)`)" in shape_expr_str
+
+
+def test_types():
+    printer = ASTPrinter()
+    assert strip_whitespace(printer.visit_type_(rx.ShapeType())) == 
"ShapeType(ndim=-1)"
+    assert strip_whitespace(printer.visit_type_(rx.ShapeType(ndim=1))) == 
"ShapeType(ndim=1)"
+    object_type = rx.ObjectType()
+    assert strip_whitespace(printer.visit_type_(object_type)) == "ObjectType()"
+    packed_type = rx.PackedFuncType()
+    assert strip_whitespace(printer.visit_type_(packed_type)) == 
"PackedFuncType()"
+    tensor_type = rx.DynTensorType(ndim=2, dtype="int32")
+    assert strip_whitespace(printer.visit_type_(tensor_type)) == 
"DynTensorType(ndim=2,dtype=int32)"
+    unit_type = rx.TupleType([])
+    assert strip_whitespace(printer.visit_type_(unit_type)) == 
"TupleType(fields=[])"
+    tuple_type = rx.TupleType([rx.ShapeType(), object_type])
+    assert_fields(
+        "TupleType",
+        {"fields": "[ShapeType(ndim=-1),ObjectType()]"},
+        strip_whitespace(printer.visit_type_(tuple_type)),
+    )
+
+    func_type = rx.FuncType([tensor_type], unit_type)
+    assert_fields(
+        "FuncType",
+        {"arg_types": "[DynTensorType(ndim=2, dtype=int32)]", "ret_type": 
"TupleType(fields=[])"},
+        printer.visit_type_(func_type),
+    )
+
+
+def test_struct_info():
+    printer = ASTPrinter(include_type_annotations=True)
+
+    assert printer.visit_struct_info_(rx.ObjectStructInfo()) == 
"ObjectStructInfo()"
+
+    assert printer.visit_struct_info_(rx.PrimStructInfo("int32")) == 
"PrimStructInfo(dtype=int32)"
+
+    # empty shape
+    empty_ssi = rx.ShapeStructInfo()
+    assert printer.visit_struct_info_(empty_ssi) == "ShapeStructInfo(ndim=-1)"
+
+    # include some dimensions
+    shape_info = rx.ShapeStructInfo([tir.IntImm("int64", 1), 
tir.IntImm("int64", 2)])
+    assert strip_whitespace(printer.visit_struct_info_(shape_info)) == 
strip_whitespace(
+        """
+        ShapeStructInfo(
+            ndim=2,
+            values=[
+                PrimExpr(value=`T.int64(1)`),
+                PrimExpr(value=`T.int64(2)`)
+            ]
+        )
+        """
+    )
+
+    # tensor struct info
+    default_tsi = rx.TensorStructInfo()
+    assert (
+        strip_whitespace(printer.visit_struct_info_(default_tsi))
+        == "TensorStructInfo(dtype=float32,ndim=-1)"
+    )
+
+    # use a var as the shape
+    x = rx.Var("x", struct_info=rx.ShapeStructInfo(values=[]))
+    var_tsi = rx.TensorStructInfo(shape=x, dtype="int32")
+    assert strip_whitespace(printer.visit_struct_info_(var_tsi)) == 
strip_whitespace(
+        """
+        TensorStructInfo(
+            dtype=int32,
+            shape=Var(
+                name_hint="x",
+                struct_info=ShapeStructInfo(ndim=0, values=[]),
+                checked_type_=ShapeType(ndim=0)
+            )
+        )
+        """
+    )
+
+    empty_tuple = rx.TupleStructInfo([])
+    assert printer.visit_struct_info_(empty_tuple) == 
"TupleStructInfo(fields=[])"
+
+    tuple_of_shape = rx.TupleStructInfo([empty_ssi])
+    assert strip_whitespace(printer.visit_struct_info_(tuple_of_shape)) == 
strip_whitespace(
+        """
+        TupleStructInfo(fields=[
+            ShapeStructInfo(ndim=-1)
+        ])
+        """
+    )
+
+    simple_func = rx.FuncStructInfo([], rx.ObjectStructInfo())
+    assert (
+        strip_whitespace(printer.visit_struct_info_(simple_func))
+        == "FuncStructInfo(params=[],ret=ObjectStructInfo())"
+    )
+
+
+def test_call_packed():
+    # test case from test_parser
+    @R.function
+    def f(
+        x: R.Tensor((32, "m"), "float32"),
+        y: R.Tensor(("m",), "float32"),
+        r: R.Tensor(dtype="int64"),
+    ) -> R.Object:
+        m = T.var("int64")
+        z: R.Tensor((32, m), "float32") = R.multiply(x, y)
+        w: R.Tensor = R.multiply(z, z)
+        q: R.Tensor(ndim=2) = R.add(w, w)
+        t = R.add(w, z)
+        sh: R.Shape = R.shape_of(t)
+        o: R.Object = R.call_packed(
+            "contrib.tensor_array_stack", x, y, sinfo_args=R.Object(), 
test_attr=True
+        )
+        return o
+
+    # checking that the call_packed call is turned into a call to an extern 
func
+    f_str = strip_whitespace(
+        dump_ast(
+            f,
+            include_type_annotations=False,
+            include_struct_info_annotations=False,
+            include_call_attrs=True,
+        )
+    )
+
+    # the function has an annotated return type
+    assert "ret_struct_info=ObjectStructInfo()" in f_str
+
+    assert isinstance(f.body, rx.SeqExpr)
+    extern_call = f.body.blocks[0].bindings[-1].value
+    extern_call_text = dump_ast(
+        extern_call,
+        include_type_annotations=False,
+        include_struct_info_annotations=False,
+        include_call_attrs=True,
+    )
+    assert strip_whitespace(extern_call_text) in f_str
+    assert_fields(
+        "Call",
+        {
+            "op": 'ExternFunc(global_symbol="contrib.tensor_array_stack")',
+            "args": '[Var(name_hint="x"), Var(name_hint="y")]',
+            "sinfo_args": "[ObjectStructInfo()]",
+            "attrs": '{"test_attr": 1}',
+        },
+        extern_call_text,
+    )
+
+    # check that the op call is there too
+    op_call = f.body.blocks[0].bindings[0].value
+    op_call_text = dump_ast(
+        op_call,
+        include_type_annotations=False,
+        include_struct_info_annotations=False,
+        include_call_attrs=True,
+    )
+    assert strip_whitespace(op_call_text) in f_str
+    assert_fields(
+        "Call",
+        {
+            "op": 'Op(name="relax.multiply")',
+            "args": '[Var(name_hint="x"), Var(name_hint="y")]',
+        },
+        op_call_text,
+    )
+
+    # TODO: add testcase for op attrs
+
+
+def test_call_tir():
+    # also from test_parser
+    @R.function
+    def foo(x: R.Tensor(("m", "n"), "float32")):
+        m, n = T.var("int64"), T.var("int64")
+        gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), 
dtype="float32"))
+        return gv0
+
+    foo_str = strip_whitespace(
+        dump_ast(
+            foo,
+            include_type_annotations=False,
+            include_struct_info_annotations=False,
+            include_call_attrs=False,
+        )
+    )
+    assert foo_str.startswith('Function(params=[Var(name_hint="x")]')
+
+    # call_tir is an op in Relax and it takes an extern func as an argument
+    assert isinstance(foo.body, rx.SeqExpr)
+    tir_call = foo.body.blocks[0].bindings[0].value
+    tir_call_text = dump_ast(
+        tir_call,
+        include_type_annotations=False,
+        include_struct_info_annotations=False,
+        include_call_attrs=False,
+    )
+    assert_fields(
+        "Call",
+        {
+            "op": 'Op(name="relax.call_tir")',
+            "args": """[
+                ExternFunc(global_symbol="test.op.identity"),
+                Tuple(fields=[Var(name_hint="x")])
+            ]""",
+            "sinfo_args": """[
+                TensorStructInfo(
+                    dtype=float32,
+                    shape=ShapeExpr(
+                        values=[
+                            PrimExpr(value=`m`),
+                            PrimExpr(value=`n`)
+                        ]
+                    )
+                )
+            ]""",
+        },
+        tir_call_text,
+    )
+    assert strip_whitespace(tir_call_text) in foo_str
+
+
+def test_operators():
+    @R.function
+    def foo(x: R.Tensor):
+        return R.unique(x, sorted=True, axis=-1)
+
+    foo_str = strip_whitespace(
+        dump_ast(
+            foo,
+            include_type_annotations=False,
+            include_struct_info_annotations=False,
+        )
+    )
+    assert 'Op(name="relax.unique")' in foo_str
+    # the sorted argument is true, so it will be a PrimValue of 1
+    assert "PrimExpr(value=`T.int64(1)`)" in foo_str
+    # axis is -1
+    assert "PrimExpr(value=`T.int64(-1)`)" in foo_str
+
+    @R.function
+    def bar(x: R.Tensor):
+        return R.print(x, format="{}")
+
+    bar_str = strip_whitespace(
+        dump_ast(
+            bar,
+            include_type_annotations=False,
+            include_struct_info_annotations=False,
+        )
+    )
+    # the format string is a StringImm argument
+    assert 'StringImm(value="{}")' in bar_str
+
+
+def test_print_struct_info_annotation_non_var():
+    @R.function
+    def f() -> R.Tensor:
+        return R.const([1, 2])
+
+    body = normalize(f).body
+    body_str = strip_whitespace(dump_ast(body))
+    # the constant has a shape of (2,)
+    struct_info = strip_whitespace(
+        """
+        struct_info=TensorStructInfo(
+            dtype=int32,
+            shape=ShapeExpr(
+                values=[PrimExpr(value=`T.int64(2)`)],
+                struct_info=ShapeStructInfo(
+                    ndim=1,
+                    values=[PrimExpr(value=`T.int64(2)`)]
+                ),
+                checked_type_=ShapeType(ndim=1)
+            )
+        )
+        """
+    )
+    assert struct_info in body_str
+
+
+def test_print_type_annotation_non_var():
+    @R.function
+    def f() -> R.Shape:
+        return R.shape_of(R.const(1))
+
+    body = normalize(f).body
+    assert isinstance(body, rx.SeqExpr)
+    call = body.blocks[-1].bindings[-1].value
+    assert isinstance(call, rx.Call)
+    arg = call.args[0]
+    arg_str = strip_whitespace(dump_ast(arg))
+    # the constant should have a tensor type
+    assert "checked_type_=DynTensorType(ndim=0" in arg_str
+
+    call_str = strip_whitespace(dump_ast(call))
+    # we expect the shape_of call to have a checked_type_ of ShapeType
+    type_str = "checked_type_=ShapeType(ndim=-1)"
+    assert type_str in call_str
+
+
+def test_if():
+    @R.function
+    def f(cond: R.Tensor((), dtype="bool")) -> R.Tensor((), dtype="int32"):
+        if cond:
+            x = R.const(1)
+        else:
+            x = R.const(2)
+        return x
+
+    body = normalize(f).body
+    assert isinstance(body, rx.SeqExpr)
+    body_str = strip_whitespace(dump_ast(body))
+    # we expect both branches to be seq exprs
+    assert "If" in body_str
+    assert "true_branch=SeqExpr(" in body_str
+    assert "false_branch=SeqExpr(" in body_str
+
+
+def test_tuple_get_item():
+    @R.function
+    def f(x: R.Tuple(R.Tensor((), dtype="int32"))) -> R.Tensor((), 
dtype="int32"):
+        return x[0]
+
+    body = normalize(f).body
+    assert isinstance(body, rx.SeqExpr)
+    body_str = strip_whitespace(dump_ast(body))
+
+    assert "TupleGetItem" in body_str
+    assert 'tuple_value=Var(name_hint="x"' in body_str
+    assert "index=0" in body_str
+
+
+def test_prim_value():
+    prim_value = rx.PrimValue(tir.IntImm("int64", 1))
+    prim_str = strip_whitespace(dump_ast(prim_value))
+    assert prim_str == strip_whitespace(
+        """
+        PrimValue(
+            value=PrimExpr(value=`T.int64(1)`),
+            struct_info=PrimStructInfo(dtype=int64),
+            checked_type_=PrimType(dtype=int64)
+        )
+    """
+    )
+
+
+def test_string_imm():
+    string_imm = rx.StringImm("test")
+    str_str = strip_whitespace(dump_ast(string_imm))
+    assert str_str == strip_whitespace(
+        """
+        StringImm(
+            value="test",
+            struct_info=ObjectStructInfo(),
+            checked_type_=ObjectType()
+        )
+    """
+    )
+
+
+def test_datatype_imm():
+    data_type_imm = rx.DataTypeImm("int32")
+    data_type_str = strip_whitespace(dump_ast(data_type_imm))
+    assert data_type_str == strip_whitespace(
+        """
+        DataTypeImm(
+            value=int32,
+            struct_info=ObjectStructInfo(),
+            checked_type_=ObjectType()
+        )
+    """
+    )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to