This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 2e2126f9e3 [Unity] Implement relax.Function.bind_symbolic_vars (#15509)
2e2126f9e3 is described below
commit 2e2126f9e342b1996f7b80ce0dd9e11095ef481c
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Aug 16 09:10:23 2023 -0500
[Unity] Implement relax.Function.bind_symbolic_vars (#15509)
* [Unity] Implement relax.Function.bind_symbolic_vars
If a function has dynamic shape parameters, it can be useful to
replace them with static parameters (e.g. when producing several
models within the same family). This commit introduces a utility
function `relax.Function.bind_symbolic_vars`, which allows symbolic
variables to be replaced with static values.
This is a related to the parameter binding done in
`relax.transform.BindParam`, but does not require the bound parameter
to be fully static data array.
* Updating ExprBinder to use tir::Substitute
Previously, `ExprBinder` only checked whether a `PrimExpr` was a
symbolic variable to be replaced, but did not handle cases where a
`PrimExpr` contained a symbolic variable to be replaced. As a result,
when binding symbolic variables `{N: 16}`, a shape of `[N,2*N]` would be
updated to `[16,2*N]` instead of `[16,32]`. This commit updates
`ExprBinder` to use `tir::Substitute` to ensure all occurrences of the
symbolic variable are replaced.
* Special case for updating symbolic vars in strided_slice attrs
* Added IRModule pass to bind symbolic vars
* Update unit test to include pytest
* Co-authored-by: Sunghyun Park <[email protected]>
* Correct match mode in kProvideDefinitions context
* Clean up implementation with VisitMode as a bitflag
---
include/tvm/relax/transform.h | 18 ++
python/tvm/relax/expr.py | 32 ++-
python/tvm/relax/transform/transform.py | 29 +++
src/relax/analysis/struct_info_analysis.cc | 63 +++--
src/relax/transform/bind_symbolic_vars.cc | 177 ++++++++++++++
src/relax/utils.cc | 58 ++++-
tests/python/relax/test_bind_symbolic_vars.py | 205 ++++++++++++++++
.../relax/test_transform_bind_symbolic_vars.py | 270 +++++++++++++++++++++
8 files changed, 827 insertions(+), 25 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 8d01262aab..05b26f0242 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -182,6 +182,24 @@ TVM_DLL Pass EliminateCommonSubexpr(bool call_only =
false);
*/
TVM_DLL Pass BindParams(String func_name, Map<String, runtime::NDArray>
params);
+/*!
+ * \brief Bind symbolic vars to constant shape values.
+ *
+ * \param binding_map The dictionary of symbolic variables and their
+ * constant shape values. Dictionary keys may be either a
+ * `tir.Var` or a string name of the variable. If the variables
+ * are referred to by name, the name must uniquely identify a
+ * symbolic variable in each function where it is used.
+ *
+ * \param func_name The name of the function in which to bind shape
+ * values. If NullOpt, all functions in the module will be
+ * updated.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass BindSymbolicVars(Map<ObjectRef, PrimExpr> binding_map,
+ Optional<String> func_name = NullOpt);
+
/*!
* \brief Fold constant expressions.
*
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index fb8ccf98d3..49b91ffb3d 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -19,7 +19,7 @@
"""The expression nodes of Relax."""
import typing
from numbers import Number
-from typing import Any, Callable, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union, Mapping
import numpy as _np # type: ignore
@@ -627,6 +627,36 @@ class Function(BaseFunc, Scriptable):
"""
return Call(self, args, None, None)
+ def bind_symbolic_vars(
+ self, binding_map: Mapping[Union[str, tvm.tir.Var], PrimExpr]
+ ) -> "Function":
+ """Return a new function with updated symbolic variable
+
+ Parameters
+ ----------
+ binding_map: Mapping[Union[str, tvm.tir.Var], PrimExpr]
+
+ The mapping of values to be replaced. Keys may be either
+ a `tir.Var` or a string name of the variable. If the
+ variables are referred to by name, the name must uniquely
+ identify a symbolic variable in the function.
+
+ Returns
+ -------
+ func: Function
+
+ The updated function
+ """
+
+ # Relax uses int64 for symbolic variables, but the FFI
+ # converts python integers into int32.
+ binding_map = {
+ key: tvm.tir.const(value, "int64") if isinstance(value, int) else
value
+ for key, value in binding_map.items()
+ }
+
+ return _ffi_api.FunctionBindSymbolicVars(self, binding_map) # type:
ignore
+
@tvm._ffi.register_object("relax.expr.ExternFunc")
class ExternFunc(BaseFunc):
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index f512e42bf6..438a6d1213 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -416,6 +416,35 @@ def BindParams(
return _ffi_api.BindParams(func_name, tvm_params) # type: ignore
+def BindSymbolicVars(
+ binding_map: Mapping[Union[str, tvm.tir.Var], tvm.tir.PrimExpr],
+ func_name: Optional[str] = None,
+) -> tvm.ir.transform.Pass:
+ """Bind params of function of the module to constant tensors.
+ Parameters
+ ----------
+ binding_map : Mapping[Union[str, tvm.tir.Var], tvm.tir.PrimExpr]
+
+ The map from symbolic varname to integer.
+
+ func_name: Optional[str]
+
+ The function name to be bound. If None (default), all
+ functions within the module will be updated.
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+ # Relax uses int64 for symbolic variables, but the FFI
+ # converts python integers into int32.
+ binding_map = {
+ key: tvm.tir.const(value, "int64") if isinstance(value, int) else value
+ for key, value in binding_map.items()
+ }
+ return _ffi_api.BindSymbolicVars(binding_map, func_name) # type: ignore
+
+
def RunCodegen(
target_options: Optional[dict] = None,
entry_functions: Optional[List[str]] = None,
diff --git a/src/relax/analysis/struct_info_analysis.cc
b/src/relax/analysis/struct_info_analysis.cc
index 22c2e9bbd4..9fae776279 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -978,17 +978,34 @@ class SymbolicVarCollector : public relax::ExprVisitor,
using tir::ExprVisitor::VisitExpr;
using tir::ExprVisitor::VisitExpr_;
- // Possible mode of visitor
- enum class VisitMode {
- /*! \brief Check all vars are well-defined. */
- kDefault,
- /*! \brief Match define the vars on first occurrence. */
- kMatchVarDef,
+ // Possible mode of visitor, used as bit-flags
+ enum VisitMode {
+ /*! \brief Do nothing on encountering a symbolic variable */
+ kNone = 0,
+
+ /*! \brief Provide a variable definition on first occurrence.
+ *
+ * If a symbolic variable occurs at a site where a definition can
+ * be provided, mark the variable as having a definition.
+ */
+ kProvideDefinition = 1,
+
+ /*! \brief Require a variable definition on occurrence.
+ *
+ * If a symbolic variable occurs, and has not previously been
+ * defined, mark the variable as being free/undefined.
+ */
+ kRequireDefinition = 2,
};
void VisitExpr_(const FunctionNode* op) final {
- WithMode(VisitMode::kMatchVarDef, [&]() {
- ICHECK(mode_ == VisitMode::kMatchVarDef);
+ WithMode(VisitMode::kProvideDefinition, [&]() {
+ for (Var param : op->params) {
+ relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param));
+ }
+ });
+
+ WithMode(VisitMode::kRequireDefinition, [&]() {
for (Var param : op->params) {
relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param));
}
@@ -998,7 +1015,8 @@ class SymbolicVarCollector : public relax::ExprVisitor,
}
void VisitBinding_(const MatchCastNode* binding) final {
- WithMode(VisitMode::kMatchVarDef, [&]() {
this->VisitStructInfo(binding->struct_info); });
+ WithMode(VisitMode(VisitMode::kProvideDefinition |
VisitMode::kRequireDefinition),
+ [&]() { this->VisitStructInfo(binding->struct_info); });
relax::ExprVisitor::VisitBinding_(binding);
}
@@ -1009,8 +1027,17 @@ class SymbolicVarCollector : public relax::ExprVisitor,
void VisitStructInfo_(const FuncStructInfoNode* op) final {
if (op->params.defined()) {
- WithMode(VisitMode::kMatchVarDef, [&]() {
- ICHECK(mode_ == VisitMode::kMatchVarDef);
+ // Visit the parameters once to collect bindings, and another
+ // time to collect usages. Otherwise, a symbolic variable
+ // defined by a later parameter may be treated as undefined when
+ // used by an earlier parameter.
+ WithMode(VisitMode::kProvideDefinition, [&]() {
+ for (StructInfo param : op->params.value()) {
+ this->VisitStructInfo(param);
+ }
+ });
+
+ WithMode(VisitMode::kRequireDefinition, [&]() {
for (StructInfo param : op->params.value()) {
this->VisitStructInfo(param);
}
@@ -1029,14 +1056,14 @@ class SymbolicVarCollector : public relax::ExprVisitor,
}
void VisitStructInfoExprField(const PrimExpr& expr) final {
- if (mode_ == VisitMode::kMatchVarDef && expr->IsInstance<tir::VarNode>()) {
- // populate symbolic var in first occurrence
- const auto& var = Downcast<tir::Var>(expr);
- if (defined_symbolic_var_.count(var) == 0) {
- defined_symbolic_var_.insert(var);
+ if (mode_ & VisitMode::kProvideDefinition) {
+ if (auto var = expr.as<tir::Var>()) {
+ defined_symbolic_var_.insert(var.value());
}
}
- tir::ExprVisitor::VisitExpr(expr);
+ if (mode_ & VisitMode::kRequireDefinition) {
+ tir::ExprVisitor::VisitExpr(expr);
+ }
}
void VisitExpr_(const tir::VarNode* op) final {
@@ -1056,7 +1083,7 @@ class SymbolicVarCollector : public relax::ExprVisitor,
}
/*! \brief The current visit mode. */
- VisitMode mode_ = VisitMode::kDefault;
+ VisitMode mode_ = VisitMode::kRequireDefinition;
/*! \brief The set of defined symbolic vars. */
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual>
defined_symbolic_var_;
/*! \brief The set of free/undefined symbolic vars. */
diff --git a/src/relax/transform/bind_symbolic_vars.cc
b/src/relax/transform/bind_symbolic_vars.cc
new file mode 100644
index 0000000000..2df9ed1f01
--- /dev/null
+++ b/src/relax/transform/bind_symbolic_vars.cc
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/type.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+Function FunctionBindSymbolicVars(Function func, Map<ObjectRef, PrimExpr>
obj_remap) {
+ // Early bail-out if no updates need to be made.
+ if (obj_remap.empty()) {
+ return func;
+ }
+
+ Array<tir::Var> old_symbolic_vars = DefinedSymbolicVars(func);
+
+ // Map from string to the variable(s) with that name.
+ std::unordered_map<std::string, Array<tir::Var>> string_lookup;
+ std::unordered_set<const tir::VarNode*> symbolic_var_set;
+ for (const auto& var : old_symbolic_vars) {
+ string_lookup[var->name_hint].push_back(var);
+ symbolic_var_set.insert(var.get());
+ }
+
+ // Replacement map to be used when rewriting the function.
+ Map<tir::Var, PrimExpr> var_remap;
+ for (const auto& [key, replacement] : obj_remap) {
+ if (auto opt = key.as<String>()) {
+ String string_key = opt.value();
+ auto it = string_lookup.find(string_key);
+ CHECK(it != string_lookup.end())
+ << "Function does not use symbolic var with name \"" << string_key
<< "\". "
+ << "Function has symbolic variables " << old_symbolic_vars;
+
+ CHECK_EQ(it->second.size(), 1)
+ << "Function contains multiple symbolic variables with name \"" <<
string_key << "\". "
+ << "The TIR variables " << it->second << " are all named \"" <<
string_key << "\"";
+ auto var = it->second[0];
+
+ CHECK(!var_remap.count(var)) << "Remap of variable " << var << " was
defined multiple times";
+ var_remap.Set(var, replacement);
+ } else if (auto opt = key.as<tir::Var>()) {
+ auto var = opt.value();
+
+ CHECK(!var_remap.count(var)) << "Remap of variable " << var << " was
defined multiple times";
+ CHECK(symbolic_var_set.count(var.get()))
+ << "Function does not use variable " << var << " as a symbolic
variable. "
+ << "Function has symbolic variables " << old_symbolic_vars;
+ var_remap.Set(var, replacement);
+ } else {
+ LOG(FATAL) << "Expected symbolic variable to be a tir::Var or a string
name, "
+ << "but " << key << " was of type " << key->GetTypeKey();
+ }
+ }
+
+ auto new_func = Downcast<Function>(Bind(func, {}, var_remap));
+
+ auto free_symbolic_vars = FreeSymbolicVars(new_func);
+
+ CHECK(free_symbolic_vars.empty())
+ << "Resulting function should not have any undefined symbolic variables,
"
+ << "but TIR variables " << free_symbolic_vars << " were undefined.";
+
+ return new_func;
+}
+
+namespace {
+IRModule ModuleBindSymbolicVars(IRModule mod, Map<ObjectRef, PrimExpr>
binding_map) {
+ std::unordered_set<const Object*> used;
+ IRModule updates;
+ for (const auto& [gvar, base_func] : mod->functions) {
+ if (auto opt = base_func.as<Function>()) {
+ auto func = opt.value();
+
+ // Collect bindings that are used by this function.
+ auto func_binding_map = [&]() -> Map<ObjectRef, PrimExpr> {
+ std::unordered_set<std::string> var_names;
+ std::unordered_set<const tir::VarNode*> vars;
+ for (const auto& var : DefinedSymbolicVars(func)) {
+ var_names.insert(var->name_hint);
+ vars.insert(var.get());
+ }
+
+ Map<ObjectRef, PrimExpr> out;
+ for (const auto& [key, replacement] : binding_map) {
+ bool used_by_function = false;
+ if (auto opt = key.as<String>()) {
+ used_by_function = var_names.count(opt.value());
+ } else if (auto ptr = key.as<tir::VarNode>()) {
+ used_by_function = vars.count(ptr);
+ } else {
+ LOG(FATAL) << "Expected symbolic variable to be a tir::Var "
+ << "or a string name, but " << key << " was of type "
<< key->GetTypeKey();
+ }
+ if (used_by_function) {
+ used.insert(key.get());
+ out.Set(key, replacement);
+ }
+ }
+ return out;
+ }();
+ func = FunctionBindSymbolicVars(func, func_binding_map);
+
+ if (!func.same_as(base_func)) {
+ updates->Add(gvar, func);
+ }
+ }
+ }
+
+ Array<ObjectRef> unused;
+ for (const auto& [key, replacement] : binding_map) {
+ if (!used.count(key.get())) {
+ unused.push_back(key);
+ }
+ }
+ CHECK_EQ(unused.size(), 0) << "Binding map contains keys " << unused
+ << ", which did not correspond to any symbolic
variables "
+ << "in the module.";
+
+ if (updates->functions.size()) {
+ mod.CopyOnWrite()->Update(updates);
+ }
+ return mod;
+}
+} // namespace
+
+TVM_REGISTER_GLOBAL("relax.FunctionBindSymbolicVars").set_body_typed(FunctionBindSymbolicVars);
+
+namespace transform {
+
+Pass BindSymbolicVars(Map<ObjectRef, PrimExpr> binding_map, Optional<String>
func_name) {
+ auto pass_func = [=](IRModule mod, PassContext context) -> IRModule {
+ if (func_name) {
+ auto gvar = mod->GetGlobalVar(func_name.value());
+ auto func = Downcast<Function>(mod->Lookup(gvar));
+ auto new_func = FunctionBindSymbolicVars(func, binding_map);
+ if (!func.same_as(new_func)) {
+ mod.CopyOnWrite()->Update(gvar, new_func);
+ }
+ } else {
+ mod = ModuleBindSymbolicVars(mod, binding_map);
+ }
+ return mod;
+ };
+
+ return tvm::transform::CreateModulePass(pass_func, 1,
"relax.BindSymbolicVars", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.BindSymbolicVars").set_body_typed(BindSymbolicVars);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index b0816b0eda..ccb72805e3 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -20,7 +20,9 @@
#include "transform/utils.h"
#include <tvm/relax/analysis.h>
+#include <tvm/relax/attrs/index.h>
#include <tvm/relax/expr_functor.h>
+#include <tvm/tir/stmt_functor.h>
namespace tvm {
namespace relax {
@@ -33,6 +35,8 @@ class ExprBinder : public ExprMutator {
: args_map_(args_map), symbolic_var_map_(symbolic_var_map) {}
private:
+ using ExprMutator::VisitExpr_;
+
Expr VisitExpr_(const FunctionNode* op) final {
tvm::Array<Var> params;
bool all_params_unchanged = true;
@@ -61,6 +65,49 @@ class ExprBinder : public ExprMutator {
}
}
+ Expr VisitExpr_(const CallNode* op) final {
+ auto call_node = Downcast<Call>(ExprMutator::VisitExpr_(op));
+
+ // Special case for strided_slice
+ //
+ // The strided_slice operator currently stores the begins/ends in
+ // the CallNode::attrs. Because the CallNode::attrs is only
+ // intended to store static information, any PrimExpr members in
+ // the attributes are not visited by `ExprMutator::VisitPrimExpr`.
+ // Therefore, these must be explicitly visited.
+ //
+ // When the strided_slice operator is updated to store begins/ends
+ // as a tuple of `relax::PrimValue` in the arguments, this special
+ // case can be removed.
+ static auto strided_slice_op = Op::Get("relax.strided_slice");
+ if (call_node->op.same_as(strided_slice_op)) {
+ auto attrs = call_node->attrs.as<StridedSliceAttrs>();
+
+ auto visit_prim_expr = [this](const auto& expr) { return
VisitPrimExpr(expr); };
+
+ Array<PrimExpr> begin = attrs->begin.Map(visit_prim_expr);
+ Array<PrimExpr> end = attrs->end.Map(visit_prim_expr);
+ auto strides = attrs->strides;
+ if (strides.defined()) {
+ strides = strides.value().Map(visit_prim_expr);
+ }
+
+ bool all_same = begin.same_as(attrs->begin) && end.same_as(attrs->end) &&
+ (!strides.defined() || strides.same_as(attrs->strides));
+ if (!all_same) {
+ ObjectPtr<StridedSliceAttrs> new_attrs =
make_object<StridedSliceAttrs>();
+ new_attrs->axes = attrs->axes;
+ new_attrs->begin = std::move(begin);
+ new_attrs->end = std::move(end);
+ new_attrs->strides = std::move(strides);
+ new_attrs->assume_inbound = attrs->assume_inbound;
+ call_node.CopyOnWrite()->attrs = Attrs(new_attrs);
+ }
+ }
+
+ return std::move(call_node);
+ }
+
Expr VisitExpr_(const VarNode* op) final {
auto id = GetRef<Var>(op);
auto it = args_map_.find(id);
@@ -72,13 +119,12 @@ class ExprBinder : public ExprMutator {
}
PrimExpr VisitPrimExpr(const PrimExpr& expr) final {
- if (const tir::VarNode* var = expr.as<tir::VarNode>()) {
- auto it = symbolic_var_map_.find(GetRef<tir::Var>(var));
- if (it != symbolic_var_map_.end()) {
- return (*it).second;
- }
+ auto new_expr = tir::Substitute(expr, symbolic_var_map_);
+ if (!expr.same_as(new_expr)) {
+ arith::Analyzer analyzer;
+ new_expr = analyzer.Simplify(new_expr);
}
- return ExprMutator::VisitPrimExpr(expr);
+ return new_expr;
}
private:
diff --git a/tests/python/relax/test_bind_symbolic_vars.py
b/tests/python/relax/test_bind_symbolic_vars.py
new file mode 100644
index 0000000000..1dc1189a67
--- /dev/null
+++ b/tests/python/relax/test_bind_symbolic_vars.py
@@ -0,0 +1,205 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import tvm
+import tvm.testing
+from tvm.script import relax as R, tir as T
+
+replace_by_tir_var = tvm.testing.parameter(
+ by_dict={"replace-by-string": False, "replace-by-tir-var": True}
+)
+
+
+def test_bind_static_value(replace_by_tir_var):
+ """Symbolic vars may be replaced
+
+ The replaced variables may be given either as strings, or as TIR variables
+ """
+
+ @R.function(private=True)
+ def before(A: R.Tensor(("M", "K")), B: R.Tensor(("K", "N"))) ->
R.Tensor(("M", "N")):
+ return R.matmul(A, B)
+
+ @R.function(private=True)
+ def expected(A: R.Tensor((128, 64)), B: R.Tensor((64, 32))) ->
R.Tensor((128, 32)):
+ return R.matmul(A, B)
+
+ if replace_by_tir_var:
+ M, K = before.params[0].struct_info.shape
+ _, N = before.params[1].struct_info.shape
+ symbolic_var_map = {M: 128, K: 64, N: 32}
+ else:
+ symbolic_var_map = {"M": 128, "K": 64, "N": 32}
+
+ after = before.bind_symbolic_vars(symbolic_var_map)
+ tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_error_with_duplicate_var_names():
+ """Duplicate variable names may not be replaced by string
+
+ Two TIR variables may have the same name. If two symbolic
+ variables share the same name, the replacement map may not refer
+ to that variable by string.
+ """
+ N1 = tvm.tir.Var("N", "int64")
+ N2 = tvm.tir.Var("N", "int64")
+
+ @R.function(private=True)
+ def func(A: R.Tensor((N1, N1)), B: R.Tensor((N1, N2))) -> R.Tensor((N1,
N2)):
+ out: R.Tensor((N1, N2)) = R.matmul(A, B)
+ return out
+
+ with pytest.raises(tvm.TVMError):
+ func.bind_symbolic_vars({"N": 64})
+
+
+def test_string_var_when_other_var_has_duplicate_var_names():
+ """Like test_error_with_duplicate_var_names, but replacing a different
variable
+
+ If two TIR variables share the same name, the restriction against
+ replacing variables by name only applies to those duplicate names.
+ Other variables may still be replaced by name.
+ """
+ N1 = tvm.tir.Var("N", "int64")
+ N2 = tvm.tir.Var("N", "int64")
+ BatchSize = tvm.tir.Var("BatchSize", "int64")
+
+ @R.function(private=True)
+ def before(
+ A: R.Tensor((BatchSize, N1, N1)), B: R.Tensor((N1, N2))
+ ) -> R.Tensor((BatchSize, N1, N2)):
+ out: R.Tensor((BatchSize, N1, N2)) = R.matmul(A, B)
+ return out
+
+ @R.function(private=True)
+ def expected(A: R.Tensor((16, N1, N1)), B: R.Tensor((N1, N2))) ->
R.Tensor((16, N1, N2)):
+ out: R.Tensor((16, N1, N2)) = R.matmul(A, B)
+ return out
+
+ after = before.bind_symbolic_vars({"BatchSize": 16})
+ tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_error_with_nonexisting_var_name():
+ """A string name of a symbolic var must be used by the function"""
+
+ @R.function(private=True)
+ def func(A: R.Tensor(("M", "N"))):
+ return A
+
+ with pytest.raises(tvm.TVMError):
+ func.bind_symbolic_vars({"non_existing_symbolic_var": 64})
+
+
+def test_error_with_nonexisting_tir_var():
+ """A TIR symbolic var must be a symbolic var of the function"""
+
+ @R.function(private=True)
+ def func(A: R.Tensor(["M", "N"])):
+ return A
+
+ with pytest.raises(tvm.TVMError):
+ func.bind_symbolic_vars({tvm.tir.Var("M", "int64"): 64})
+
+
+def test_error_with_multiple_definitions():
+ """The string/TIR var syntaxes may not define the same variable"""
+
+ @R.function(private=True)
+ def func(A: R.Tensor(["M", "N"])):
+ return A
+
+ tir_var = func.params[0].struct_info.shape[0]
+ symbolic_var_map = {tir_var: 0, "M": 0}
+
+ with pytest.raises(tvm.TVMError):
+ func.bind_symbolic_vars(symbolic_var_map)
+
+
+def test_error_if_output_has_undefined():
+ """The replacements may not introduce undefined symbolic vars"""
+
+ @R.function(private=True)
+ def func(A: R.Tensor(["M", "N"])):
+ return A
+
+ outside_var = tvm.tir.Var("outside_var", "int64")
+
+ with pytest.raises(tvm.TVMError):
+ func.bind_symbolic_vars({"M": outside_var * 2})
+
+
+def test_replacements_may_produce_new_symbolic_vars():
+ """The output may introduce symbolic vars, but they must be bound"""
+
+ @R.function(private=True)
+ def before(A: R.Tensor(["M", "N"])):
+ return A
+
+ @R.function(private=True)
+ def expected(A: R.Tensor(["outside_var * 2", "outside_var"])):
+ return A
+
+ outside_var = tvm.tir.Var("outside_var", "int64")
+
+ after = before.bind_symbolic_vars({"M": outside_var * 2, "N": outside_var})
+ tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_bind_symbolic_vars_in_shape():
+ """The bound variable should be replaced when appearing in struct info"""
+
+ @R.function(private=True)
+ def before(A: R.Tensor(["M", "N"])):
+ M = T.int64()
+ N = T.int64()
+ B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M *
N]))
+ return B
+
+ @R.function(private=True)
+ def expected(A: R.Tensor(["M", 16])):
+ M = T.int64()
+ B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([M * 32]))
+ return B
+
+ after = before.bind_symbolic_vars({"N": 16})
+ tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_bind_strided_slice():
+ """relax.op.strided_slice stores PrimExpr attributes"""
+
+ @R.function(private=True)
+ def before(A: R.Tensor(["M", "N"])):
+ N = T.int64()
+ B = R.strided_slice(A, [1], [0], [N // 4])
+ return B
+
+ @R.function(private=True)
+ def expected(A: R.Tensor(["M", 32])):
+ B = R.strided_slice(A, [1], [0], [8])
+ return B
+
+ after = before.bind_symbolic_vars({"N": 32})
+ tvm.ir.assert_structural_equal(expected, after)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_transform_bind_symbolic_vars.py
b/tests/python/relax/test_transform_bind_symbolic_vars.py
new file mode 100644
index 0000000000..687945a650
--- /dev/null
+++ b/tests/python/relax/test_transform_bind_symbolic_vars.py
@@ -0,0 +1,270 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import relax
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+def test_bind_tensors():
+ """Symbolic variables may occur in Tensor shapes"""
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor(("batch", "m"), dtype="float32"),
+ w0: R.Tensor(("m", "n"), dtype="float32"),
+ w1: R.Tensor(("k", 10), dtype="float32"),
+ ) -> R.Tensor(("batch", "k"), dtype="float32"):
+ batch = T.Var("batch", "int64")
+ n = T.Var("n", "int64")
+ k = T.Var("k", "int64")
+ with R.dataflow():
+ lv0 = R.call_dps_packed(
+ "test0", (x, w0), out_sinfo=R.Tensor((batch, n),
dtype="float32")
+ )
+ out = R.call_dps_packed(
+ "test1", (lv0, w1), out_sinfo=R.Tensor((batch, k),
dtype="float32")
+ )
+ R.output(out)
+ return out
+
+ symvar_map = {"batch": 1, "k": 3}
+ target_func_name = "main"
+ After = relax.transform.BindSymbolicVars(symvar_map,
target_func_name)(Before)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((1, "m"), dtype="float32"),
+ w0: R.Tensor(("m", "n"), dtype="float32"),
+ w1: R.Tensor((3, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3), dtype="float32"):
+ n = T.int64()
+ with R.dataflow():
+ lv0 = R.call_dps_packed(
+ "test0", (x, w0), out_sinfo=R.Tensor((1, n),
dtype="float32")
+ )
+ out = R.call_dps_packed(
+ "test1", (lv0, w1), out_sinfo=R.Tensor((1, 3),
dtype="float32")
+ )
+ R.output(out)
+ return out
+
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
+def test_bind_shape():
+ """Symbolic variables may occur in ShapeExpr"""
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Shape(("batch", "m")),
+ w0: R.Shape(("m", "n")),
+ w1: R.Shape(("k", 10)),
+ ) -> R.Shape(("batch", "k")):
+ batch = T.Var("batch", "int64")
+ n = T.Var("n", "int64")
+ k = T.Var("k", "int64")
+ with R.dataflow():
+ lv0 = R.call_dps_packed("test0", (x, w0),
out_sinfo=R.Tensor((batch, n)))
+ out = R.call_dps_packed("test1", (lv0, w1),
out_sinfo=R.Tensor((batch, k)))
+ R.output(out)
+ return out
+
+ symvar_map = {"batch": 1, "k": 3}
+ target_func_name = "main"
+ After = relax.transform.BindSymbolicVars(symvar_map,
target_func_name)(Before)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Shape([1, "m"]), w0: R.Shape(["m", "n"]), w1: R.Shape([3, 10])
+ ) -> R.Shape([1, 3]):
+ n = T.int64()
+ with R.dataflow():
+ lv0 = R.call_dps_packed("test0", (x, w0),
out_sinfo=R.Tensor((1, n)))
+ out = R.call_dps_packed("test1", (lv0, w1),
out_sinfo=R.Tensor((1, 3)))
+ R.output(out)
+ return out
+
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
+def test_arith():
+ """Symbolic shapes may use TIR arithmetic expressions"""
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor(("batch", "m-1"), dtype="float32"),
+ w0: R.Tensor(("m", "n"), dtype="float32"),
+ w1: R.Tensor(("k", 10), dtype="float32"),
+ ) -> R.Tensor(("batch", "k*m"), dtype="float32"):
+ batch = T.Var("batch", "int64")
+ m = T.Var("m", "int64")
+ n = T.Var("n", "int64")
+ k = T.Var("k", "int64")
+ with R.dataflow():
+ lv0 = R.call_dps_packed(
+ "test0",
+ (x, w0),
+ out_sinfo=R.Tensor((batch, m + n), dtype="float32"),
+ )
+ out = R.call_dps_packed(
+ "test1",
+ (lv0, w1),
+ out_sinfo=R.Tensor((batch, k + n), dtype="float32"),
+ )
+ R.output(out)
+ return out
+
+ symvar_map = {"batch": 1, "k": 2, "m": 3}
+ target_func_name = "main"
+ After = relax.transform.BindSymbolicVars(symvar_map,
target_func_name)(Before)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2), dtype="float32"),
+ w0: R.Tensor((3, "n"), dtype="float32"),
+ w1: R.Tensor((2, 10), dtype="float32"),
+ ) -> R.Tensor((1, 6), dtype="float32"):
+ n = T.int64()
+ with R.dataflow():
+ lv0 = R.call_dps_packed(
+ "test0", (x, w0), out_sinfo=R.Tensor((1, n + 3),
dtype="float32")
+ )
+ out = R.call_dps_packed(
+ "test1", (lv0, w1), out_sinfo=R.Tensor((1, n + 2),
dtype="float32")
+ )
+ R.output(out)
+ return out
+
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
+def test_bind_multiple_variables_by_name():
+ """String names may be used to replace across multiple functions"""
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main_1(x: R.Tensor(("m", "n"), dtype="float32")):
+ return x
+
+ @R.function
+ def main_2(x: R.Tensor(("m", "n"), dtype="float32")):
+ return x
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main_1(x: R.Tensor(("m", 16), dtype="float32")):
+ return x
+
+ @R.function
+ def main_2(x: R.Tensor(("m", 16), dtype="float32")):
+ return x
+
+ After = relax.transform.BindSymbolicVars({"n": 16})(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
+def test_bind_single_variable_by_identity():
+ """TIR variables may be used to replace a specific var"""
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main_1(x: R.Tensor(("m", "n"), dtype="float32")):
+ return x
+
+ @R.function
+ def main_2(x: R.Tensor(("m", "n"), dtype="float32")):
+ return x
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main_1(x: R.Tensor(("m", 16), dtype="float32")):
+ return x
+
+ @R.function
+ def main_2(x: R.Tensor(("m", "n"), dtype="float32")):
+ return x
+
+ main_1_n = Before["main_1"].params[0].struct_info.shape[1]
+ After = relax.transform.BindSymbolicVars({main_1_n: 16})(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
+def test_bind_single_variable_by_function_name():
+ """Variable name and function name may be used to replace a specific var"""
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main_1(x: R.Tensor(("m", "n"), dtype="float32")):
+ return x
+
+ @R.function
+ def main_2(x: R.Tensor(("m", "n"), dtype="float32")):
+ return x
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main_1(x: R.Tensor(("m", 16), dtype="float32")):
+ return x
+
+ @R.function
+ def main_2(x: R.Tensor(("m", "n"), dtype="float32")):
+ return x
+
+ After = relax.transform.BindSymbolicVars({"n": 16}, "main_1")(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
+def test_error_for_unused_replacement():
+ """Each replacement must be used"""
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor(("m", "n"), dtype="float32")):
+ return x
+
+ with pytest.raises(tvm.TVMError):
+ relax.transform.BindSymbolicVars({"non_existing_var_name": 16})(Before)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()