This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity-staging in repository https://gitbox.apache.org/repos/asf/tvm.git
commit ff1a02c66aa69fd0e8707a671172f3942a7f61e5 Author: Yuchen Jin <yuch...@cs.washington.edu> AuthorDate: Fri Feb 10 23:33:46 2023 -0800 [Unity] Relax VM shape lowering pass (#13956) This PR introduces Relax `FunctionPass` and `DataflowBlockPass` API, and the `VMShapeLower` pass to lower the shape expression in Relax to TIR functions and VM shape heap builtin functions. Co-Authored-by: Ziheng Jiang <zih...@apache.org> Co-Authored-by: Lesheng Jin <34279105+lesheng...@users.noreply.github.com> Co-Authored-by: Altan Haan <alt...@cs.washington.edu> Co-Authored-by: Junru Shao <junrushao1...@gmail.com> Co-Authored-by: Prakalp Srivastava <prak...@octoml.ai> Co-Authored-by: Ruihang Lai <ruiha...@cs.cmu.edu> Co-Authored-by: Siyuan Feng <hzfen...@sjtu.edu.cn> Co-Authored-by: Steven S. <Lyubomirsky slyubomir...@octoml.ai> Co-Authored-by: Sunghyun Park <49998730+sun...@users.noreply.github.com> Co-Authored-by: Tianqi Chen <tianqi.tc...@gmail.com> Co-Authored-by: Yong Wu <yongc...@gmail.com> --- include/tvm/relax/backend.h | 44 ++ include/tvm/relax/transform.h | 72 ++ python/tvm/relax/__init__.py | 1 + python/tvm/relax/transform/__init__.py | 20 + python/tvm/relax/transform/_ffi_api.py | 19 + python/tvm/relax/transform/transform.py | 345 ++++++++++ src/relax/backend/vm/vm_shape_lower.cc | 725 +++++++++++++++++++++ src/relax/ir/transform.cc | 413 ++++++++++++ .../relax/test_backend_transform_shape_lower.py | 429 ++++++++++++ 9 files changed, 2068 insertions(+) diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h new file mode 100644 index 0000000000..4ebeacac0f --- /dev/null +++ b/include/tvm/relax/backend.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/backend.h + * \brief Relax backend specific transformation passes. + */ +#ifndef TVM_RELAX_BACKEND_H_ +#define TVM_RELAX_BACKEND_H_ + +#include <tvm/relax/transform.h> + +namespace tvm { +namespace relax { +namespace transform { + +/*! + * \brief Lower the shape expression in relax to VM shape heap and TIR functions. + * + * \return The Pass. + */ +TVM_DLL Pass VMShapeLower(); + +} // namespace transform +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BACKEND_H_ diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h new file mode 100644 index 0000000000..fa288a7f06 --- /dev/null +++ b/include/tvm/relax/transform.h @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform.h + * \brief Relax specific transformation passes. + */ +#ifndef TVM_RELAX_TRANSFORM_H_ +#define TVM_RELAX_TRANSFORM_H_ + +#include <tvm/ir/transform.h> +#include <tvm/relax/expr.h> + +namespace tvm { +namespace relax { +namespace transform { + +using Pass = tvm::transform::Pass; +using PassInfo = tvm::transform::PassInfo; +using PassContext = tvm::transform::PassContext; +using Function = tvm::relax::Function; +using DataflowBlock = tvm::relax::DataflowBlock; + +/*! + * \brief Create a function pass. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the function pass. + * \param name The name of the function pass. + * \param required The list of the passes that the function pass is dependent on. + * + * \return The created function pass. + */ +TVM_DLL Pass CreateFunctionPass( + const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func, + int opt_level, String name, tvm::Array<String> required); + +/*! + * \brief Create a dataflowblock pass. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the dataflowblock pass. + * \param name The name of the dataflowblock pass. + * \param required The list of the passes that the dataflowblock pass is dependent on. + * + * \return The created dataflowblock pass. + */ +TVM_DLL Pass CreateDataflowBlockPass( + const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)>& pass_func, + int opt_level, String name, tvm::Array<String> required); + +} // namespace transform +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_TRANSFORM_H_ diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index ce175354d0..a6306b788e 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -20,6 +20,7 @@ from . import exec_builder from . import expr from . import ty from . import analysis +from . import transform from . import vm from . import block_builder from . import op diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py new file mode 100644 index 0000000000..eb4d5f710c --- /dev/null +++ b/python/tvm/relax/transform/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax transformations. """ + +from .transform import * diff --git a/python/tvm/relax/transform/_ffi_api.py b/python/tvm/relax/transform/_ffi_api.py new file mode 100644 index 0000000000..667aa62c2c --- /dev/null +++ b/python/tvm/relax/transform/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.transform""" +import tvm._ffi + +tvm._ffi._init_api("relax.transform", __name__) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py new file mode 100644 index 0000000000..f20f06c522 --- /dev/null +++ b/python/tvm/relax/transform/transform.py @@ -0,0 +1,345 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Relax transformation passes.""" +import functools +import inspect +import types +from typing import Callable, Union + +import tvm.ir +from . import _ffi_api + + +@tvm._ffi.register_object("relax.FunctionPass") +class FunctionPass(tvm.ir.transform.Pass): + """A pass that works on each tvm.relax.Function in a module. A function + pass class should be created through `function_pass`. + """ + + +@tvm._ffi.register_object("relax.DataflowBlockPass") +class DataflowBlockPass(tvm.ir.transform.Pass): + """A pass that works on each tvm.relax.DataflowBlock in a module.""" + + +def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass: + """Lower the symbolic shape and argument and match-cast structinfo matching. + + Parameters + ---------- + emit_err_ctx: Optional[bool] + Whether emit err context string, can be turned off for testing purposes. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.VMShapeLower(emit_err_ctx) # type: ignore + + +def _wrap_class_function_pass(pass_cls, pass_info): + """Wrap a python class as function pass.""" + + class PyFunctionPass(FunctionPass): + """Internal wrapper class to create a class instance.""" + + def __init__(self, *args, **kwargs): + # initialize handle in case pass_cls creation failed. + self.handle = None + inst = pass_cls(*args, **kwargs) + + # it is important not to capture self to + # avoid a cyclic dependency + def _pass_func(func, mod, ctx): + return inst.transform_function(func, mod, ctx) + + self.__init_handle_by_constructor__( + _ffi_api.MakeFunctionPass, _pass_func, pass_info # type: ignore + ) + self._inst = inst + + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) + + functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__) + PyFunctionPass.__name__ = pass_cls.__name__ + PyFunctionPass.__doc__ = pass_cls.__doc__ + PyFunctionPass.__module__ = pass_cls.__module__ + return PyFunctionPass + + +def function_pass( + pass_func=None, + opt_level=None, + name=None, + required=None, +) -> Union[Callable, FunctionPass]: + """Decorate a function pass. + + This function returns a callback when pass_func + is provided. Otherwise, it returns the created function pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(Function, Module, PassContext) -> Function]] + The transformation function or class. + + opt_level : int + The optimization level of this function pass. + + name : Optional[str] + The name of the function pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the function pass is dependent on. + + Returns + ------- + create_function_pass : Union[Callable, FunctionPass] + + A decorator will be returned if pass_func is not provided, + otherwise return the decorated result. + The returned decorator has two behaviors depending on the input: + A new FunctionPass will be returned when we decorate a pass function. + A new FunctionPass class will be returned when we decorate a class type. + + Examples + -------- + The following code block decorates a function pass class. + + .. code-block:: python + + @relax.transform.function_pass(opt_level=1) + class TestReplaceFunc: + def __init__(self, new_func): + self.new_func = new_func + + def transform_function(self, func, mod, ctx): + # just for demo purposes + # transform func to new_func + return self.new_func + + @R.function + def f1(x: Tensor[(m, n), "float32"]): + return x + + @tvm.script.ir_module + class InputMod: + @R.function + def f2(x: Tensor[(m, n), "float32"]): + gv0 = relax.add(x, x) + return gv0 + # fpass is now a special pass that replaces every + # function to f1 + fpass = TestReplaceFunc(f1) + # now every function in InputMod is replaced by f1 + updated_mod = fpass(InputMod) + + + The following code creates a function pass by decorating + a user defined transform function. + + .. code-block:: python + + @relax.transform.function_pass(opt_level=2) + def transform(func, mod, ctx): + # my transformations here. + return func + + function_pass = transform + assert isinstance(function_pass, relax.transform.FunctionPass) + assert function_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = function_pass(m) + # Now transform should have been applied to every function in + # the provided module m. And the updated module will be returned. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the function pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + "list/tuple.") + + def create_function_pass(pass_arg): + """Internal function that creates a function pass""" + fname = name if name else pass_arg.__name__ + info = tvm.transform.PassInfo(opt_level, fname, required) + if inspect.isclass(pass_arg): + return _wrap_class_function_pass(pass_arg, info) + if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Function pass") + return _ffi_api.MakeFunctionPass(pass_arg, info) # type: ignore + + if pass_func: + return create_function_pass(pass_func) + return create_function_pass + + +def _wrap_class_dataflowblock_pass(pass_cls, pass_info): + """Wrap a python class as dataflowblock pass""" + + class PyDataflowBlockPass(DataflowBlockPass): + """Internal wrapper class to create a class instance.""" + + def __init__(self, *args, **kwargs): + # initialize handle in case pass_cls creation failed. + self.handle = None + inst = pass_cls(*args, **kwargs) + + # it is important not to capture self to + # avoid a cyclic dependency + def _pass_func(func, mod, ctx): + return inst.transform_dataflowblock(func, mod, ctx) + + self.__init_handle_by_constructor__( + _ffi_api.MakeDataflowBlockPass, _pass_func, pass_info # type: ignore + ) + self._inst = inst + + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) + + functools.update_wrapper(PyDataflowBlockPass.__init__, pass_cls.__init__) + PyDataflowBlockPass.__name__ = pass_cls.__name__ + PyDataflowBlockPass.__doc__ = pass_cls.__doc__ + PyDataflowBlockPass.__module__ = pass_cls.__module__ + return PyDataflowBlockPass + + +def dataflowblock_pass( + pass_func=None, opt_level=None, name=None, required=None +) -> Union[Callable, DataflowBlockPass]: + """Decorate a dataflowblock pass. + + This function returns a callback when pass_func + is provided. Otherwise, it returns the created dataflowblock pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(DataflowBlock, Module, PassContext) -> DataflowBlock]] + The transformation function or class. + + opt_level : int + The optimization level of this dataflowblock pass. + + name : Optional[str] + The name of the dataflowblock pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the dataflowblock pass is dependent on. + + Returns + ------- + create_dataflowblock_pass : Union[Callable, DataflowBlockPass] + + A decorator will be returned if pass_func is not provided, + otherwise return the decorated result. + The returned decorator has two behaviors depending on the input: + A new DataflowBlockPass will be returned when we decorate a pass function. + A new DataflowBlockPass class will be returned when we decorate a class type. + + Examples + -------- + The following code block decorates a dataflowblock pass class. + + .. code-block:: python + + @relax.transform.dataflowblock_pass(opt_level=1) + class TestReplaceBinding: + # Simple test function to replace the first VarBinding to another. + + def __init__(self): + # create a new VarBinding + m, n = tir.Var("m", "int64"), tir.Var("n", "int64") + lv0 = relax.Var("lv1", relax.TensorStructInfo([m, n], "float32")) + val = relax.const(np.random.rand(24, 56)) + self.new_binding = relax.VarBinding(lv0, val) + + def transform_dataflowblock(self, block, mod, ctx): + # just for demo purposes + # Replace the first binding in the DataflowBlock + new_bindings = [self.new_binding, block.bindings[1]] + new_block = relax.expr.DataflowBlock(new_bindings, block.span) + return new_block + + @tvm.script.ir_module + class InputMod: + @R.function + def f1(x: Tensor[(m, n), "float32"]): + with relax.dataflow(): + lv0 = relax.multiply(x, x) + gv0 = relax.add(x, x) + relax.output(gv0) + return gv0 + # block_pass is now a special pass that replaces every + # first binding to the constant value binding + block_pass = TestReplaceBinding() + # now every first binding in DataflowBlock of InputMod + # is replaced by new_binding + updated_mod = block_pass(InputMod) + + + The following code creates a dataflowblock pass by decorating + a user defined transform function. + + .. code-block:: python + + @relax.transform.dataflowblock_pass(opt_level=2) + def transform(block, mod, ctx): + # my transformations here. + return block + + block_pass = transform + assert isinstance(block_pass, relax.transform.DataflowBlockPass) + assert block_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = block_pass(m) + # Now transform should have been applied to every DataflowBlock in + # the provided module m. And the updated module will be returned. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the dataflowblock pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + "list/tuple.") + + def create_dataflowblock_pass(pass_arg): + """Internal function that creates a dataflowblock pass""" + fname = name if name else pass_arg.__name__ + info = tvm.transform.PassInfo(opt_level, fname, required) + if inspect.isclass(pass_arg): + return _wrap_class_dataflowblock_pass(pass_arg, info) + if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for DataflowBlock pass") + return _ffi_api.MakeDataflowBlockPass(pass_arg, info) # type: ignore + + if pass_func: + return create_dataflowblock_pass(pass_func) + return create_dataflowblock_pass diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc new file mode 100644 index 0000000000..090bcf01b5 --- /dev/null +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -0,0 +1,725 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/backend/vm/vm_shape_lower.cc + * \brief Lower the function boundary type checks and symbolic shape computations. + */ +#include <tvm/relax/analysis.h> +#include <tvm/relax/backend.h> +#include <tvm/relax/expr_functor.h> +#include <tvm/relax/struct_info.h> +#include <tvm/relax/struct_info_functor.h> +#include <tvm/runtime/relax_vm/builtin.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/function.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt_functor.h> + +namespace tvm { +namespace relax { + +/*! \brief A slot used in PrimExpr lowering. */ +struct PrimExprSlot { + /*! \brief The existing */ + PrimExpr expr; + /*! \brief The slot index */ + int index; + // The following three members are auxiliary data + // to help shape rewriting. + /*! + * \brief List of slots whose PrimExpr uses this PrimExpr. + * \note Users won't be empty only if PrimExpr is a Var and it does not include itself. + */ + std::vector<PrimExprSlot*> user_slots; + /*! + * \brief Number of outstanding vars that are not defined in this PrimExpr. + * \note This is a helper counter used in analysis to perform computations. + */ + int outstanding_defs = 0; + /*! \brief Whether we have computed the value. */ + bool value_computed = false; +}; + +/*! + * \brief Helper dats structure to collect pairs of match shapes + * in a recursive matching process. + */ +struct MatchShapeTodoItem { + Expr input; + Array<PrimExpr> pattern; + String err_ctx; +}; + +/*! \brief Slot map used for shape lowering. */ +using PrimExprSlotMap = + std::unordered_map<PrimExpr, PrimExprSlot*, StructuralHash, tir::ExprDeepEqual>; + +// Collector to collect PrimExprSlotMap +class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor { + public: + // collect the PrimExpr slot for a given function + static void Collect(Function func, std::vector<std::unique_ptr<PrimExprSlot>>* slot_vec, + PrimExprSlotMap* slot_map) { + PrimExprSlotCollector collector; + collector.slot_vec_ = slot_vec; + collector.slot_map_ = slot_map; + // collect shape declaration in func params + for (auto param : func->params) { + collector.VisitStructInfo(GetStructInfo(param)); + collector.VisitExpr(param); + } + collector.VisitExpr(func->body); + } + + private: + void VisitPrimExpr(const PrimExpr& expr) final { + if (expr->IsInstance<IntImmNode>()) return; + if (slot_map_->count(expr) == 0) { + auto slot = std::make_unique<PrimExprSlot>(); + slot->expr = expr; + slot->index = static_cast<int>(slot_vec_->size()); + slot_map_->emplace(expr, slot.get()); + slot_vec_->emplace_back(std::move(slot)); + } + } + + void VisitBinding_(const MatchCastNode* op) final { + // Visit the match cast struct info so we can define + // the symbolic variables here. + this->VisitStructInfo(op->struct_info); + } + + void VisitExpr_(const FunctionNode* op) final { + // Do not recurse into function node as it is self-contained + } + + void VisitStructInfo_(const FuncStructInfoNode* op) final { + // Do not recurse into function struct info as it is self-contained + } + + void VisitStructInfoExprField(const PrimExpr& expr) final { VisitPrimExpr(expr); } + + void VisitStructInfoExprField(const Expr& expr) final { ExprVisitor::VisitExpr(expr); } + + std::vector<std::unique_ptr<PrimExprSlot>>* slot_vec_; + PrimExprSlotMap* slot_map_; +}; + +/*! + * \brief Main logic to transform the shape lowered functions + * + * Consider the following input: + * + * \code + * + * def f(x: R.Tuple(R.Tensor([m, n+1]), R.Tensor([n, 2])) -> R.Tensor: + * return x + * + * \endcode + * + * Overall flow of the algorithm: + * - Preprocess: PrimExprSlot collection, we scan the function and allocate PrimExprSlot + * for each PrimExpr. In the above example, the result mapping from the slot index + * to expr would be {0:m, 1: n+1: 2: n}. Note that "n+1" also get a slot. + * PrimExprSlot also comes with auxiliary fields that track whether its value + * can be readily computed. + * + * Steps at each matching point: + * - Step 0: We call CheckMatchCast, + * which will recursively unpack the StructInfo, and generate static information checks. + * Note that this step only generates functions for checking types and ndim info, but not + * the symbolic shape variables. The symbolic shape-matching results will be returned as + * vector<MatchShapeTodoItem>. This is because symbolic shape matching may not be completed + * in a single round. Importantly, CheckMatchCast also deals with tuple unpacking. + * + * - Step 1: We then call RunMatch to generate the statements for matching symbolic shapes. + * In the above example, the first round will store the value of m, n to their corresponding + * slot. RunMatch may return outstanding items. In the above example x.shape[1] == n+1 cannot + * be checked in the first round. RunMatch will populate new vars(this case n, m), these vars + * are added to a ready queue (ready_vars_) + * + * - Step 2: We EmitOutstandingPrimExprCompute to check if ready_vars will trigger new values + * to be computed. We eagerly compute all the outstanding values. The trigger is done through + * a ref counter which decreases when each outstanding def is satisfied. + * This step can also generate additional TIR functions to carry out shape computations. + * + * - Step 3: RunMatch again for given outstanding match todos. This time all invariants + * should be checked. + * + * The above step would populate each slot(which is backed by an element in shape_heap). + * Each time we find a symbolic shape tuple, we call MakeShape for given slot indices + * in the shape_heap. + * + * + * Key functions in the flow: + * - PrimExprSlotCollector: preprocessing and collecting the slots + * - CheckMatchCast: recursively structinfo unpacking, generate checks and match items. + * - RunMatch: generate symbolic shape matches + * - EmitOutstandingPrimExprCompute: tracks the variables to be computed and emit shape computation + * - VisitExpr_(ShapeExprNode*): makes symbolic shape tuple. + * + * The checks and symbolic shape all maps to runtime builtin functions. Please checkout + * runtime/relax_vm/builtin.cc for their definitions. + * + * Shape computation are lowered to host-side TIR functions that load var from slot + * and store computed results into the slot. For a given slot map: {0:m, 1: n+1: 2: n} + * It will create the shape_func below that loads data from H[2](n's slot) run compute + * and store back to H[1](n+1's slot). + * + * \code + * + * @T.prim_func + * def shape_func(H: T.Buffer([3], "int64")): + * H[1] = H[2] + 1 + * + * \endcode + * + * The current implementation will batch all shape computations at each match point. + * For example, all the expressions that depend on n, m will be computed in a single + * shape_func at the function boundary. If there are follow-up match_cast points, + * that defines new variable, then we might we will generate new shape functions + * to compute expressions that depend on these variables. + */ +class VMShapeLowerMutator + : public ExprMutator, + public StructInfoFunctor<void(const StructInfo&, Expr, bool, const String&, + std::vector<MatchShapeTodoItem>*)> { + public: + static IRModule Lower(IRModule mod, bool emit_err_ctx) { + VMShapeLowerMutator mutator(mod, emit_err_ctx); + + for (auto& kv : mod->functions) { + if (auto* func = kv.second.as<FunctionNode>()) { + Function updated_func = mutator.Rewrite(kv.first, GetRef<Function>(func)); + mutator.builder_->UpdateFunction(kv.first, updated_func); + } + } + return mutator.builder_->GetContextIRModule(); + } + + private: + explicit VMShapeLowerMutator(IRModule mod, bool emit_err_ctx) + : ExprMutator(mod), emit_err_ctx_(emit_err_ctx) {} + + using ExprMutator::VisitExpr_; + + // Unit rewrite function per function. + Function Rewrite(GlobalVar gvar, Function func) { + // prepare mapping and heap var + PrimExprSlotCollector::Collect(func, &slot_vec_, &slot_map_); + heap_size_ = IntImm(ShapeDType(), static_cast<int64_t>(slot_vec_.size())); + VarBinding shape_heap_binding = this->AllocShapeHeapBinding(heap_size_); + shape_heap_ = shape_heap_binding->var; + + // prepare slot information + this->PopulateSlotInfo(); + + Array<BindingBlock> blocks; + + builder_->BeginScope(func->params); + + { + // Check the parameter section. + builder_->BeginBindingBlock(); + this->builder_->EmitNormalized(shape_heap_binding); + std::vector<MatchShapeTodoItem> match_todos; + for (size_t i = 0; i < func->params.size(); ++i) { + StructInfo sinfo = GetStructInfo(func->params[i]); + std::ostringstream err_ctx; + err_ctx << "ErrorContext(fn=" << gvar->name_hint << ", loc=param[" << i + << "], param=" << func->params[i]->name_hint() << ", annotation=" << sinfo << ") "; + this->CheckMatchCast(sinfo, func->params[i], true, err_ctx.str(), &match_todos); + } + // insert heap generation logic. + match_todos = this->RunMatch(match_todos, false); + this->EmitOutstandingPrimExprCompute(); + this->RunMatch(match_todos, true); + + BindingBlock pre_block = builder_->EndBlock(); + blocks.push_back(pre_block); + } + + // new body. + auto body_seq = Downcast<SeqExpr>(this->VisitWithNewScope(func->body, func->params)); + blocks.insert(blocks.end(), body_seq->blocks.begin(), body_seq->blocks.end()); + + { + // Insert the return value check + builder_->BeginBindingBlock(); + std::ostringstream err_ctx; + err_ctx << "ErrorContext(fn=" << gvar->name_hint + << ", loc=return, annotation=" << func->ret_struct_info << ") "; + std::vector<MatchShapeTodoItem> match_todos; + // NOTE: the return value's shape computation must already be defined. + this->CheckMatchCast(func->ret_struct_info, body_seq->body, false, err_ctx.str(), + &match_todos); + // NOTE: the return value's shape computation must already be defined. + this->RunMatch(match_todos, true); + BindingBlock post_block = builder_->EndBlock(); + blocks.push_back(post_block); + } + + auto new_body = builder_->Normalize(SeqExpr(blocks, body_seq->body)); + // create a new function + return Function(func->params, new_body, func->ret_struct_info, func->attrs); + } + + //------------------------------------------------------- + // PrimExpr slot handling + //------------------------------------------------------- + static DataType ShapeDType() { return DataType::Int(64); } + + /*! \brief populate additional information in the slot. */ + void PopulateSlotInfo() { + for (auto& kv : slot_map_) { + auto* slot = kv.second; + if (!slot->expr.as<tir::VarNode>()) { + Array<tir::Var> dep_vars = tir::UndefinedVars(slot->expr); + for (auto var : dep_vars) { + auto it = slot_map_.find(var); + ICHECK(it != slot_map_.end()) + << "Var " << var << "is not defined in the function but is referenced by " + << slot->expr; + auto* var_slot = it->second; + // populate the use slot. + var_slot->user_slots.push_back(slot); + } + // set outstanding defs. + slot->outstanding_defs += static_cast<int>(dep_vars.size()); + } + } + } + //------------------------------------------------------- + // Helper functions + //------------------------------------------------------- + StringImm GetErrContext(String err_ctx) const { + return emit_err_ctx_ ? StringImm(err_ctx) : StringImm(""); + } + + VarBinding AllocShapeHeapBinding(IntImm heap_size) { + if (heap_size->value > 0) { + TensorStructInfo heap_sinfo(ShapeDType(), 1); + Var var("shape_heap", heap_sinfo); + // set up the builtin func. + Call call(call_builtin_with_ctx_op_, + {builtin_alloc_shape_heap_, Tuple({PrimValue(heap_size)})}, Attrs(), {heap_sinfo}); + UpdateStructInfo(call, heap_sinfo); + return VarBinding(var, call); + } else { + Var var("shape_heap", ObjectStructInfo()); + Call call(null_value_op_, {}); + UpdateStructInfo(call, ObjectStructInfo()); + return VarBinding(var, call); + } + } + + //------------------------------------------------------- + // Expr mutation overloading. + //------------------------------------------------------- + Expr VisitExpr_(const FunctionNode* op) final { + LOG(FATAL) << "VMShapeLower do not work for local functions, make sure " + << " to run it after LambdaLift"; + return GetRef<Expr>(op); + } + + Expr VisitExpr_(const ShapeExprNode* op) final { + using runtime::relax_vm::MakeShapeCode; + // Constant shape can be preserved. + bool is_const_shape = std::all_of(op->values.begin(), op->values.end(), [](const PrimExpr& e) { + return e->IsInstance<IntImmNode>(); + }); + if (is_const_shape) { + return GetRef<Expr>(op); + } + + Array<Expr> args = {shape_heap_, PrimValue::Int64(static_cast<int64_t>(op->values.size()))}; + for (PrimExpr expr : op->values) { + if (auto* int_expr = expr.as<IntImmNode>()) { + args.push_back(PrimValue::Int64(static_cast<int>(MakeShapeCode::kUseImm))); + args.push_back(PrimValue::Int64(int_expr->value)); + } else { + auto it = slot_map_.find(expr); + ICHECK(it != slot_map_.end()); + auto* slot = it->second; + ICHECK(slot->value_computed) << "PrimExpr " << expr << " has not been computed"; + args.push_back(PrimValue::Int64(static_cast<int>(MakeShapeCode::kLoadShape))); + args.push_back(PrimValue::Int64(slot->index)); + } + } + + // make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n]) + Call call(builtin_make_shape_, args, Attrs(), + {ShapeStructInfo(static_cast<int>(op->values.size()))}); + return call; + } + + void VisitBinding_(const MatchCastNode* binding) final { + Expr value = ExprMutator::VisitExpr(binding->value); + std::vector<MatchShapeTodoItem> match_todos; + std::ostringstream err_ctx; + err_ctx << "ErrorContext(match_cast, struct_info=" << binding->struct_info << ") "; + // always_check=false + this->CheckMatchCast(binding->struct_info, value, false, err_ctx.str(), &match_todos); + + match_todos = this->RunMatch(match_todos, false); + this->EmitOutstandingPrimExprCompute(); + this->RunMatch(match_todos, true); + + // These checks are emitted as extra, in codegen + // match-cast is simply ignored and treated as a normal binding. + builder_->EmitNormalized(GetRef<MatchCast>(binding)); + } + + // Do not override shape in struct info fields + // We only override the shape that are already part of the normal function values + // If future passes lift those values out into the values, + // then codegen may not be able to handle symbolic values. + // Place this pass as last pass before codegen. + StructInfo VisitExprDepStructInfoField(const StructInfo& sinfo) final { return sinfo; } + + //------------------------------------------------------- + // Shape computations. + //------------------------------------------------------- + /*! + * \brief Execute the match todo items. + * + * This function can populate vars in the match items when seeing it for the first time. + * These new vars will be added to this->ready_vars_. + * + * If an item contains PrimExpr that are yet to be computed (but may be computable through + * vars defined in this round), it will be returned to the caller. + * + * The caller should call EmitOutstandingPrimExprCompute, then call RunMatch again. + * + * \param match_todos The list of match items to be executed. + * \param require_value_computed Whether we require all expr to be computed. + * \return List of outstanding items that contains value that are yet to be computed. + */ + std::vector<MatchShapeTodoItem> RunMatch(const std::vector<MatchShapeTodoItem>& match_todos, + bool require_value_computed) { + std::vector<MatchShapeTodoItem> outstanding_todos; + + using runtime::relax_vm::MatchShapeCode; + for (const MatchShapeTodoItem& item : match_todos) { + int64_t shape_len = static_cast<int64_t>(item.pattern.size()); + bool all_nop = true; + int num_outstanding_exprs = 0; + + Array<Expr> args = {item.input, shape_heap_, PrimValue::Int64(shape_len)}; + + for (PrimExpr expr : item.pattern) { + MatchShapeCode code = MatchShapeCode::kNoOp; + int64_t rvalue = 0; + if (auto* int_expr = expr.as<IntImmNode>()) { + code = MatchShapeCode::kAssertEqualToImm; + rvalue = int_expr->value; + } else { + auto it = slot_map_.find(expr); + ICHECK(it != slot_map_.end()); + auto* slot = it->second; + if (slot->value_computed) { + code = MatchShapeCode::kAssertEqualToLoad; + rvalue = slot->index; + } else { + // the value is not yet computed + ICHECK(!require_value_computed) << "PrimExpr " << expr << " is not computed"; + if (expr.as<tir::VarNode>()) { + // if it is a var, we will populate it in this round. + // otherwise, we skip and mark it as outstanding + code = MatchShapeCode::kStoreToHeap; + rvalue = slot->index; + slot->value_computed = true; + ready_vars_.push_back(slot); + } else { + code = MatchShapeCode::kNoOp; + rvalue = 0; + ++num_outstanding_exprs; + } + } + } + all_nop = all_nop && code == MatchShapeCode::kNoOp; + args.push_back(PrimValue::Int64(static_cast<int>(code))); + args.push_back(PrimValue::Int64(rvalue)); + } + if (num_outstanding_exprs != 0) { + outstanding_todos.push_back(item); + } + args.push_back(GetErrContext(item.err_ctx)); + if (!all_nop) { + Call call(builtin_match_shape_, args, Attrs(), {void_sinfo_}); + builder_->Emit(call, "_"); + } + } + return std::move(outstanding_todos); + } + + /*! + * \brief Compute a list of prim expr that now be computed + * for given ready vars. + */ + std::vector<PrimExprSlot*> GetReadyPrimExprSlots() { + std::vector<PrimExprSlot*> to_compute; + for (PrimExprSlot* slot : ready_vars_) { + for (PrimExprSlot* user : slot->user_slots) { + ICHECK_GT(user->outstanding_defs, 0); + user->outstanding_defs -= 1; + if (user->outstanding_defs == 0) { + to_compute.push_back(user); + } + } + } + ready_vars_.clear(); + return to_compute; + } + + /*! + * \brief Check the dependent expressions of ready_vars_, + * + * If there are outstanding PrimExpr that can now be computed + * we generate a PrimFunc that compute the extra shape values + * + * We will then clear the ready_vars. + * + * \return Number of PrimExpr computed. + */ + size_t EmitOutstandingPrimExprCompute() { + std::vector<PrimExprSlot*> to_compute = GetReadyPrimExprSlots(); + if (to_compute.size() == 0) return 0; + ICHECK_GT(heap_size_->value, 0); + // construct a PrimFunc that compute the shape. + tir::Var heap("heap", DataType::Handle()); + Array<PrimExpr> buffer_shape{heap_size_}; + tir::Buffer buffer = tir::decl_buffer(buffer_shape, ShapeDType(), "H", "global"); + Map<tir::Var, tir::Buffer> buffer_map; + buffer_map.Set(heap, buffer); + + auto var_map = [&](const tir::Var& var) -> Optional<PrimExpr> { + auto it = slot_map_.find(var); + ICHECK(it != slot_map_.end()); + return tir::BufferLoad(buffer, {IntImm(ShapeDType(), it->second->index)}); + }; + + Array<tir::Stmt> seq; + for (PrimExprSlot* slot : to_compute) { + ICHECK(!slot->value_computed); + slot->value_computed = true; + PrimExpr value = tir::Substitute(slot->expr, var_map); + seq.push_back(tir::BufferStore(buffer, value, {IntImm(ShapeDType(), slot->index)})); + } + + tir::Stmt body = tir::SeqStmt::Flatten(seq); + Array<tir::Var> params{heap}; + Type ret_type = VoidType(); + + // TODO(relax-team): Consider attach the target attribute to + // the shape_func to indicate that this is a host function + // This could require us to attach target to the relax function here. + tir::PrimFunc shape_func(params, body, ret_type, buffer_map); + GlobalVar shape_func_var = builder_->AddFunction(shape_func, "shape_func"); + builder_->Emit(Call(shape_func_var, {shape_heap_}), "_"); + return to_compute.size(); + } + //------------------------------------------------------- + // StructInfo value match logic + // + // CheckMatchCast is the only function needed by + // other code sections + //------------------------------------------------------- + /*! + * \brief Insert runtime check of the match cast condition(value, struct_info). + * + * \param struct_info The struct info to be matched. + * \param value The input value. + * \param always_check Whether we insert runtime check even if we can prove + * that value's struct info already satisfies the condition. + * This option is necessary for argument checking per our calling convention. + * + * \param err_ctx Extra error context to bring more informative error reporting. + * \param match_todos List of match shape todo items collected when recursively + * visit the match cast. + */ + void CheckMatchCast(const StructInfo& struct_info, Expr value, bool always_check, + const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) { + return this->VisitStructInfo(struct_info, value, always_check, err_ctx, match_todos); + } + + void VisitStructInfo(const StructInfo& struct_info, Expr value, bool always_check, + const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) final { + // short-cut, if the struct info already satisfies the + // constraint during match cast, we can skip matching + if (!always_check && IsBaseOf(struct_info, GetStructInfo(value))) return; + return StructInfoFunctor::VisitStructInfo(struct_info, value, always_check, err_ctx, + match_todos); + } + + void VisitStructInfo_(const ObjectStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) final { + } + + void VisitStructInfo_(const PrimStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) final { + // TODO(relax-team) add PrimValue checks later. + LOG(FATAL) << "MatchCast of PrimValue is not yet supported"; + } + + void VisitStructInfo_(const ShapeStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) final { + // emit runtime check of shape + if (always_check || !IsBaseOf(ShapeStructInfo(op->ndim), GetStructInfo(value))) { + // check_shape_info(value, ndim, err_ctx) + Call call(builtin_check_shape_info_, + {value, PrimValue::Int64(op->ndim), GetErrContext(err_ctx)}, Attrs(), + {void_sinfo_}); + builder_->Emit(call, "_"); + } + if (op->values.defined()) { + MatchShapeTodoItem item; + item.input = value; + item.pattern = op->values.value(); + item.err_ctx = err_ctx; + match_todos->push_back(item); + } + } + + void VisitStructInfo_(const TensorStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) final { + // emit runtime check of shape + if (always_check || !IsBaseOf(TensorStructInfo(op->dtype, op->ndim), GetStructInfo(value))) { + // check_tensor_info(value, ndim, dtype, err_ctx) + Call call(builtin_check_tensor_info_, + {value, PrimValue::Int64(op->ndim), DataTypeImm(op->dtype), GetErrContext(err_ctx)}, + Attrs(), {void_sinfo_}); + builder_->Emit(call, "_"); + } + + if (auto* shape_expr = op->shape.as<ShapeExprNode>()) { + MatchShapeTodoItem item; + item.input = value; + item.pattern = shape_expr->values; + item.err_ctx = err_ctx; + match_todos->push_back(item); + } else if (op->shape.as<VarNode>()) { + // NOTE: This part of the logic is left empty for future support as it is less common. + // Future implementors: we can emit a binding here and assert here. + LOG(FATAL) << "Cannot handle Tensor shape pattern where a var appears multiple times"; + } else { + ICHECK(!op->shape.defined()) << "Can only handle tensor shape pattern var"; + } + } + + // Internal helper function to make tuple get item. + // This function will try to simplify constant tuples + // the return value **always** have struct info. + Expr MakeTupleGetItem(Expr value, int64_t index) { + if (auto* tuple_expr = value.as<TupleNode>()) { + return tuple_expr->fields[index]; + } else if (auto* tuple_sinfo = GetStructInfoAs<TupleStructInfoNode>(value)) { + // value is tuple type, it is OK to run tuple get item. + auto ret = TupleGetItem(value, index); + UpdateStructInfo(ret, tuple_sinfo->fields[index]); + return ret; + } else { + // call runtime tuple get item, and return a object. + Call call(builtin_tuple_getitem_, {value, PrimValue::Int64(index)}, Attrs(), {object_sinfo_}); + UpdateStructInfo(call, ObjectStructInfo()); + return call; + } + } + + void VisitStructInfo_(const TupleStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) final { + auto* value_tinfo = GetStructInfoAs<TupleStructInfoNode>(value); + if (value_tinfo) { + CHECK_EQ(value_tinfo->fields.size(), op->fields.size()) + << "TypeError: " << err_ctx << " during match-cast we find tuple size mismatch"; + } + if (always_check || !value_tinfo) { + // check_tuple_info(value, tuple_size) + Call call(builtin_check_tuple_info_, + {value, PrimValue::Int64(static_cast<int64_t>(op->fields.size())), + GetErrContext(err_ctx)}, + Attrs(), {void_sinfo_}); + builder_->Emit(call, "_"); + } + // recursively visit each sub-field and run matching + for (size_t i = 0; i < op->fields.size(); ++i) { + this->VisitStructInfo(op->fields[i], MakeTupleGetItem(value, i), always_check, err_ctx, + match_todos); + } + } + + void VisitStructInfo_(const FuncStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) final { + // we only check function is callable. + if (!always_check && MatchStructInfo<FuncStructInfo>(value)) return; + // check_func_info(value, err_ctx) + Call call(builtin_check_func_info_, {value, GetErrContext(err_ctx)}, Attrs(), {void_sinfo_}); + builder_->Emit(call, "_"); + } + + //------------------------------------------------------- + // Private member fields. + //------------------------------------------------------- + /*! \brief whether to emit error context, can be turned off for testing purposes. */ + bool emit_err_ctx_{true}; + /*! \brief heap ptr to store the PrimExpr slots. */ + Var shape_heap_; + /*! \brief heap size. */ + IntImm heap_size_; + /*! \brief index => slot. */ + std::vector<std::unique_ptr<PrimExprSlot>> slot_vec_; + /*! \brief Expr => slot. */ + PrimExprSlotMap slot_map_; + /*! + * \brief List of vars that are being defined but + * have not go through outstanding shape compute check. + */ + std::vector<PrimExprSlot*> ready_vars_; + // call builtin cop + const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); + const Op& null_value_op_ = Op::Get("relax.null_value"); + // common struct info + const StructInfo object_sinfo_ = ObjectStructInfo(); + const StructInfo void_sinfo_ = TupleStructInfo(Array<StructInfo>({})); + // check function + const ExternFunc builtin_alloc_shape_heap_{"vm.builtin.alloc_shape_heap"}; + const ExternFunc builtin_match_shape_{"vm.builtin.match_shape"}; + const ExternFunc builtin_make_shape_{"vm.builtin.make_shape"}; + const ExternFunc builtin_check_shape_info_{"vm.builtin.check_shape_info"}; + const ExternFunc builtin_check_tensor_info_{"vm.builtin.check_tensor_info"}; + const ExternFunc builtin_check_tuple_info_{"vm.builtin.check_tuple_info"}; + const ExternFunc builtin_check_func_info_{"vm.builtin.check_func_info"}; + const ExternFunc builtin_tuple_getitem_{"vm.builtin.tuple_getitem"}; +}; + +namespace transform { + +Pass VMShapeLower(bool emit_err_ctx) { + runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = + [=](IRModule mod, PassContext pc) { return VMShapeLowerMutator::Lower(mod, emit_err_ctx); }; + return CreateModulePass(pass_func, 0, "VMShapeLower", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.VMShapeLower").set_body_typed([](bool emit_err_ctx) { + return VMShapeLower(emit_err_ctx); +}); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc new file mode 100644 index 0000000000..1b077d8b88 --- /dev/null +++ b/src/relax/ir/transform.cc @@ -0,0 +1,413 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/ir/transform.cc + * \brief Relax specific transformation passes. + */ +#include <dmlc/thread_local.h> +#include <tvm/node/repr_printer.h> +#include <tvm/relax/analysis.h> +#include <tvm/relax/expr_functor.h> +#include <tvm/relax/struct_info_functor.h> +#include <tvm/relax/transform.h> +#include <tvm/relay/function.h> +#include <tvm/runtime/registry.h> +namespace tvm { +namespace relax { +namespace transform { + +TVM_REGISTER_PASS_CONFIG_OPTION("relax.fallback_device_type", IntImm); + +// TODO(@yuchen): will need to dedup with FunctionPass in Relay when we upstream +class FunctionPass; + +/*! + * \brief Function-level passes are used to implement various global + * optimizations for a given Relax IRModule. It fetches one function at a time + * from the function list in the IRModule for optimization. + * + * Note that the scope of passes at this level is a Relax function. Therefore, + * we cannot add or delete a function through these passes as they are not aware + * of the global information. + */ +class FunctionPassNode : public tvm::transform::PassNode { + public: + /* \brief The pass meta data.*/ + PassInfo pass_info; + + /*! \brief The packed pass function sketches the real optimization. For + * instance, we can implement a pass that works on a Relax function as a + * `pass_func` and let it run on a given IRModule. The same `pass_func` will + * then be applied on each function in the IRModule. + */ + runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func; + + FunctionPassNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } + + /*! + * \brief Run a function pass on given pass context. + * + * \param mod The IRModule that an optimization pass is applied on. + * \param pass_ctx The context that an optimization pass executes on. + * + * \return Return the updated IRModule. + */ + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; + + /*! + * \brief Get the pass information/meta data. + */ + PassInfo Info() const override { return pass_info; } + + static constexpr const char* _type_key = "relax.FunctionPass"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode); + + private: + /* + * \brief Check if a function should be skipped for optimization. + * + * \param func The target function to be checked. + * + * \return Return true if the function will be skipped, otherwise false. + */ + bool SkipFunction(const Function& func) const; +}; + +class FunctionPass : public Pass { + public: + /*! + * \brief The constructor + * \param pass_func The packed function which implements a pass. + * \param pass_info The pass info. + */ + TVM_DLL FunctionPass( + runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func, + PassInfo pass_info); + + TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode); +}; + +FunctionPass::FunctionPass( + runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func, + PassInfo pass_info) { + auto n = make_object<FunctionPassNode>(); + n->pass_func = std::move(pass_func); + n->pass_info = std::move(pass_info); + data_ = std::move(n); +} + +// Perform IRModule -> IRModule optimizations at the Function level. +IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { + DiagnosticContext previous = DiagnosticContext::Default(mod); + + if (pass_ctx->diag_ctx) { + DiagnosticContext tmp = pass_ctx->diag_ctx.value(); + pass_ctx->diag_ctx = previous; + previous = tmp; + } else { + pass_ctx->diag_ctx = previous; + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block this is a bug."; + + const PassInfo& pass_info = Info(); + + ICHECK(mod.defined()); + + VLOG_CONTEXT << pass_info->name; + VLOG(0) << "Executing function pass with opt level: " << pass_info->opt_level; + VLOG(1) << "Input module:" << std::endl << mod; + + IRModule updated_mod = mod->ShallowCopy(); + + std::vector<std::pair<GlobalVar, Function> > updates; + for (const auto& it : updated_mod->functions) { + // only picks up relax::Function + if (auto* n = it.second.as<FunctionNode>()) { + Function func = GetRef<Function>(n); + auto updated_func = SkipFunction(func) ? func : pass_func(func, updated_mod, pass_ctx); + updates.push_back({it.first, updated_func}); + } + } + + for (const auto& pair : updates) { + updated_mod->Add(pair.first, pair.second, true); + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block, this is a bug."; + + pass_ctx->diag_ctx.value().Render(); + pass_ctx->diag_ctx = previous; + + VLOG(1) << "Output module:" << std::endl << updated_mod; + + return updated_mod; +} + +bool FunctionPassNode::SkipFunction(const Function& func) const { + // TODO(@yuchen): will need to revisit in the future + return (func->GetAttr<String>(relay::attr::kCompiler).defined()) || + func->GetAttr<Integer>(relay::attr::kSkipOptimization, 0) != 0; +} + +Pass CreateFunctionPass( + const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func, + int opt_level, String name, tvm::Array<String> required) { + PassInfo pass_info = PassInfo(opt_level, name, required); + return FunctionPass(pass_func, pass_info); +} + +TVM_REGISTER_NODE_TYPE(FunctionPassNode); + +TVM_REGISTER_GLOBAL("relax.transform.MakeFunctionPass") + .set_body_typed( + [](runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func, + PassInfo pass_info) { return FunctionPass(pass_func, pass_info); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch<FunctionPassNode>([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast<const FunctionPassNode*>(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run Function pass: " << info->name << " at the optimization level " + << info->opt_level; + }); + +class DataflowBlockPass; + +/*! + * \brief DataflowBlock-level passes are used to implement various dataflow block + * optimizations for a given Relax IRModule. It fetches one dataflow block at a time + * from the functions in an IRModule, and yields a rewritten DataflowBlock. + * + * Note that the scope of passes at this level is a Relax DataflowBlock. Therefore, + * we cannot modify the global scope Vars and symbolic shape Vars defined inside the dataflow block. + */ +class DataflowBlockPassNode : public tvm::transform::PassNode { + public: + /* \brief The pass meta data.*/ + PassInfo pass_info; + + /*! \brief The packed pass function sketches the real optimization. For + * instance, we can implement a pass that works on a Relax DataflowBlock as a + * `pass_func` and let it run on a given IRModule. The same `pass_func` will + * then be applied on each DataflowBlock in the IRModule. + */ + runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func; + + DataflowBlockPassNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } + + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; + + PassInfo Info() const override { return pass_info; } + + static constexpr const char* _type_key = "relax.DataflowBlockPass"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockPassNode, PassNode); +}; + +/*! \brief Helper to apply the passed function to dataflow blocks.*/ +class DataflowBlockMutator : public ExprMutator { + public: + DataflowBlockMutator( + runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func, + IRModule mod, PassContext pass_ctx) + : pass_func_(pass_func), mod_(mod), pass_ctx_(pass_ctx) {} + + /*! + * \brief Rewrite the DataflowBlockNode with pass_func_ + * + * This function will check that there are no rewrites of the global scope Vars + * and symbolic shape Vars defined inside the dataflow block. + */ + BindingBlock VisitBindingBlock_(const DataflowBlockNode* n) final { + // collect Global Scope Vars and Symbolic Vars inside the DataflowBlock + Map<String, Var> global_scope_vars; + Map<String, tir::Var> symbolic_vars; + for (const Binding& binding : n->bindings) { + Var var = binding->var; + if (const auto* match_cast = binding.as<MatchCastNode>()) { + auto collected_vars = SymbolicVarCollector::Collect(match_cast->struct_info); + for (const tir::VarNode* var : collected_vars) { + symbolic_vars.Set(var->name_hint, GetRef<tir::Var>(var)); + } + } + if (!var.as<DataflowVarNode>()) { + global_scope_vars.Set(var->name_hint(), var); + } + } + + // apply pass_func_ to the DataflowBlock + DataflowBlock block = GetRef<DataflowBlock>(n); + DataflowBlock updated_block = pass_func_(block, mod_, pass_ctx_); + + // raise error if there are updates of recorded Global Scope Vars and Symbolic Vars + for (const Binding& binding : updated_block->bindings) { + Var var = binding->var; + if (const auto* match_cast = binding.as<MatchCastNode>()) { + auto collected_vars = SymbolicVarCollector::Collect(match_cast->struct_info); + for (const tir::VarNode* var : collected_vars) { + if (symbolic_vars.count(var->name_hint) > 0) { + tir::Var old_var = symbolic_vars[var->name_hint]; + ICHECK(var == old_var.get()) + << "Error: DataflowBlock Pass should not rewrite any Symbolic Var."; + symbolic_vars.erase(var->name_hint); + } + } + } + if (!var.as<DataflowVarNode>() && global_scope_vars.count(var->name_hint()) > 0) { + ICHECK(var.same_as(global_scope_vars[var->name_hint()])) + << "Error: DataflowBlock Pass should not rewrite any GlobalScope Var."; + global_scope_vars.erase(var->name_hint()); + } + } + ICHECK(global_scope_vars.empty() && symbolic_vars.empty()) + << "Error: DataflowBlock Pass should not delete any GlobalScope/Symbolic Var."; + + return std::move(updated_block); + } + + private: + class SymbolicVarCollector : public StructInfoVisitor { + public: + static std::unordered_set<const tir::VarNode*> Collect(const StructInfo& info) { + SymbolicVarCollector collector; + collector.VisitStructInfo(info); + return std::move(collector.symbolic_vars_); + } + + private: + void VisitStructInfoExprField(const PrimExpr& expr) final { + if (const tir::VarNode* sym_var = expr.as<tir::VarNode>()) { + symbolic_vars_.insert(sym_var); + } + } + + private: + std::unordered_set<const tir::VarNode*> symbolic_vars_; + }; + + runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func_; + IRModule mod_; + PassContext pass_ctx_; +}; + +class DataflowBlockPass : public Pass { + public: + /*! + * \brief The constructor + * \param pass_func The packed function which implements a pass. + * \param pass_info The pass info. + */ + TVM_DLL DataflowBlockPass( + runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func, + PassInfo pass_info); + + TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlockPass, Pass, DataflowBlockPassNode); +}; + +DataflowBlockPass::DataflowBlockPass( + runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func, + PassInfo pass_info) { + auto n = make_object<DataflowBlockPassNode>(); + n->pass_func = std::move(pass_func); + n->pass_info = std::move(pass_info); + data_ = std::move(n); +} + +// Perform IRModule -> IRModule transformations at the DataflowBlock level. +IRModule DataflowBlockPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { + DiagnosticContext previous = DiagnosticContext::Default(mod); + + if (pass_ctx->diag_ctx) { + DiagnosticContext tmp = pass_ctx->diag_ctx.value(); + pass_ctx->diag_ctx = previous; + previous = tmp; + } else { + pass_ctx->diag_ctx = previous; + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block, this is a bug."; + + const PassInfo& pass_info = Info(); + + ICHECK(mod.defined()); + + VLOG_CONTEXT << pass_info->name; + VLOG(0) << "Executing DataflowBlock pass with opt level: " << pass_info->opt_level; + VLOG(1) << "Input module:" << std::endl << mod; + + IRModule updated_mod = mod->ShallowCopy(); + + DataflowBlockMutator dataflow_block_mutator(pass_func, updated_mod, pass_ctx); + std::vector<std::pair<GlobalVar, Function> > updates; + for (const auto& it : updated_mod->functions) { + // only picks up relax::Function + if (auto* n = it.second.as<FunctionNode>()) { + Function func = GetRef<Function>(n); + Function updated_func = Downcast<Function>(dataflow_block_mutator.VisitExpr(func)); + updates.push_back({it.first, updated_func}); + } + } + + for (const auto& pair : updates) { + updated_mod->Add(pair.first, pair.second, true); + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block this is a bug."; + + pass_ctx->diag_ctx.value().Render(); + pass_ctx->diag_ctx = previous; + + VLOG(1) << "Output module:" << std::endl << updated_mod; + + return updated_mod; +} + +Pass CreateDataflowBlockPass( + const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)>& pass_func, + int opt_level, String name, tvm::Array<String> required) { + PassInfo pass_info = PassInfo(opt_level, name, required); + return DataflowBlockPass(pass_func, pass_info); +} + +TVM_REGISTER_NODE_TYPE(DataflowBlockPassNode); + +TVM_REGISTER_GLOBAL("relax.transform.MakeDataflowBlockPass") + .set_body_typed( + [](runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func, + PassInfo pass_info) { return DataflowBlockPass(pass_func, pass_info); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch<DataflowBlockPassNode>([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast<const DataflowBlockPassNode*>(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run DataflowBlock pass: " << info->name << " at the optimization level " + << info->opt_level; + }); +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py new file mode 100644 index 0000000000..0bf0f175dd --- /dev/null +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -0,0 +1,429 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm.script +import tvm.testing +from tvm import relax +from tvm.ir import assert_structural_equal +from tvm.relax.testing.runtime_builtin import MakeShapeCode, MatchShapeCode +from tvm.script import relax as R +from tvm.script import tir as T + + +def test_const_shape_arg(): + MS = MatchShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Shape([1, 2]), y: R.Shape): + return x + + @T.prim_func + def extra_func(H: T.Buffer(T.int64(4), "int64")): + """Extra function, checks if the pass preserves it.""" + H[T.int64(1)] = H[T.int64(0)] + T.int64(1) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Shape([1, 2]), y: R.Shape): + shape_heap = R.null_value() + _ = R.call_packed("vm.builtin.check_shape_info", x, 2, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed("vm.builtin.check_shape_info", y, -1, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed( + "vm.builtin.match_shape", + x, + shape_heap, + 2, + MS.ASSERT_EQUAL_TO_IMM, + 1, + MS.ASSERT_EQUAL_TO_IMM, + 2, + "", + sinfo_args=[R.Tuple()], + ) + return x + + @T.prim_func + def extra_func(H: T.Buffer(T.int64(4), "int64")): + H[T.int64(1)] = H[T.int64(0)] + T.int64(1) + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +def test_static_fn_check(): + """Check static shape and function.""" + MS = MatchShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main(f: R.Callable([R.Object], R.Object), y: R.Shape([1, 2])): + return y + + @tvm.script.ir_module + class Expected: + @R.function + def main(f: R.Callable([R.Object], R.Object), y: R.Shape([1, 2])): + shape_heap = R.null_value() + _ = R.call_packed("vm.builtin.check_func_info", f, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed("vm.builtin.check_shape_info", y, 2, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed( + "vm.builtin.match_shape", + y, + shape_heap, + 2, + MS.ASSERT_EQUAL_TO_IMM, + 1, + MS.ASSERT_EQUAL_TO_IMM, + 2, + "", + sinfo_args=[R.Tuple()], + ) + return y + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +def test_simple_symbolic_shape(): + MS = MatchShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor(["n", 2, "m"], "float32")): + return x + + sindex = { + "n": 0, + "m": 1, + } + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(["n", 2, "m"], "float32")): + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [R.prim_value(2)], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ) + _ = R.call_packed( + "vm.builtin.check_tensor_info", x, 3, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + ) + _ = R.call_packed( + "vm.builtin.match_shape", + x, + shape_heap, + 3, + MS.STORE_TO_HEAP, + sindex["n"], + MS.ASSERT_EQUAL_TO_IMM, + 2, + MS.STORE_TO_HEAP, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + return x + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +def test_symbolic_compute(): + MS = MatchShapeCode + MK = MakeShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None) + ) -> R.Shape(ndim=3): + n = T.Var("n", "int64") + k = T.Var("k", "int64") + z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None)) + return (k + 1, m, 2) + + # slot assignment: + # 0: n, 1: m, 2:k, 3: k+1 + sindex = {"n": 0, "m": 1, "k": 2, "k+1": 3} + + @tvm.script.ir_module + class Expected: + @T.prim_func + def shape_func(H: T.Buffer(T.int64(4), "int64")): + # generated compute function + H[T.int64(sindex["k+1"])] = H[T.int64(sindex["k"])] + T.int64(1) + + @R.function + def main( + x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None) + ) -> R.Shape(ndim=3): + n = T.Var("n", "int64") + k = T.Var("k", "int64") + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [R.prim_value(4)], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ) + _ = R.call_packed( + "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + ) + _ = R.call_packed( + "vm.builtin.check_tensor_info", y, 3, R.dtype(""), "", sinfo_args=[R.Tuple()] + ) + _ = R.call_packed( + "vm.builtin.match_shape", + x, + shape_heap, + 2, + MS.STORE_TO_HEAP, + sindex["n"], + MS.STORE_TO_HEAP, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + _ = R.call_packed( + "vm.builtin.match_shape", + y, + shape_heap, + 3, + MS.STORE_TO_HEAP, + sindex["k"], + MS.ASSERT_EQUAL_TO_LOAD, + sindex["m"], + MS.NO_OP, + 0, + "", + sinfo_args=[R.Tuple()], + ) + _ = shape_func(shape_heap) + # extra assertion on y's shape after shape computation + _ = R.call_packed( + "vm.builtin.match_shape", + y, + shape_heap, + 3, + MS.ASSERT_EQUAL_TO_LOAD, + sindex["k"], + MS.ASSERT_EQUAL_TO_LOAD, + sindex["m"], + MS.ASSERT_EQUAL_TO_LOAD, + sindex["k+1"], + "", + sinfo_args=[R.Tuple()], + ) + z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None)) + # construct shape value for return + s = R.call_packed( + "vm.builtin.make_shape", + shape_heap, + 3, + MK.LOAD_SHAPE, + sindex["k+1"], + MK.LOAD_SHAPE, + sindex["m"], + MK.USE_IMM, + 2, + sinfo_args=[R.Shape(ndim=3)], + ) + return s + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +def test_tuple_handling(): + MS = MatchShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tuple( + R.Tensor(["n", "m"], "float32"), R.Tuple(R.Shape, R.Tensor(["n", "k"], "int32")) + ) + ): + return x + + # slot assignment: + sindex = {"n": 0, "m": 1, "k": 2} + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tuple( + R.Tensor(["n", "m"], "float32"), R.Tuple(R.Shape, R.Tensor(["n", "k"], "int32")) + ) + ): + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [R.prim_value(3)], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ) + # recursively unpack tuple for static info check + _ = R.call_packed("vm.builtin.check_tuple_info", x, 2, "", sinfo_args=[R.Tuple()]) + t0 = x[0] + _ = R.call_packed( + "vm.builtin.check_tensor_info", + t0, + 2, + R.dtype("float32"), + "", + sinfo_args=[R.Tuple()], + ) + t1 = x[1] + _ = R.call_packed("vm.builtin.check_tuple_info", t1, 2, "", sinfo_args=[R.Tuple()]) + t1x0 = t1[0] + _ = R.call_packed("vm.builtin.check_shape_info", t1x0, -1, "", sinfo_args=[R.Tuple()]) + t1x1 = t1[1] + _ = R.call_packed( + "vm.builtin.check_tensor_info", + t1x1, + 2, + R.dtype("int32"), + "", + sinfo_args=[R.Tuple()], + ) + # match shape checks. + _ = R.call_packed( + "vm.builtin.match_shape", + t0, + shape_heap, + 2, + MS.STORE_TO_HEAP, + sindex["n"], + MS.STORE_TO_HEAP, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + _ = R.call_packed( + "vm.builtin.match_shape", + t1x1, + shape_heap, + 2, + MS.ASSERT_EQUAL_TO_LOAD, + sindex["n"], + MS.STORE_TO_HEAP, + sindex["k"], + "", + sinfo_args=[R.Tuple()], + ) + return x + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +def test_return_match_check(): + """Test when return body is not same as ret_struct_info, runtime match check needed.""" + MS = MatchShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["n", "m"], "float32"), y: R.Object + ) -> R.Tuple(R.Tensor(["n", "m"], "float32")): + return y + + # slot assignment: + sindex = { + "n": 0, + "m": 1, + } + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(["n", "m"], "float32"), y: R.Object + ) -> R.Tuple(R.Tensor(["n", "m"], "float32")): + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [R.prim_value(2)], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ) + _ = R.call_packed( + "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + ) + _ = R.call_packed( + "vm.builtin.match_shape", + x, + shape_heap, + 2, + MS.STORE_TO_HEAP, + sindex["n"], + MS.STORE_TO_HEAP, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + _ = R.call_packed("vm.builtin.check_tuple_info", y, 1, "", sinfo_args=[R.Tuple()]) + # emit runtime function call since y do not have the right type. + y1 = R.call_packed("vm.builtin.tuple_getitem", y, 0, sinfo_args=[R.Object]) + # run check + _ = R.call_packed( + "vm.builtin.check_tensor_info", + y1, + 2, + R.dtype("float32"), + "", + sinfo_args=[R.Tuple()], + ) + # shape check + _ = R.call_packed( + "vm.builtin.match_shape", + y1, + shape_heap, + 2, + MS.ASSERT_EQUAL_TO_LOAD, + sindex["n"], + MS.ASSERT_EQUAL_TO_LOAD, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + + return y + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +if __name__ == "__main__": + tvm.testing.main()