This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 3928e9b63ca3eb954ce37d7f4338028f7144baa8
Author: Sunghyun Park <49998730+sun...@users.noreply.github.com>
AuthorDate: Thu Feb 16 19:35:03 2023 -0800

    [Unity][Pass] BindParams pass, FoldConstant pass (#14016)
    
    This PR introduces FoldConstant/BindParam passes.
    
    Co-authored-by: Yong Wu <yongc...@gmail.com>
    Co-Authored-by: Hongyi Jin <3231950...@qq.com>
    Co-Authored-by: Siyuan Feng <hzfen...@sjtu.edu.cn>
---
 include/tvm/ir/function.h                          | 133 ++++++----
 include/tvm/relax/transform.h                      |  15 ++
 python/tvm/relax/transform/transform.py            |  62 ++++-
 src/relax/transform/bind_params.cc                 | 113 +++++++++
 src/relax/transform/fold_constant.cc               | 230 +++++++++++++++++
 tests/python/relax/test_transform_bind_params.py   |  75 ++++++
 tests/python/relax/test_transform_fold_constant.py | 280 +++++++++++++++++++++
 7 files changed, 861 insertions(+), 47 deletions(-)

diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h
index 1493544e73..381ea6b8d6 100644
--- a/include/tvm/ir/function.h
+++ b/include/tvm/ir/function.h
@@ -65,6 +65,68 @@ enum class CallingConv : int {
   kDeviceKernelLaunch = 2,
 };
 
+/*!
+ * \brief Supported linkage types.
+ */
+enum class LinkageType : int {
+  /*!
+   * \brief Internal linkage.
+   */
+  kInternal = 0,
+  /*!
+   * \brief External linkage.
+   - Function with external linkage should have a global symbol attached to it.
+   */
+  kExternal = 1
+};
+
+/*!
+ * \brief Generic attribute names that can be attached to any function.
+ *
+ * \sa tvm::tir::attr, tvm::relay::attr
+ */
+namespace attr {
+/*!
+ * \brief Indicates the special calling convention.
+ *
+ * Type: Integer
+ *
+ * \sa tvm::CallingConv
+ */
+constexpr const char* kCallingConv = "calling_conv";
+
+/*!
+ * \brief Compilation target of the function.
+ *
+ * Type: Target
+ *
+ * \sa tvm::Target
+ */
+constexpr const char* kTarget = "target";
+
+/*!
+ * \brief Global linker symbol of the function in generated code.
+ *
+ *  This option forces the code generator to name the
+ *  function with the given.
+ *
+ *  For example, we could set a global_symbol of a function
+ *  early to make sure that we can always refer to it by
+ *  the symbol name in the generated DLL.
+ *
+ *  We should not set the attribute for local functions,
+ *  so that the compiler can freely rename them.
+ *
+ *  A unique global symbol will be automatically assigned
+ *  to each function in the module before the target code
+ *  generation phase.
+ *
+ * Type: String
+ */
+constexpr const char* kGlobalSymbol = "global_symbol";
+
+}  // namespace attr
+
 /*!
  * \brief Base node of all functions.
  *
@@ -130,6 +192,31 @@ class BaseFuncNode : public RelayExprNode {
    * \endcode
    */
   bool HasNonzeroAttr(const std::string& attr_key) const { return 
attrs.HasNonzeroAttr(attr_key); }
+  /*!
+   * \brief Get the type of the linkage.
+   *
+   * Currently, we only consider external/internal linkage.
+   * This can be extended in the future when necessary.
+   *
+   * \return Linkage type.
+   *
+   * \code
+   *
+   *  void Example(const BaseFunc& f) {
+   *    if (f->GetLinkageType() == tvm::LinkageType::kExternal) {
+   *      // Do not remove a function with external linkage
+   *    }
+   *  }
+   *
+   * \endcode
+   */
+
+  LinkageType GetLinkageType() const {
+    if (GetAttr<String>(attr::kGlobalSymbol))
+      return LinkageType::kExternal;
+    else
+      return LinkageType::kInternal;
+  }
 
   static constexpr const char* _type_key = "BaseFunc";
   static constexpr const uint32_t _type_child_slots = 2;
@@ -145,51 +232,5 @@ class BaseFunc : public RelayExpr {
   TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
 };
 
-/*!
- * \brief Generic attribute names that can be attached to any function.
- *
- * \sa tvm::tir::attr, tvm::relay::attr
- */
-namespace attr {
-/*!
- * \brief Indicates the special calling convention.
- *
- * Type: Integer
- *
- * \sa tvm::CallingConv
- */
-constexpr const char* kCallingConv = "calling_conv";
-
-/*!
- * \brief Compilation target of the function.
- *
- * Type: Target
- *
- * \sa tvm::Target
- */
-constexpr const char* kTarget = "target";
-
-/*!
- * \brief Global linker symbol of the function in generated code.
- *
- *  This option forces the code generator to name the
- *  function with the given.
- *
- *  For example, we could set a global_symbol of a function
- *  early to make sure that we can always refer to it by
- *  the symbol name in the generated DLL.
- *
- *  We should not set the attribute for local functions,
- *  so that the compiler can freely rename them.
- *
- *  A unique global symbol will be automatically assigned
- *  to each function in the module before the target code
- *  generation phase.
- *
- * Type: String
- */
-constexpr const char* kGlobalSymbol = "global_symbol";
-
-}  // namespace attr
 }  // namespace tvm
 #endif  // TVM_IR_FUNCTION_H_
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index ff98b16d25..dab062588a 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -99,7 +99,22 @@ TVM_DLL Pass RewriteDataflowReshape();
  * \return The Pass.
  */
 TVM_DLL Pass AttachGlobalSymbol();
