This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity-staging in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 3928e9b63ca3eb954ce37d7f4338028f7144baa8 Author: Sunghyun Park <49998730+sun...@users.noreply.github.com> AuthorDate: Thu Feb 16 19:35:03 2023 -0800 [Unity][Pass] BindParams pass, FoldConstant pass (#14016) This PR introduces FoldConstant/BindParam passes. Co-authored-by: Yong Wu <yongc...@gmail.com> Co-Authored-by: Hongyi Jin <3231950...@qq.com> Co-Authored-by: Siyuan Feng <hzfen...@sjtu.edu.cn> --- include/tvm/ir/function.h | 133 ++++++---- include/tvm/relax/transform.h | 15 ++ python/tvm/relax/transform/transform.py | 62 ++++- src/relax/transform/bind_params.cc | 113 +++++++++ src/relax/transform/fold_constant.cc | 230 +++++++++++++++++ tests/python/relax/test_transform_bind_params.py | 75 ++++++ tests/python/relax/test_transform_fold_constant.py | 280 +++++++++++++++++++++ 7 files changed, 861 insertions(+), 47 deletions(-) diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 1493544e73..381ea6b8d6 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -65,6 +65,68 @@ enum class CallingConv : int { kDeviceKernelLaunch = 2, }; +/*! + * \brief Supported linkage types. + */ +enum class LinkageType : int { + /*! + * \brief Internal linkage. + */ + kInternal = 0, + /*! + * \brief External linkage. + - Function with external linkage should have a global symbol attached to it. + */ + kExternal = 1 +}; + +/*! + * \brief Generic attribute names that can be attached to any function. + * + * \sa tvm::tir::attr, tvm::relay::attr + */ +namespace attr { +/*! + * \brief Indicates the special calling convention. + * + * Type: Integer + * + * \sa tvm::CallingConv + */ +constexpr const char* kCallingConv = "calling_conv"; + +/*! + * \brief Compilation target of the function. + * + * Type: Target + * + * \sa tvm::Target + */ +constexpr const char* kTarget = "target"; + +/*! + * \brief Global linker symbol of the function in generated code. + * + * This option forces the code generator to name the + * function with the given. + * + * For example, we could set a global_symbol of a function + * early to make sure that we can always refer to it by + * the symbol name in the generated DLL. + * + * We should not set the attribute for local functions, + * so that the compiler can freely rename them. + * + * A unique global symbol will be automatically assigned + * to each function in the module before the target code + * generation phase. + * + * Type: String + */ +constexpr const char* kGlobalSymbol = "global_symbol"; + +} // namespace attr + /*! * \brief Base node of all functions. * @@ -130,6 +192,31 @@ class BaseFuncNode : public RelayExprNode { * \endcode */ bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); } + /*! + * \brief Get the type of the linkage. + * + * Currently, we only consider external/internal linkage. + * This can be extended in the future when necessary. + * + * \return Linkage type. + * + * \code + * + * void Example(const BaseFunc& f) { + * if (f->GetLinkageType() == tvm::LinkageType::kExternal) { + * // Do not remove a function with external linkage + * } + * } + * + * \endcode + */ + + LinkageType GetLinkageType() const { + if (GetAttr<String>(attr::kGlobalSymbol)) + return LinkageType::kExternal; + else + return LinkageType::kInternal; + } static constexpr const char* _type_key = "BaseFunc"; static constexpr const uint32_t _type_child_slots = 2; @@ -145,51 +232,5 @@ class BaseFunc : public RelayExpr { TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); }; -/*! - * \brief Generic attribute names that can be attached to any function. - * - * \sa tvm::tir::attr, tvm::relay::attr - */ -namespace attr { -/*! - * \brief Indicates the special calling convention. - * - * Type: Integer - * - * \sa tvm::CallingConv - */ -constexpr const char* kCallingConv = "calling_conv"; - -/*! - * \brief Compilation target of the function. - * - * Type: Target - * - * \sa tvm::Target - */ -constexpr const char* kTarget = "target"; - -/*! - * \brief Global linker symbol of the function in generated code. - * - * This option forces the code generator to name the - * function with the given. - * - * For example, we could set a global_symbol of a function - * early to make sure that we can always refer to it by - * the symbol name in the generated DLL. - * - * We should not set the attribute for local functions, - * so that the compiler can freely rename them. - * - * A unique global symbol will be automatically assigned - * to each function in the module before the target code - * generation phase. - * - * Type: String - */ -constexpr const char* kGlobalSymbol = "global_symbol"; - -} // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index ff98b16d25..dab062588a 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -99,7 +99,22 @@ TVM_DLL Pass RewriteDataflowReshape(); * \return The Pass. */ TVM_DLL Pass AttachGlobalSymbol(); +/*! + * \brief Bind params of function of the module to constant tensors. + * + * \param func_name The name of the function to bind parameters. + * \param params The parameters to bind. + * + * \return The Pass. + */ +TVM_DLL Pass BindParams(String func_name, Map<String, runtime::NDArray> params); +/*! + * \brief Fold constant expressions. + * + * \return The Pass. + */ +TVM_DLL Pass FoldConstant(); } // namespace transform } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 1a525431dd..745a26a4da 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -19,7 +19,8 @@ import functools import inspect import types -from typing import Callable, Union +from typing import Callable, Dict, Union, Optional, List +import numpy as np # type: ignore import tvm.ir from . import _ffi_api @@ -115,6 +116,65 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass: return _ffi_api.AttachGlobalSymbol() # type: ignore +def BindParams( + func_name: str, + params: Dict[str, Union[tvm.runtime.NDArray, np.ndarray]], +) -> tvm.ir.transform.Pass: + """Bind params of function of the module to constant tensors. + + Parameters + ---------- + + func_name: str + The function name to be bound + + params : Dict[str, Union[tvm.runtime.NDArray, np.ndarray]] + The map from param name to constant tensors. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + tvm_params = {} + for k, v in params.items(): + if isinstance(v, np.ndarray): + v = tvm.nd.array(v) + assert isinstance( + v, tvm.runtime.NDArray + ), f"param values are expected to be TVM.NDArray or numpy.ndarray, but got {type(v)}" + tvm_params[k] = v + + return _ffi_api.BindParams(func_name, tvm_params) # type: ignore + + +def RemoveUnusedFunctions(entry_functions: Optional[List[str]] = None) -> tvm.ir.transform.Pass: + """Remove unused relax/prim functions without external linkage in a IRModule. + + Parameters + ---------- + entry_functions: Optional[List[str]] + The set of entry functions to start from. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass to remove unused functions. + """ + if entry_functions is None: + entry_functions = ["main"] + return _ffi_api.RemoveUnusedFunctions(entry_functions) # type: ignore + + +def FoldConstant() -> tvm.ir.transform.Pass: + """Fold constant expressions. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.FoldConstant() # type: ignore + + def AnnotateTIROpPattern() -> tvm.ir.transform.Pass: """Annotate Op Pattern Kind for TIR functions diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc new file mode 100644 index 0000000000..1de8d94461 --- /dev/null +++ b/src/relax/transform/bind_params.cc @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/driver/driver_api.h> +#include <tvm/ir/function.h> +#include <tvm/relax/expr.h> +#include <tvm/relax/expr_functor.h> +#include <tvm/relax/transform.h> +#include <tvm/relax/type.h> +#include <tvm/tir/op.h> + +#include <utility> + +namespace tvm { +namespace relax { + +/*! + * \brief Bind params to function by using name + * \param func Relax function + * \param params params dict + * \return Function + */ +inline Function BindParamsByName(Function func, const Map<String, runtime::NDArray>& params) { + std::unordered_map<std::string, Var> name_dict; + std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> repeat_var; + for (auto arg : func->params) { + const auto& name = arg->name_hint(); + if (name_dict.count(name)) { + repeat_var.insert(name_dict[name]); + } else { + name_dict[name] = arg; + } + } + + std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> bind_dict; + for (auto& kv : params) { + if (name_dict.count(kv.first) == 0) { + continue; + } + auto arg = name_dict.at(kv.first); + if (repeat_var.count(arg)) { + LOG(FATAL) << "ValueError: Multiple args in the function have name " << kv.first; + } + bind_dict[arg] = Constant(kv.second); + } + Expr bound_expr = Bind(func, bind_dict); + Function ret = Downcast<Function>(bound_expr); + ICHECK(ret.defined()) << "The returning type is expected to be a Relax Function." + << "\n"; + return ret; +} + +/*! + * \brief Bind params to a specific function in a module + * \param m The module + * \param func_name The name of the specific function + * \param param The param dict + * \return The module after binding params. + */ +IRModule BindParam(IRModule m, String func_name, Map<String, runtime::NDArray> param) { + IRModuleNode* new_module = m.CopyOnWrite(); + Map<GlobalVar, BaseFunc> functions = m->functions; + for (const auto& func_pr : functions) { + if (const auto* relax_f = func_pr.second.as<FunctionNode>()) { + if (relax_f->GetLinkageType() == LinkageType::kExternal) { + // Use global_symbol if it's external linkage + Optional<String> gsymbol = relax_f->GetAttr<String>(tvm::attr::kGlobalSymbol); + if (gsymbol.defined() && gsymbol.value() == func_name) { + Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f), param); + new_module->Update(func_pr.first, f_after_bind); + } + } else { + // Use global var's name_hint if it's internal linkage + if (func_pr.first->name_hint == func_name) { + Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f), param); + new_module->Update(func_pr.first, f_after_bind); + } + } + } + } + return GetRef<IRModule>(new_module); +} + +namespace transform { + +Pass BindParams(String func_name, Map<String, runtime::NDArray> params) { + runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = + [=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), func_name, params); }; + return CreateModulePass(pass_func, 0, "BindParams", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.BindParams").set_body_typed(BindParams); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc new file mode 100644 index 0000000000..aa55ee7f7e --- /dev/null +++ b/src/relax/transform/fold_constant.cc @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/driver/driver_api.h> +#include <tvm/ir/function.h> +#include <tvm/relax/expr_functor.h> +#include <tvm/relax/transform.h> +#include <tvm/relax/type.h> +#include <tvm/tir/function.h> +#include <tvm/tir/op.h> + +namespace tvm { +namespace relax { + +class ConstantFolder : public ExprMutator { + public: + explicit ConstantFolder(IRModule ctx_module) : ctx_module_(ctx_module) {} + + private: + /*! + * \brief Pattern match the shape inside the given struct info to a + * constant shape and get runtime shape tuple from it. + * \param struct_info The given struct info whose shape inside is to be casted. + * \return The runtime shape tuple, or nullopt if it is not a constant shape. + * \note Only TensorStructInfo is supported at this moment. Return NullOpt + * if the input struct info is not TensorStructInfo. + */ + static Optional<runtime::ShapeTuple> MatchConstShape(const StructInfo& struct_info) { + // Only support single output for call_tir at this moment. + const auto* tensor_sinfo = struct_info.as<TensorStructInfoNode>(); + if (tensor_sinfo == nullptr) { + return NullOpt; + } + + const auto* shape = tensor_sinfo->shape.as<ShapeExprNode>(); + ICHECK(shape != nullptr) << "struct info given by call_tir should have ShapeExpr shape"; + + std::vector<int64_t> shape_values; + for (const auto v : shape->values) { + auto* ptr = v.as<IntImmNode>(); + if (!ptr) return NullOpt; + shape_values.push_back(ptr->value); + } + return runtime::ShapeTuple(shape_values.begin(), shape_values.end()); + } + + /*! + * \brief Pattern match op to constant array arguments. + * \return The constant array arguments, or nullopt if match fails. + */ + static Optional<Array<runtime::NDArray>> MatchConstArrayArgs(const Array<Expr>& args) { + Array<runtime::NDArray> res; + for (auto arg : args) { + auto* ptr = arg.as<relax::ConstantNode>(); + if (!ptr) return NullOpt; + res.push_back(ptr->data); + } + return res; + } + + /*! + * \brief Pattern match op to a TIR function and look it up. + * \return The TIR function, or nullopt if pattern match fails. + */ + Optional<tir::PrimFunc> MatchPrimFunc(const Expr& op) { + if (auto* ptr = op.as<GlobalVarNode>()) { + // NOTE: as check works for nullptr(returns null) + Optional<BaseFunc> base_func = ctx_module_->functions.Get(GetRef<GlobalVar>(ptr)); + if (auto* pfunc = base_func.as<tir::PrimFuncNode>()) { + return GetRef<tir::PrimFunc>(pfunc); + } + } + return NullOpt; + } + + /*! + * \brief Get a cached build version of func + * \return The cached func, nullopt if func cannot be built. + */ + Optional<PackedFunc> GetCachedBuild(tir::PrimFunc func) { + // TODO(tvm-team): consider another way of bulk extract and build PrimFunc once + // would be helpful for future cases where PrimFunc recursively call into each other + Target eval_cpu_target{"llvm"}; + + auto it = func_build_cache_.find(func); + if (it != func_build_cache_.end()) { + return it->second; + } + Optional<PackedFunc> build_func = NullOpt; + + try { + // Not all the primfunc can be directly built via llvm, for example, if a function is + // already scheduled to only work on GPU, we will need to skip this in the const folder for + // now + // TODO(Hongyi): further check and narrow the scope of foldable function + runtime::Module rt_module = + build(LowerPrimFunc(func, "tir_function"), eval_cpu_target, eval_cpu_target); + build_func = rt_module.GetFunction("tir_function"); + } catch (const tvm::Error& err) { + // build failure may happen in which case we skip + DLOG(WARNING) << "Build failure for function " << func << ", Error message: " << err.what(); + } + func_build_cache_[func] = build_func; + return build_func; + } + + // Try constant evaluate the function call + // if failed return NullOpt + Optional<Expr> ConstEvaluateCallTIR(tir::PrimFunc tir_func, Array<runtime::NDArray> arr_args, + runtime::ShapeTuple shape, DataType ret_type) { + // obtain function from the cache. + Optional<PackedFunc> func = GetCachedBuild(tir_func); + if (!func) return NullOpt; + + // here the vector size has an additional + 1 because we need to put ret_tensor at the end + std::vector<TVMValue> values(arr_args.size() + 1); + std::vector<int> type_codes(arr_args.size() + 1); + + DLDevice cpu_dev = {DLDeviceType::kDLCPU, 0}; + runtime::NDArray ret_tensor = runtime::NDArray::Empty(shape, ret_type, cpu_dev); + + // avoid set rvalue ref which get de-allocated later, store args in a vector + // where temp_args[i] are lvalue ref that is stable + std::vector<runtime::NDArray> temp_args(arr_args.begin(), arr_args.end()); + + size_t arg_offset = 0; + for (; arg_offset < arr_args.size(); ++arg_offset) { + runtime::TVMArgsSetter(values.data(), type_codes.data())(arg_offset, temp_args[arg_offset]); + } + // set return value + runtime::TVMArgsSetter(values.data(), type_codes.data())(arg_offset++, ret_tensor); + + TVMRetValue ret; + // invoke + func.value().CallPacked(TVMArgs(values.data(), type_codes.data(), values.size()), &ret); + return Constant(ret_tensor); + } + + Expr VisitCallTIR(Call call) { + // call_tir needs to have at least three arguments + ICHECK_GE(call->args.size(), 2); + Optional<tir::PrimFunc> func = MatchPrimFunc(call->args[0]); + ICHECK(call->args[1].as<TupleNode>()) << "call_tir.args[1] must be Tuple"; + Optional<Array<runtime::NDArray>> arr_args = + MatchConstArrayArgs(call->args[1].as<TupleNode>()->fields); + ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one sinfo arg"; + Optional<runtime::ShapeTuple> shape = MatchConstShape(call->sinfo_args[0]); + bool output_not_tuple = call->sinfo_args.size() == 1; + // Pattern 0: call constant function, const argument with const shape. + if (func && arr_args && shape && output_not_tuple) { + DynTensorType ret_type = Downcast<DynTensorType>(call->checked_type()); + // value_or will return value if it is not null, otherwise return or + return ConstEvaluateCallTIR(func.value(), arr_args.value(), shape.value(), ret_type->dtype) + .value_or(call); + } + // TODO(hongyi): support const-fold tuple outputs + return std::move(call); + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call) final { + // post-order mutation + Call post_call = Downcast<Call>(VisitExprPostOrder_(call)); + static const Op& call_tir_op = Op::Get("relax.call_tir"); + + if (call->op.same_as(call_tir_op)) { + return VisitCallTIR(post_call); + } + return std::move(post_call); + } + + Expr VisitExpr_(const DataflowVarNode* op) final { + Optional<Expr> opt = LookupBinding(GetRef<Var>(op)); + // `as` check checks if opt is not null and is instance of constant + if (opt.as<relax::ConstantNode>()) { + return opt.value(); + } + return ExprMutator::VisitExpr_(op); + } + + Expr VisitExpr_(const VarNode* op) final { + Optional<Expr> opt = LookupBinding(GetRef<Var>(op)); + // `as` check checks if opt is not null and is instance of constant + if (opt.as<relax::ConstantNode>()) { + return opt.value(); + } + return ExprMutator::VisitExpr_(op); + } + + // the context module to lookup functions + IRModule ctx_module_; + // cache for function build, via structural equality + std::unordered_map<tir::PrimFunc, Optional<runtime::PackedFunc>, StructuralHash, StructuralEqual> + func_build_cache_; +}; + +namespace transform { + +Pass FoldConstant() { + runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = + [=](Function f, IRModule m, PassContext pc) { + ConstantFolder folder(m); + return Downcast<Function>(folder(f)); + }; + return CreateFunctionPass(pass_func, 0, "FoldConstant", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FoldConstant").set_body_typed(FoldConstant); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_bind_params.py b/tests/python/relax/test_transform_bind_params.py new file mode 100644 index 0000000000..b96fb89e6c --- /dev/null +++ b/tests/python/relax/test_transform_bind_params.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import tvm +import tvm.script +import tvm.testing +from tvm import relax +from tvm.script import relax as R +from tvm.script import tir as T + +use_np_array = tvm.testing.parameter(False, True) + + +def test_bind_params(use_np_array): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + A = T.match_buffer(x, (16, 16)) + B = T.match_buffer(y, (16, 16)) + C = T.match_buffer(z, (16, 16)) + for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): + with T.block("matmul"): + vi = T.axis.S(16, i0 * 4 + i1) + vj = T.axis.S(16, j) + vk = T.axis.R(16, k0 * 4 + k1) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def main( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): + gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((16, 16), dtype="float32")) + return gv0 + + x_np = np.random.rand(16, 16).astype(np.float32) + w_np = np.random.rand(16, 16).astype(np.float32) + x_tvm = tvm.nd.array(x_np) + w_tvm = tvm.nd.array(w_np) + params_dict = {"w": w_np if use_np_array else w_tvm} + mod = relax.transform.BindParams("main", params_dict)(InputModule) + assert len(mod["main"].params) == 1 + + target = tvm.target.Target("llvm") + ex_after = relax.vm.build(mod, target) + vm_after = relax.VirtualMachine(ex_after, tvm.cpu()) + res_after = vm_after["main"](x_tvm) + + ex_before = relax.vm.build(InputModule, target) + vm_before = relax.VirtualMachine(ex_before, tvm.cpu()) + res_before = vm_before["main"](x_tvm, w_tvm) + + tvm.testing.assert_allclose(res_before.numpy(), res_after.numpy()) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py new file mode 100644 index 0000000000..32ee3e7000 --- /dev/null +++ b/tests/python/relax/test_transform_fold_constant.py @@ -0,0 +1,280 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import relax +import numpy as np + +import tvm.script +from tvm.script import tir as T, relax as R + + +def gen_mod(mod, name, binding): + """Select relax function with name, rename to main and and bind constant. + + Parameters + ---------- + mod: IRModule + The input module + + name: str + The name of relax function to preserve and rename to main + + binding: Dict[str, array] + The const parameter bindings + """ + funcs = {} + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + + for k, v in mod.functions.items(): + if isinstance(v, tvm.relax.Function): + if k.name_hint == name: + # rename to main + gv = tvm.ir.GlobalVar("main") + funcs[gv] = tvm.relax.Function(v.params, v.body, v.ret_struct_info).with_attr( + "global_symbol", "main" + ) + else: + funcs[k] = v + mod = tvm.IRModule(funcs) + return relax.transform.BindParams("main", binding)(mod) + + +def test_one_fold_addone(): + # put before after in a single module + @tvm.script.ir_module + class Module: + @T.prim_func + def addone(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"]) -> None: + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + @R.function + def before(c0: R.Tensor((16, 16), "float32")): + lv0 = relax.call_tir(addone, (c0,), R.Tensor((16, 16), dtype="float32")) + return lv0 + + @R.function + def expected(c1: R.Tensor((16, 16), "float32")): + lv0 = c1 + return c1 + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c1_np = c0_np + 1 + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_one_fold_transpose(): + # put before after in a single module + @tvm.script.ir_module + class Module: + @T.prim_func + def func(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), "float32"]) -> None: + for i, j in T.grid(3, 2): + with T.block("transpose"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vj, vi] + + @R.function + def before(c0: R.Tensor((2, 3), "float32")): + lv0 = relax.call_tir(func, (c0,), R.Tensor((3, 2), dtype="float32")) + return lv0 + + @R.function + def expected(c1: R.Tensor((3, 2), "float32")): + lv0 = c1 + return c1 + + c0_np = np.arange(2 * 3).astype("float32").reshape(2, 3) + c1_np = c0_np.T + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_two_hop_addone(): + @tvm.script.ir_module + class Module: + @T.prim_func + def addone(A: T.Buffer[(2, 2), "float32"], B: T.Buffer[(2, 2), "float32"]) -> None: + for i, j in T.grid(2, 2): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + @R.function + def before(c0: R.Tensor((2, 2), "float32")): + lv0 = relax.call_tir(addone, (c0,), R.Tensor((2, 2), dtype="float32")) + lv1 = relax.call_tir(addone, (lv0,), R.Tensor((2, 2), dtype="float32")) + return lv1 + + @R.function + def expected(c1: R.Tensor((2, 2), "float32"), c2: R.Tensor((2, 2), "float32")): + lv0 = c1 + lv1 = c2 + return c2 + + c0_np = np.arange((2 * 2)).astype("float32").reshape(2, 2) + c1_np = c0_np + 1 + c2_np = c1_np + 1 + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np, "c2": c2_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_dataflow_fold(): + @tvm.script.ir_module + class Module: + @T.prim_func + def identity(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"]) -> None: + for i, j in T.grid(16, 16): + with T.block("identity"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + + @R.function + def before(c0: R.Tensor((16, 16), "float32")): + with R.dataflow(): + gv0 = relax.call_tir(identity, (c0,), R.Tensor((16, 16), dtype="float32")) + R.output(gv0) + return gv0 + + @R.function + def expected(c1: R.Tensor((16, 16), "float32")): + with R.dataflow(): + gv0 = c1 + R.output(gv0) + return c1 + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c1_np = c0_np + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_fold_mixed_case(): + @tvm.script.ir_module + class Module: + # TIR function can handle different cases. + @T.prim_func + def addone(a: T.handle, b: T.handle) -> None: + n = T.var("int32") + m = T.var("int32") + A = T.match_buffer(a, (n, m)) + B = T.match_buffer(b, (n, m)) + for i, j in T.grid(n, m): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + @T.prim_func + def sub( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], + ) -> None: + for i, j in T.grid(16, 16): + with T.block("sub"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] - B[vi, vj] + + @R.function + def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor("float32", ndim=2)): + n, m = T.var("int64"), T.var("int64") + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) + # this line cannot be folded because n is unknown + lv0 = relax.call_tir(addone, (c0,), R.Tensor((n, 16), dtype="float32")) + # this line can be folded + lv1 = relax.call_tir(addone, (c0,), R.Tensor((16, 16), dtype="float32")) + # this line can be folded because all inputs are const + lv2 = relax.call_tir(sub, (c0, lv1), R.Tensor((16, 16), dtype="float32")) + # this line can not be folded because x's shape is unknown + lv3 = relax.call_tir(sub, (lv2, x), R.Tensor((16, 16), dtype="float32")) + return lv3 + + @R.function + def expected( + c0: R.Tensor((16, 16), "float32"), + c1: R.Tensor((16, 16), "float32"), + c2: R.Tensor((16, 16), "float32"), + x: R.Tensor("float32", ndim=2), + ) -> R.Tensor: + n, m = T.var("int64"), T.var("int64") + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) + # this line cannot be folded because n is unknown + lv0 = relax.call_tir(addone, (c0,), R.Tensor((n, 16), dtype="float32")) + # this line can be folded + lv1 = c1 + # this line can be folded because all inputs are const + lv2 = c2 + # this line can not be folded because x's shape is unknown + lv3 = relax.call_tir(sub, (c2, x), R.Tensor((16, 16), dtype="float32")) + return lv3 + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c1_np = c0_np + 1 + c2_np = c0_np - c1_np + + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c0": c0_np, "c1": c1_np, "c2": c2_np}) + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_int32_fold(): + @tvm.script.ir_module + class Module: + @T.prim_func + def addone(A: T.Buffer[(16, 16), "int32"], B: T.Buffer[(16, 16), "int32"]) -> None: + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.int32(1) + + @R.function + def before(c0: R.Tensor((16, 16), "int32")): + lv0 = relax.call_tir(addone, (c0,), R.Tensor((16, 16), dtype="int32")) + return lv0 + + @R.function + def expected(c1: R.Tensor((16, 16), "int32")): + lv0 = c1 + return c1 + + c0_np = np.arange((16 * 16)).astype("int32").reshape(16, 16) + c1_np = c0_np + 1 + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +if __name__ == "__main__": + tvm.testing.main()