+/*!
+ * \brief Bind params of function of the module to constant tensors.
+ *
+ * \param func_name The name of the function to bind parameters.
+ * \param params The parameters to bind.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass BindParams(String func_name, Map<String, runtime::NDArray> 
params);
 
+/*!
+ * \brief Fold constant expressions.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass FoldConstant();
 }  // namespace transform
 }  // namespace relax
 }  // namespace tvm
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 1a525431dd..745a26a4da 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -19,7 +19,8 @@
 import functools
 import inspect
 import types
-from typing import Callable, Union
+from typing import Callable, Dict, Union, Optional, List
+import numpy as np  # type: ignore
 
 import tvm.ir
 from . import _ffi_api
@@ -115,6 +116,65 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass:
     return _ffi_api.AttachGlobalSymbol()  # type: ignore
 
 
+def BindParams(
+    func_name: str,
+    params: Dict[str, Union[tvm.runtime.NDArray, np.ndarray]],
+) -> tvm.ir.transform.Pass:
+    """Bind params of function of the module to constant tensors.
+
+    Parameters
+    ----------
+
+    func_name: str
+        The function name to be bound
+
+    params : Dict[str, Union[tvm.runtime.NDArray, np.ndarray]]
+        The map from param name to constant tensors.
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    tvm_params = {}
+    for k, v in params.items():
+        if isinstance(v, np.ndarray):
+            v = tvm.nd.array(v)
+        assert isinstance(
+            v, tvm.runtime.NDArray
+        ), f"param values are expected to be TVM.NDArray or numpy.ndarray, but 
got {type(v)}"
+        tvm_params[k] = v
+
+    return _ffi_api.BindParams(func_name, tvm_params)  # type: ignore
+
+
+def RemoveUnusedFunctions(entry_functions: Optional[List[str]] = None) -> 
tvm.ir.transform.Pass:
+    """Remove unused relax/prim functions without external linkage in a 
IRModule.
+
+    Parameters
+    ----------
+    entry_functions: Optional[List[str]]
+        The set of entry functions to start from.
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered pass to remove unused functions.
+    """
+    if entry_functions is None:
+        entry_functions = ["main"]
+    return _ffi_api.RemoveUnusedFunctions(entry_functions)  # type: ignore
+
+
+def FoldConstant() -> tvm.ir.transform.Pass:
+    """Fold constant expressions.
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    return _ffi_api.FoldConstant()  # type: ignore
+
+
 def AnnotateTIROpPattern() -> tvm.ir.transform.Pass:
     """Annotate Op Pattern Kind for TIR functions
 
diff --git a/src/relax/transform/bind_params.cc 
b/src/relax/transform/bind_params.cc
new file mode 100644
index 0000000000..1de8d94461
--- /dev/null
+++ b/src/relax/transform/bind_params.cc
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/function.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/type.h>
+#include <tvm/tir/op.h>
+
+#include <utility>
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief Bind params to function by using name
+ * \param func Relax function
+ * \param params params dict
+ * \return Function
+ */
+inline Function BindParamsByName(Function func, const Map<String, 
runtime::NDArray>& params) {
+  std::unordered_map<std::string, Var> name_dict;
+  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> repeat_var;
+  for (auto arg : func->params) {
+    const auto& name = arg->name_hint();
+    if (name_dict.count(name)) {
+      repeat_var.insert(name_dict[name]);
+    } else {
+      name_dict[name] = arg;
+    }
+  }
+
+  std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> bind_dict;
+  for (auto& kv : params) {
+    if (name_dict.count(kv.first) == 0) {
+      continue;
+    }
+    auto arg = name_dict.at(kv.first);
+    if (repeat_var.count(arg)) {
+      LOG(FATAL) << "ValueError: Multiple args in the function have name " << 
kv.first;
+    }
+    bind_dict[arg] = Constant(kv.second);
+  }
+  Expr bound_expr = Bind(func, bind_dict);
+  Function ret = Downcast<Function>(bound_expr);
+  ICHECK(ret.defined()) << "The returning type is expected to be a Relax 
Function."
+                        << "\n";
+  return ret;
+}
+
+/*!
+ * \brief Bind params to a specific function in a module
+ * \param m The module
+ * \param func_name The name of the specific function
+ * \param param The param dict
+ * \return The module after binding params.
+ */
+IRModule BindParam(IRModule m, String func_name, Map<String, runtime::NDArray> 
param) {
+  IRModuleNode* new_module = m.CopyOnWrite();
+  Map<GlobalVar, BaseFunc> functions = m->functions;
+  for (const auto& func_pr : functions) {
+    if (const auto* relax_f = func_pr.second.as<FunctionNode>()) {
+      if (relax_f->GetLinkageType() == LinkageType::kExternal) {
+        // Use global_symbol if it's external linkage
+        Optional<String> gsymbol = 
relax_f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        if (gsymbol.defined() && gsymbol.value() == func_name) {
+          Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f), 
param);
+          new_module->Update(func_pr.first, f_after_bind);
+        }
+      } else {
+        // Use global var's name_hint if it's internal linkage
+        if (func_pr.first->name_hint == func_name) {
+          Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f), 
param);
+          new_module->Update(func_pr.first, f_after_bind);
+        }
+      }
+    }
+  }
+  return GetRef<IRModule>(new_module);
+}
+
+namespace transform {
+
+Pass BindParams(String func_name, Map<String, runtime::NDArray> params) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), 
func_name, params); };
+  return CreateModulePass(pass_func, 0, "BindParams", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.BindParams").set_body_typed(BindParams);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/fold_constant.cc 
b/src/relax/transform/fold_constant.cc
new file mode 100644
index 0000000000..aa55ee7f7e
--- /dev/null
+++ b/src/relax/transform/fold_constant.cc
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/function.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/type.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace relax {
+
+class ConstantFolder : public ExprMutator {
+ public:
+  explicit ConstantFolder(IRModule ctx_module) : ctx_module_(ctx_module) {}
+
+ private:
+  /*!
+   * \brief Pattern match the shape inside the given struct info to a
+   * constant shape and get runtime shape tuple from it.
+   * \param struct_info The given struct info whose shape inside is to be 
casted.
+   * \return The runtime shape tuple, or nullopt if it is not a constant shape.
+   * \note Only TensorStructInfo is supported at this moment. Return NullOpt
+   * if the input struct info is not TensorStructInfo.
+   */
+  static Optional<runtime::ShapeTuple> MatchConstShape(const StructInfo& 
struct_info) {
+    // Only support single output for call_tir at this moment.
+    const auto* tensor_sinfo = struct_info.as<TensorStructInfoNode>();
+    if (tensor_sinfo == nullptr) {
+      return NullOpt;
+    }
+
+    const auto* shape = tensor_sinfo->shape.as<ShapeExprNode>();
+    ICHECK(shape != nullptr) << "struct info given by call_tir should have 
ShapeExpr shape";
+
+    std::vector<int64_t> shape_values;
+    for (const auto v : shape->values) {
+      auto* ptr = v.as<IntImmNode>();
+      if (!ptr) return NullOpt;
+      shape_values.push_back(ptr->value);
+    }
+    return runtime::ShapeTuple(shape_values.begin(), shape_values.end());
+  }
+
+  /*!
+   * \brief Pattern match op to constant array arguments.
+   * \return The constant array arguments, or nullopt if match fails.
+   */
+  static Optional<Array<runtime::NDArray>> MatchConstArrayArgs(const 
Array<Expr>& args) {
+    Array<runtime::NDArray> res;
+    for (auto arg : args) {
+      auto* ptr = arg.as<relax::ConstantNode>();
+      if (!ptr) return NullOpt;
+      res.push_back(ptr->data);
+    }
+    return res;
+  }
+
+  /*!
+   * \brief Pattern match op to a TIR function and look it up.
+   * \return The TIR function, or nullopt if pattern match fails.
+   */
+  Optional<tir::PrimFunc> MatchPrimFunc(const Expr& op) {
+    if (auto* ptr = op.as<GlobalVarNode>()) {
+      // NOTE: as check works for nullptr(returns null)
+      Optional<BaseFunc> base_func = 
ctx_module_->functions.Get(GetRef<GlobalVar>(ptr));
+      if (auto* pfunc = base_func.as<tir::PrimFuncNode>()) {
+        return GetRef<tir::PrimFunc>(pfunc);
+      }
+    }
+    return NullOpt;
+  }
+
+  /*!
+   * \brief Get a cached build version of func
+   * \return The cached func, nullopt if func cannot be built.
+   */
+  Optional<PackedFunc> GetCachedBuild(tir::PrimFunc func) {
+    // TODO(tvm-team): consider another way of bulk extract and build PrimFunc 
once
+    // would be helpful for future cases where PrimFunc recursively call into 
each other
+    Target eval_cpu_target{"llvm"};
+
+    auto it = func_build_cache_.find(func);
+    if (it != func_build_cache_.end()) {
+      return it->second;
+    }
+    Optional<PackedFunc> build_func = NullOpt;
+
+    try {
+      // Not all the primfunc can be directly built via llvm, for example, if 
a function is
+      // already scheduled to only work on GPU, we will need to skip this in 
the const folder for
+      // now
+      // TODO(Hongyi): further check and narrow the scope of foldable function
+      runtime::Module rt_module =
+          build(LowerPrimFunc(func, "tir_function"), eval_cpu_target, 
eval_cpu_target);
+      build_func = rt_module.GetFunction("tir_function");
+    } catch (const tvm::Error& err) {
+      // build failure may happen in which case we skip
+      DLOG(WARNING) << "Build failure for function " << func << ", Error 
message: " << err.what();
+    }
+    func_build_cache_[func] = build_func;
+    return build_func;
+  }
+
+  // Try constant evaluate the function call
+  // if failed return NullOpt
+  Optional<Expr> ConstEvaluateCallTIR(tir::PrimFunc tir_func, 
Array<runtime::NDArray> arr_args,
+                                      runtime::ShapeTuple shape, DataType 
ret_type) {
+    // obtain function from the cache.
+    Optional<PackedFunc> func = GetCachedBuild(tir_func);
+    if (!func) return NullOpt;
+
+    // here the vector size has an additional + 1 because we need to put 
ret_tensor at the end
+    std::vector<TVMValue> values(arr_args.size() + 1);
+    std::vector<int> type_codes(arr_args.size() + 1);
+
+    DLDevice cpu_dev = {DLDeviceType::kDLCPU, 0};
+    runtime::NDArray ret_tensor = runtime::NDArray::Empty(shape, ret_type, 
cpu_dev);
+
+    // avoid set rvalue ref which get de-allocated later, store args in a 
vector
+    // where temp_args[i] are lvalue ref that is stable
+    std::vector<runtime::NDArray> temp_args(arr_args.begin(), arr_args.end());
+
+    size_t arg_offset = 0;
+    for (; arg_offset < arr_args.size(); ++arg_offset) {
+      runtime::TVMArgsSetter(values.data(), type_codes.data())(arg_offset, 
temp_args[arg_offset]);
+    }
+    // set return value
+    runtime::TVMArgsSetter(values.data(), type_codes.data())(arg_offset++, 
ret_tensor);
+
+    TVMRetValue ret;
+    // invoke
+    func.value().CallPacked(TVMArgs(values.data(), type_codes.data(), 
values.size()), &ret);
+    return Constant(ret_tensor);
+  }
+
+  Expr VisitCallTIR(Call call) {
+    // call_tir needs to have at least three arguments
+    ICHECK_GE(call->args.size(), 2);
+    Optional<tir::PrimFunc> func = MatchPrimFunc(call->args[0]);
+    ICHECK(call->args[1].as<TupleNode>()) << "call_tir.args[1] must be Tuple";
+    Optional<Array<runtime::NDArray>> arr_args =
+        MatchConstArrayArgs(call->args[1].as<TupleNode>()->fields);
+    ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one 
sinfo arg";
+    Optional<runtime::ShapeTuple> shape = MatchConstShape(call->sinfo_args[0]);
+    bool output_not_tuple = call->sinfo_args.size() == 1;
+    // Pattern 0: call constant function, const argument with const shape.
+    if (func && arr_args && shape && output_not_tuple) {
+      DynTensorType ret_type = Downcast<DynTensorType>(call->checked_type());
+      // value_or will return value if it is not null, otherwise return or
+      return ConstEvaluateCallTIR(func.value(), arr_args.value(), 
shape.value(), ret_type->dtype)
+          .value_or(call);
+    }
+    // TODO(hongyi): support const-fold tuple outputs
+    return std::move(call);
+  }
+
+  using ExprMutator::VisitExpr_;
+
+  Expr VisitExpr_(const CallNode* call) final {
+    // post-order mutation
+    Call post_call = Downcast<Call>(VisitExprPostOrder_(call));
+    static const Op& call_tir_op = Op::Get("relax.call_tir");
+
+    if (call->op.same_as(call_tir_op)) {
+      return VisitCallTIR(post_call);
+    }
+    return std::move(post_call);
+  }
+
+  Expr VisitExpr_(const DataflowVarNode* op) final {
+    Optional<Expr> opt = LookupBinding(GetRef<Var>(op));
+    // `as` check checks if opt is not null and is instance of constant
+    if (opt.as<relax::ConstantNode>()) {
+      return opt.value();
+    }
+    return ExprMutator::VisitExpr_(op);
+  }
+
+  Expr VisitExpr_(const VarNode* op) final {
+    Optional<Expr> opt = LookupBinding(GetRef<Var>(op));
+    // `as` check checks if opt is not null and is instance of constant
+    if (opt.as<relax::ConstantNode>()) {
+      return opt.value();
+    }
+    return ExprMutator::VisitExpr_(op);
+  }
+
+  // the context module to lookup functions
+  IRModule ctx_module_;
+  // cache for function build, via structural equality
+  std::unordered_map<tir::PrimFunc, Optional<runtime::PackedFunc>, 
StructuralHash, StructuralEqual>
+      func_build_cache_;
+};
+
+namespace transform {
+
+Pass FoldConstant() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
+      [=](Function f, IRModule m, PassContext pc) {
+        ConstantFolder folder(m);
+        return Downcast<Function>(folder(f));
+      };
+  return CreateFunctionPass(pass_func, 0, "FoldConstant", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.FoldConstant").set_body_typed(FoldConstant);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_transform_bind_params.py 
b/tests/python/relax/test_transform_bind_params.py
new file mode 100644
index 0000000000..b96fb89e6c
--- /dev/null
+++ b/tests/python/relax/test_transform_bind_params.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import relax
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+use_np_array = tvm.testing.parameter(False, True)
+
+
+def test_bind_params(use_np_array):
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
+            T.func_attr({"global_symbol": "tir_matmul"})
+            A = T.match_buffer(x, (16, 16))
+            B = T.match_buffer(y, (16, 16))
+            C = T.match_buffer(z, (16, 16))
+            for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
+                with T.block("matmul"):
+                    vi = T.axis.S(16, i0 * 4 + i1)
+                    vj = T.axis.S(16, j)
+                    vk = T.axis.R(16, k0 * 4 + k1)
+                    with T.init():
+                        C[vi, vj] = T.float32(0)
+                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+        @R.function
+        def main(
+            x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")
+        ) -> R.Tensor((16, 16), "float32"):
+            gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((16, 16), 
dtype="float32"))
+            return gv0
+
+    x_np = np.random.rand(16, 16).astype(np.float32)
+    w_np = np.random.rand(16, 16).astype(np.float32)
+    x_tvm = tvm.nd.array(x_np)
+    w_tvm = tvm.nd.array(w_np)
+    params_dict = {"w": w_np if use_np_array else w_tvm}
+    mod = relax.transform.BindParams("main", params_dict)(InputModule)
+    assert len(mod["main"].params) == 1
+
+    target = tvm.target.Target("llvm")
+    ex_after = relax.vm.build(mod, target)
+    vm_after = relax.VirtualMachine(ex_after, tvm.cpu())
+    res_after = vm_after["main"](x_tvm)
+
+    ex_before = relax.vm.build(InputModule, target)
+    vm_before = relax.VirtualMachine(ex_before, tvm.cpu())
+    res_before = vm_before["main"](x_tvm, w_tvm)
+
+    tvm.testing.assert_allclose(res_before.numpy(), res_after.numpy())
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fold_constant.py 
b/tests/python/relax/test_transform_fold_constant.py
new file mode 100644
index 0000000000..32ee3e7000
--- /dev/null
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -0,0 +1,280 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import relax
+import numpy as np
+
+import tvm.script
+from tvm.script import tir as T, relax as R
+
+
+def gen_mod(mod, name, binding):
+    """Select relax function with name, rename to main and and bind constant.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The input module
+
+    name: str
+        The name of relax function to preserve and rename to main
+
+    binding: Dict[str, array]
+        The const parameter bindings
+    """
+    funcs = {}
+    binding = {k: tvm.nd.array(v) for k, v in binding.items()}
+
+    for k, v in mod.functions.items():
+        if isinstance(v, tvm.relax.Function):
+            if k.name_hint == name:
+                # rename to main
+                gv = tvm.ir.GlobalVar("main")
+                funcs[gv] = tvm.relax.Function(v.params, v.body, 
v.ret_struct_info).with_attr(
+                    "global_symbol", "main"
+                )
+        else:
+            funcs[k] = v
+    mod = tvm.IRModule(funcs)
+    return relax.transform.BindParams("main", binding)(mod)
+
+
+def test_one_fold_addone():
+    # put before after in a single module
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def addone(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), 
"float32"]) -> None:
+            for i, j in T.grid(16, 16):
+                with T.block("addone"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    B[vi, vj] = A[vi, vj] + T.float32(1)
+
+        @R.function
+        def before(c0: R.Tensor((16, 16), "float32")):
+            lv0 = relax.call_tir(addone, (c0,), R.Tensor((16, 16), 
dtype="float32"))
+            return lv0
+
+        @R.function
+        def expected(c1: R.Tensor((16, 16), "float32")):
+            lv0 = c1
+            return c1
+
+    c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
+    c1_np = c0_np + 1
+    before = gen_mod(Module, "before", {"c0": c0_np})
+    expected = gen_mod(Module, "expected", {"c1": c1_np})
+
+    after = relax.transform.FoldConstant()(before)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_one_fold_transpose():
+    # put before after in a single module
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def func(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), 
"float32"]) -> None:
+            for i, j in T.grid(3, 2):
+                with T.block("transpose"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    B[vi, vj] = A[vj, vi]
+
+        @R.function
+        def before(c0: R.Tensor((2, 3), "float32")):
+            lv0 = relax.call_tir(func, (c0,), R.Tensor((3, 2), 
dtype="float32"))
+            return lv0
+
+        @R.function
+        def expected(c1: R.Tensor((3, 2), "float32")):
+            lv0 = c1
+            return c1
+
+    c0_np = np.arange(2 * 3).astype("float32").reshape(2, 3)
+    c1_np = c0_np.T
+    before = gen_mod(Module, "before", {"c0": c0_np})
+    expected = gen_mod(Module, "expected", {"c1": c1_np})
+
+    after = relax.transform.FoldConstant()(before)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_two_hop_addone():
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def addone(A: T.Buffer[(2, 2), "float32"], B: T.Buffer[(2, 2), 
"float32"]) -> None:
+            for i, j in T.grid(2, 2):
+                with T.block("addone"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    B[vi, vj] = A[vi, vj] + T.float32(1)
+
+        @R.function
+        def before(c0: R.Tensor((2, 2), "float32")):
+            lv0 = relax.call_tir(addone, (c0,), R.Tensor((2, 2), 
dtype="float32"))
+            lv1 = relax.call_tir(addone, (lv0,), R.Tensor((2, 2), 
dtype="float32"))
+            return lv1
+
+        @R.function
+        def expected(c1: R.Tensor((2, 2), "float32"), c2: R.Tensor((2, 2), 
"float32")):
+            lv0 = c1
+            lv1 = c2
+            return c2
+
+    c0_np = np.arange((2 * 2)).astype("float32").reshape(2, 2)
+    c1_np = c0_np + 1
+    c2_np = c1_np + 1
+    before = gen_mod(Module, "before", {"c0": c0_np})
+    expected = gen_mod(Module, "expected", {"c1": c1_np, "c2": c2_np})
+
+    after = relax.transform.FoldConstant()(before)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_dataflow_fold():
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def identity(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), 
"float32"]) -> None:
+            for i, j in T.grid(16, 16):
+                with T.block("identity"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    B[vi, vj] = A[vi, vj]
+
+        @R.function
+        def before(c0: R.Tensor((16, 16), "float32")):
+            with R.dataflow():
+                gv0 = relax.call_tir(identity, (c0,), R.Tensor((16, 16), 
dtype="float32"))
+                R.output(gv0)
+            return gv0
+
+        @R.function
+        def expected(c1: R.Tensor((16, 16), "float32")):
+            with R.dataflow():
+                gv0 = c1
+                R.output(gv0)
+            return c1
+
+    c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
+    c1_np = c0_np
+    before = gen_mod(Module, "before", {"c0": c0_np})
+    expected = gen_mod(Module, "expected", {"c1": c1_np})
+    after = relax.transform.FoldConstant()(before)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_fold_mixed_case():
+    @tvm.script.ir_module
+    class Module:
+        # TIR function can handle different cases.
+        @T.prim_func
+        def addone(a: T.handle, b: T.handle) -> None:
+            n = T.var("int32")
+            m = T.var("int32")
+            A = T.match_buffer(a, (n, m))
+            B = T.match_buffer(b, (n, m))
+            for i, j in T.grid(n, m):
+                with T.block("addone"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    B[vi, vj] = A[vi, vj] + T.float32(1)
+
+        @T.prim_func
+        def sub(
+            A: T.Buffer[(16, 16), "float32"],
+            B: T.Buffer[(16, 16), "float32"],
+            C: T.Buffer[(16, 16), "float32"],
+        ) -> None:
+            for i, j in T.grid(16, 16):
+                with T.block("sub"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    C[vi, vj] = A[vi, vj] - B[vi, vj]
+
+        @R.function
+        def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor("float32", 
ndim=2)):
+            n, m = T.var("int64"), T.var("int64")
+            x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
+            # this line cannot be folded because n is unknown
+            lv0 = relax.call_tir(addone, (c0,), R.Tensor((n, 16), 
dtype="float32"))
+            # this line can be folded
+            lv1 = relax.call_tir(addone, (c0,), R.Tensor((16, 16), 
dtype="float32"))
+            # this line can be folded because all inputs are const
+            lv2 = relax.call_tir(sub, (c0, lv1), R.Tensor((16, 16), 
dtype="float32"))
+            # this line can not be folded because x's shape is unknown
+            lv3 = relax.call_tir(sub, (lv2, x), R.Tensor((16, 16), 
dtype="float32"))
+            return lv3
+
+        @R.function
+        def expected(
+            c0: R.Tensor((16, 16), "float32"),
+            c1: R.Tensor((16, 16), "float32"),
+            c2: R.Tensor((16, 16), "float32"),
+            x: R.Tensor("float32", ndim=2),
+        ) -> R.Tensor:
+            n, m = T.var("int64"), T.var("int64")
+            x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
+            # this line cannot be folded because n is unknown
+            lv0 = relax.call_tir(addone, (c0,), R.Tensor((n, 16), 
dtype="float32"))
+            # this line can be folded
+            lv1 = c1
+            # this line can be folded because all inputs are const
+            lv2 = c2
+            # this line can not be folded because x's shape is unknown
+            lv3 = relax.call_tir(sub, (c2, x), R.Tensor((16, 16), 
dtype="float32"))
+            return lv3
+
+    c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
+    c1_np = c0_np + 1
+    c2_np = c0_np - c1_np
+
+    before = gen_mod(Module, "before", {"c0": c0_np})
+    expected = gen_mod(Module, "expected", {"c0": c0_np, "c1": c1_np, "c2": 
c2_np})
+    after = relax.transform.FoldConstant()(before)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_int32_fold():
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def addone(A: T.Buffer[(16, 16), "int32"], B: T.Buffer[(16, 16), 
"int32"]) -> None:
+            for i, j in T.grid(16, 16):
+                with T.block("addone"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    B[vi, vj] = A[vi, vj] + T.int32(1)
+
+        @R.function
+        def before(c0: R.Tensor((16, 16), "int32")):
+            lv0 = relax.call_tir(addone, (c0,), R.Tensor((16, 16), 
dtype="int32"))
+            return lv0
+
+        @R.function
+        def expected(c1: R.Tensor((16, 16), "int32")):
+            lv0 = c1
+            return c1
+
+    c0_np = np.arange((16 * 16)).astype("int32").reshape(16, 16)
+    c1_np = c0_np + 1
+    before = gen_mod(Module, "before", {"c0": c0_np})
+    expected = gen_mod(Module, "expected", {"c1": c1_np})
+
+    after = relax.transform.FoldConstant()(before)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to