This is an automated email from the ASF dual-hosted git repository.
sslyu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6a3fadc065 [Unity][Transform] Handle `call_tir_inplace` in `FuseTIR`
and `FuseOps` (#16487)
6a3fadc065 is described below
commit 6a3fadc0654ecf9557ffe08d24677684c96e80b0
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Tue Feb 6 15:38:29 2024 -0500
[Unity][Transform] Handle `call_tir_inplace` in `FuseTIR` and `FuseOps`
(#16487)
* WIP initial commit
* Handle in-place calls in FuseTIR
* Formatting
* Add test case for FuseOps
* Address review comments related to clarity
* Use a set to ensure in-place indices will be unique
* Add test case where PrimFunc is used both in-place and DPS
* Explicitly check for duplicate index
---
src/relax/transform/fuse_ops.cc | 10 +-
src/relax/transform/fuse_tir.cc | 158 ++++++++++---
tests/python/relax/test_transform_fuse_ops.py | 141 +++++++++++
tests/python/relax/test_transform_fuse_tir.py | 324 ++++++++++++++++++++++++++
4 files changed, 600 insertions(+), 33 deletions(-)
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index b0eeba399e..32780f6dd2 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -183,6 +183,8 @@ class GraphCreator : public ExprVisitor {
ICHECK_NOTNULL(binding_var_node);
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+ static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
+
OpPatternKind pattern = OpPatternKind::kOpaque;
Array<Expr> args = call->args;
@@ -191,7 +193,7 @@ class GraphCreator : public ExprVisitor {
// - Otherwise, the pattern of the current binding variable node is set to
`kOpaque`, and we
// recurse into the call expression.
const auto* op = call->op.as<OpNode>();
- if (op == call_tir_op_.get()) {
+ if (op == call_tir_op_.get() || op == call_tir_inplace_op_.get()) {
const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(global_var));
@@ -377,7 +379,8 @@ class FunctionCreator : public ExprMutator {
* function accordingly
* \param binding The binding to be appended
* \note Allowed bindings are:
- * - VarBinding with value being a call node calling `relax.call_tir`.
+ * - VarBinding with value being a call node calling `relax.call_tir` or
+ * `relax.call_tir_inplace`.
* - VarBinding with value being a tuple-get-item node.
* // TODO(tvm-team): handle match shape
*/
@@ -387,7 +390,8 @@ class FunctionCreator : public ExprMutator {
if (const auto* var_binding = binding.as<VarBindingNode>()) {
if (const auto* call = var_binding->value.as<CallNode>()) {
- if (call->op == Op::Get("relax.call_tir")) {
+ if (call->op == Op::Get("relax.call_tir") ||
+ call->op == Op::Get("relax.call_tir_inplace")) {
// Update the name of the function.
name_hint_ = name_hint_ + "_" +
Downcast<GlobalVar>(call->args[0])->name_hint;
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 1c25229d88..4ad291e91c 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/relax/analysis.h>
+#include <tvm/relax/attrs/op.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
@@ -367,9 +368,10 @@ class FusedTIRConstructor : public ExprVisitor {
* \brief Construct a fused TIR PrimFunc from a relax sub-function
* \param mod The IRModule
* \param gv The global var of relax subfunction to be fused into one
PrimFunc
- * \return The fused TIR PrimFunc
+ * \return The fused TIR PrimFunc and the in-place indices (non-empty for an
in-place call)
*/
- static tir::PrimFunc GetFusedTIR(const IRModule& mod, const GlobalVar& gv) {
+ static std::pair<tir::PrimFunc, Array<Integer>> GetFusedTIR(const IRModule&
mod,
+ const GlobalVar&
gv) {
FusedTIRConstructor visitor(mod, gv->name_hint);
BaseFunc f = mod->Lookup(gv);
CHECK(f->IsInstance<relax::FunctionNode>())
@@ -377,7 +379,11 @@ class FusedTIRConstructor : public ExprVisitor {
CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive))
<< "Expected a function with attr `kPrimitive`";
visitor(Downcast<relax::Function>(f));
- return visitor.fused_tir_;
+ Array<Integer> inplace_indices;
+ for (size_t idx : visitor.inplace_indices_) {
+ inplace_indices.push_back(Integer(idx));
+ }
+ return {visitor.fused_tir_, inplace_indices};
}
private:
@@ -438,9 +444,38 @@ class FusedTIRConstructor : public ExprVisitor {
auto it = func_info_.expr2buffers.find(body);
ICHECK(it != func_info_.expr2buffers.end())
<< "Fail to detect output buffers for function body";
+
const Array<tir::Buffer>& buffers = (*it).second;
+
+ // map of input buffers to indices (helpful for detecting in-place inputs)
+ std::unordered_map<tir::Buffer, size_t, ObjectPtrHash, ObjectPtrEqual>
buffer_to_idx;
+ std::unordered_map<tir::Var, size_t, ObjectPtrHash, ObjectPtrEqual>
input_to_idx;
+ for (size_t i = 0; i < func_info_.params.size(); i++) {
+ input_to_idx[func_info_.params[i]] = i;
+ }
+ for (auto [var, buffer] : func_info_.buffer_map) {
+ if (auto it = input_to_idx.find(var); it != input_to_idx.end()) {
+ buffer_to_idx[buffer] = (*it).second;
+ }
+ }
+
+ // numbered separately because the number of output *vars* might differ
from the
+ // number of outputs if there are in-place inputs
+ int out_idx = 0;
for (size_t i = 0; i < buffers.size(); ++i) {
- tir::Var param = tir::Var("p_output" + std::to_string(i),
PrimType(DataType::Handle()));
+ // Do not add output vars for in-place inputs
+ // (i.e., already listed in the buffer map. This would result
+ // in duplicates in the buffer map otherwise)
+ if (auto it = buffer_to_idx.find(buffers[i]); it != buffer_to_idx.end())
{
+ auto idx = (*it).second;
+ CHECK(!inplace_indices_.count(idx))
+ << "In-place index " << idx << " used twice! An argument must be
aliased.";
+ inplace_indices_.insert(idx);
+ continue;
+ }
+
+ tir::Var param = tir::Var("p_output" + std::to_string(out_idx),
PrimType(DataType::Handle()));
+ out_idx++;
func_info_.buffer_map.Set(param, buffers[i]);
func_info_.params.push_back(param);
func_info_.output_buffers.insert(buffers[i].get());
@@ -476,8 +511,11 @@ class FusedTIRConstructor : public ExprVisitor {
void VisitExpr_(const CallNode* call) final {
ExprVisitor::VisitExpr_(call);
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
- ICHECK(call->op == call_tir_op_)
- << "Only call_tir is supported in primitive function, but got: " <<
GetRef<Expr>(call);
+ static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
+
+ ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_)
+ << "Only call_tir and call_tir_inplace are supported in primitive
function, but got: "
+ << GetRef<Expr>(call);
// Step 1. Get Global var and PrimFunc
GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
@@ -503,7 +541,7 @@ class FusedTIRConstructor : public ExprVisitor {
MapInputBuffer(prim_func, call->args[1]);
const Array<Array<PrimExpr>>& output_buffer_shapes =
GetCallTIROutputShapes(call);
- AllocateIntermediateBuffer(GetRef<Expr>(call), prim_func,
output_buffer_shapes);
+ AllocateIntermediateBuffer(call, prim_func, output_buffer_shapes);
// Step 6. Update tir_vars
if (call->args.size() > 2) {
@@ -566,7 +604,8 @@ class FusedTIRConstructor : public ExprVisitor {
*/
static Array<Array<PrimExpr>> GetCallTIROutputShapes(const CallNode* call) {
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
- ICHECK(call->op.same_as(call_tir_op_));
+ static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
+ ICHECK(call->op.same_as(call_tir_op_) ||
call->op.same_as(call_tir_inplace_op_));
ICHECK_EQ(call->sinfo_args.size(), 1);
auto get_tensor_shape = [](const TensorStructInfoNode* sinfo) {
const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
@@ -611,7 +650,7 @@ class FusedTIRConstructor : public ExprVisitor {
}
}
}
- // Make sure every buffers are mapped.
+ // Make sure every buffer is mapped.
ICHECK_EQ(buffer_idx, buffers.size());
}
@@ -639,28 +678,49 @@ class FusedTIRConstructor : public ExprVisitor {
MapArgsToBuffer(arg_list, buffer_list);
}
- static Array<tir::Var> GetPrimFuncOutputParams(const tir::PrimFunc& func,
size_t output_size) {
+ static Array<Integer> GetInplaceOutputIndices(const Array<Integer>&
inplace_indices,
+ int num_inputs) {
+ Array<Integer> ret;
+ int last_idx = num_inputs;
+ for (auto idx : inplace_indices) {
+ int i = idx.IntValue();
+ if (i >= 0) {
+ ret.push_back(Integer(i));
+ } else {
+ ret.push_back(Integer(last_idx));
+ last_idx++;
+ }
+ }
+
+ return ret;
+ }
+
+ static Array<tir::Var> GetPrimFuncOutputParams(const tir::PrimFunc& func,
+ const Array<Integer>&
output_indices) {
size_t n = func->params.size();
int symbolic_var_index = -1;
+ size_t output_size = output_indices.size();
ICHECK_GE(n, output_size);
- for (size_t i = 0; i < n; ++i) {
- const tir::Var& param = func->params[i];
+
+ Array<tir::Var> ret;
+ for (auto idx : output_indices) {
+ int i = idx.IntValue();
+ const tir::Var& param = func->params[static_cast<size_t>(i)];
if (param->dtype.is_int() || param->dtype.is_uint()) {
if (symbolic_var_index == -1) symbolic_var_index = i;
} else if (param->dtype.is_handle()) {
CHECK(symbolic_var_index == -1) << "The scalar input should be at the
ending of the "
"parameter list.";
+ ret.push_back(param);
} else {
LOG(FATAL) << "The params of PrimFunc are expected to be Buffer handle
or scalar, but got: "
<< param->dtype;
}
}
+
size_t end_index = symbolic_var_index == -1 ? n : symbolic_var_index;
ICHECK_GE(end_index, output_size);
- size_t begin_index = end_index - output_size;
- Array<tir::Var> output_params{func->params.begin() + begin_index,
- func->params.begin() + end_index};
- return output_params;
+ return ret;
}
/*!
@@ -670,18 +730,39 @@ class FusedTIRConstructor : public ExprVisitor {
* \param func The old TIR PrimFunc
* \param output_shapes The shape of output params.
*/
- void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func,
+ void AllocateIntermediateBuffer(const CallNode* call, const tir::PrimFunc&
func,
const Array<Array<PrimExpr>>& output_shapes)
{
+ bool is_inplace = (call->op == Op::Get("relax.call_tir_inplace"));
+
size_t n = func->params.size();
+ int num_inputs = Downcast<Tuple>(call->args[1])->fields.size();
size_t output_size = output_shapes.size();
ICHECK_GE(n, output_size);
- // Allocate intermediate buffer
- Array<tir::Buffer> alloc_buffers;
- Array<tir::Var> output_params = GetPrimFuncOutputParams(func, output_size);
+ Array<tir::Buffer> output_buffers;
+ Array<Integer> output_idxs;
+ if (is_inplace) {
+ const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
+ CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
+ output_idxs = std::move(GetInplaceOutputIndices(attrs->inplace_indices,
num_inputs));
+ } else {
+ for (size_t i = 0; i < output_size; i++) {
+ output_idxs.push_back(num_inputs + i);
+ }
+ }
+
+ Array<tir::Var> output_params = GetPrimFuncOutputParams(func, output_idxs);
+ auto input_buffers = func_info_.expr2buffers.Get(call->args[1]);
for (size_t i = 0; i < output_size; ++i) {
const tir::Var& param = output_params[i];
const tir::Buffer& buffer = func->buffer_map.at(param);
+ // if this is an inplace output, do not do an intermediate allocation
+ if (output_idxs[i].IntValue() < num_inputs) {
+ CHECK(input_buffers.defined()) << "Inplace functions must have some
defined input";
+
output_buffers.push_back(input_buffers.value()[output_idxs[i].IntValue()]);
+ continue;
+ }
+
auto unify_name_hints = [this, &buffer]() {
String base_name = buffer->name;
String unique_name = base_name + "_intermediate";
@@ -703,14 +784,14 @@ class FusedTIRConstructor : public ExprVisitor {
n->name = unify_name_hints();
tir::Buffer new_buffer(n);
func_info_.alloc_buffers.push_back(new_buffer);
- alloc_buffers.push_back(new_buffer);
+ output_buffers.push_back(new_buffer);
// Match the shape of the output buffer with the shape
func_info_.symbolic_var_matcher.Match(buffer->shape, n->shape);
func_info_.buffer_subst_map.Set(buffer, new_buffer);
}
// Update expr2buffers
- func_info_.expr2buffers.Set(expr, alloc_buffers);
+ func_info_.expr2buffers.Set(GetRef<Expr>(call), output_buffers);
}
/*!
@@ -858,6 +939,8 @@ class FusedTIRConstructor : public ExprVisitor {
FuseFuncInfo func_info_;
/*! \brief The tir function after fusion*/
tir::PrimFunc fused_tir_;
+ /*! \brief Indices of inputs that are used for in-place computation */
+ std::unordered_set<size_t> inplace_indices_;
};
std::vector<size_t> GetTupleAccessedIndices(const FunctionNode* func, const
Var& tuple_var) {
@@ -897,8 +980,11 @@ class TIRFuseMutator : public ExprMutator {
for (const auto& [gv, func] : mod->functions) {
// Only fuse primitive relax functions
if (func->IsInstance<relax::FunctionNode>() &&
func->HasNonzeroAttr(attr::kPrimitive)) {
- tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv);
- mutator.fused_tir_funcs_.Set(gv, fused_tir);
+ const auto& [prim_func, indices] =
FusedTIRConstructor::GetFusedTIR(mod, gv);
+ mutator.fused_tir_funcs_.Set(gv, prim_func);
+ if (!indices.empty()) {
+ mutator.inplace_indices_.Set(gv, indices);
+ }
}
}
@@ -945,6 +1031,7 @@ class TIRFuseMutator : public ExprMutator {
Expr VisitExpr_(const CallNode* op) final {
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+ static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
Call call =
Downcast<Call>(builder_->Normalize(ExprMutator::VisitExpr_(op)));
@@ -985,26 +1072,34 @@ class TIRFuseMutator : public ExprMutator {
CHECK(prim_value->value.defined())
<< "FuseTIR requires all R.Prim arguments to have a known
value.";
PrimExpr expr = prim_value->value.value();
- CHECK(expr->IsInstance<tir::VarNode>())
- << "FuseTIR currently requires all R.Prim arguments to provide
a single tir::Var.";
+ CHECK(expr->IsInstance<tir::VarNode>()) << "FuseTIR currently
requires all R.Prim "
+ "arguments to provide a
single tir::Var.";
tir_vars.push_back(expr);
} else {
arg_list.push_back(arg);
}
}
- // Step b. Create call_tir
+ // Step b. Create call_tir or call_tir_inplace
Array<Expr> call_args = {fused_tir_gv, Tuple(arg_list)};
if (!tir_vars.empty()) {
call_args.push_back(ShapeExpr(tir_vars));
}
- return Call(call_tir_op_, call_args, call->attrs,
{GetStructInfo(call)});
+ Op call_op = call_tir_op_;
+ Attrs call_attrs = call->attrs;
+ if (auto it = inplace_indices_.find(old_gv); it !=
inplace_indices_.end()) {
+ call_op = call_tir_inplace_op_;
+ auto inplace_attrs = make_object<CallTIRInplaceAttrs>();
+ inplace_attrs->inplace_indices = (*it).second;
+ call_attrs = Attrs(inplace_attrs);
+ }
+ return Call(call_op, call_args, call_attrs, {GetStructInfo(call)});
} else {
// Case 1.2. The callee function is not primitive, nothing to do.
return call;
}
- } else if (call->op == call_tir_op_) {
- // Case 2. It is a call_tir, re-emit the PrimFunc.
+ } else if (call->op == call_tir_op_ || call->op == call_tir_inplace_op_) {
+ // Case 2. It is a call_tir or call_tir_inplace, re-emit the PrimFunc.
if (const auto* gv = call->args[0].as<GlobalVarNode>()) {
tir::PrimFunc func =
Downcast<tir::PrimFunc>(mod_->Lookup(GetRef<GlobalVar>(gv)));
GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint);
@@ -1023,6 +1118,9 @@ class TIRFuseMutator : public ExprMutator {
const IRModule& mod_;
/*! \brief The map from global var of primitive relax function to generated
prim func. */
Map<GlobalVar, tir::PrimFunc> fused_tir_funcs_;
+ /*! \brief The map from global var of primitive relax function to in-place
indices
+ * (if there are any). */
+ Map<GlobalVar, Array<Integer>> inplace_indices_;
};
IRModule FuseTIR(IRModule mod) {
diff --git a/tests/python/relax/test_transform_fuse_ops.py
b/tests/python/relax/test_transform_fuse_ops.py
index 1a4a630e3e..3cd608d8ee 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -1501,5 +1501,146 @@ def test_partially_used_tuple_param():
_check(Module, Expected)
+def test_call_tir_inplace():
+ @I.ir_module
+ class Module:
+ @T.prim_func(private=True)
+ def add(
+ A: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+ B: T.Buffer((), "float32"),
+ Out: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[v_ax0, v_ax1], B[()])
+ T.writes(Out[v_ax0, v_ax1])
+ Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
+
+ @T.prim_func(private=True)
+ def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(A[v_i0, v_i1])
+ T.writes(A[v_i0, v_i1])
+ A[v_i0, v_i1] = T.exp(A[v_i0, v_i1])
+
+ @T.prim_func(private=True)
+ def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)),
"float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("T_squeeze"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[v_ax0, v_ax1])
+ T.writes(A[v_ax0, v_ax1])
+ A[v_ax0, v_ax1] = A[v_ax0, v_ax1]
+
+ @R.function
+ def main(
+ x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((),
dtype="float32")
+ ) -> R.Tensor((10, 20), dtype="float32"):
+ cls = Module
+ with R.dataflow():
+ lv = R.call_tir(
+ cls.add,
+ (x, p0),
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ lv1 = R.call_tir_inplace(
+ cls.exp_inplace,
+ (lv,),
+ inplace_indices=[0],
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ gv = R.call_tir_inplace(
+ cls.squeeze_inplace,
+ (lv1,),
+ inplace_indices=[0],
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def add(
+ A: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+ B: T.Buffer((), "float32"),
+ Out: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True), "op_pattern": 0})
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[v_ax0, v_ax1], B[()])
+ T.writes(Out[v_ax0, v_ax1])
+ Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
+
+ @T.prim_func(private=True)
+ def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
+ T.func_attr({"tir.noalias": T.bool(True), "op_pattern": 0})
+ for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(A[v_i0, v_i1])
+ T.writes(A[v_i0, v_i1])
+ A[v_i0, v_i1] = T.exp(A[v_i0, v_i1])
+
+ @T.prim_func(private=True)
+ def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)),
"float32")):
+ T.func_attr({"tir.noalias": T.bool(True), "op_pattern": 0})
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("T_squeeze"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[v_ax0, v_ax1])
+ T.writes(A[v_ax0, v_ax1])
+ A[v_ax0, v_ax1] = A[v_ax0, v_ax1]
+
+ @R.function(private=True)
+ def fused_add_exp_inplace_squeeze_inplace(
+ x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((),
dtype="float32")
+ ) -> R.Tensor((10, 20), dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ cls = Expected
+ with R.dataflow():
+ lv = R.call_tir(
+ cls.add,
+ (x, p0),
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ lv1 = R.call_tir_inplace(
+ cls.exp_inplace,
+ (lv,),
+ inplace_indices=[0],
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ gv = R.call_tir_inplace(
+ cls.squeeze_inplace,
+ (lv1,),
+ inplace_indices=[0],
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((),
dtype="float32")
+ ) -> R.Tensor((10, 20), dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ gv1: R.Tensor(
+ (10, 20), dtype="float32"
+ ) = cls.fused_add_exp_inplace_squeeze_inplace(x, p0)
+ R.output(gv1)
+ return gv1
+
+ _check(Module, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fuse_tir.py
b/tests/python/relax/test_transform_fuse_tir.py
index 143670c701..c0a6f4448b 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -1930,5 +1930,329 @@ def test_gather():
_check(Before, After)
+def test_inplace_simple():
+ @I.ir_module
+ class Module:
+ I.module_attrs({"foo": "bar"})
+
+ @T.prim_func(private=True)
+ def add_inplace(
+ A: T.Buffer((T.int64(10), T.int64(20)), "float32"), B:
T.Buffer((), "float32")
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ # T.reads(A[v_ax0, v_ax1], B[()])
+ # T.writes(A[v_ax0, v_ax1])
+ A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
+
+ @T.prim_func(private=True)
+ def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ # T.reads(A[v_i0, v_i1])
+ # T.writes(A[v_i0, v_i1])
+ A[v_i0, v_i1] = T.exp(A[v_i0, v_i1])
+
+ @T.prim_func(private=True)
+ def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)),
"float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("T_squeeze"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ # T.reads(A[v_ax0, v_ax1])
+ # T.writes(A[v_ax0, v_ax1])
+ A[v_ax0, v_ax1] = A[v_ax0, v_ax1]
+
+ @R.function(private=True)
+ def fused_add_exp_squeeze(
+ x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((),
dtype="float32")
+ ) -> R.Tensor((10, 20), dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ cls = Module
+ with R.dataflow():
+ # This overwrites x and is actually evil because the function
is marked as pure
+ # but we are doing it just to test the pass. The automatic
DataflowUseInplaceCalls
+ # transformation will not produce code like this, but it may
make sense to do it
+ # if ownership of x is fully and truly transferred.
+ # Users should apply with caution!
+ lv = R.call_tir_inplace(
+ cls.add_inplace,
+ (x, p0),
+ inplace_indices=[0],
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ lv1 = R.call_tir_inplace(
+ cls.exp_inplace,
+ (lv,),
+ inplace_indices=[0],
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ gv = R.call_tir_inplace(
+ cls.squeeze_inplace,
+ (lv1,),
+ inplace_indices=[0],
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((),
dtype="float32")
+ ) -> R.Tensor((10, 20), dtype="float32"):
+ cls = Module
+ with R.dataflow():
+ gv1: R.Tensor((10, 20), dtype="float32") =
cls.fused_add_exp_squeeze(x, p0)
+ R.output(gv1)
+ return gv1
+
+ @I.ir_module
+ class Expected:
+ I.module_attrs({"foo": "bar"})
+
+ @T.prim_func(private=True)
+ def fused_add_exp_squeeze(
+ x: T.Buffer((T.int64(10), T.int64(20)), "float32"), p0:
T.Buffer((), "float32")
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ x[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
+ for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ x[v_i0, v_i1] = T.exp(x[v_i0, v_i1])
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("T_squeeze"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ x[v_ax0, v_ax1] = x[v_ax0, v_ax1]
+
+ # note that this will clobber x! Use with caution
+ @R.function
+ def main(
+ x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((),
dtype="float32")
+ ) -> R.Tensor((10, 20), dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir_inplace(
+ cls.fused_add_exp_squeeze,
+ (x, p0),
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ inplace_indices=[0],
+ )
+ R.output(gv1)
+ return gv1
+
+ _check(Module, Expected)
+
+
+def test_fuse_inplace_and_non_inplace():
+ @I.ir_module
+ class Module:
+ I.module_attrs({"foo": "bar"})
+
+ @T.prim_func(private=True)
+ def add(
+ A: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+ B: T.Buffer((), "float32"),
+ Out: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
+
+ @T.prim_func(private=True)
+ def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ A[v_i0, v_i1] = T.exp(A[v_i0, v_i1])
+
+ @T.prim_func(private=True)
+ def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)),
"float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("T_squeeze"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ A[v_ax0, v_ax1] = A[v_ax0, v_ax1]
+
+ @R.function(private=True)
+ def fused_add_exp_squeeze(
+ x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((),
dtype="float32")
+ ) -> R.Tensor((10, 20), dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ cls = Module
+ with R.dataflow():
+ lv = R.call_tir(
+ cls.add,
+ (x, p0),
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ lv1 = R.call_tir_inplace(
+ cls.exp_inplace,
+ (lv,),
+ inplace_indices=[0],
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ gv = R.call_tir_inplace(
+ cls.squeeze_inplace,
+ (lv1,),
+ inplace_indices=[0],
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((),
dtype="float32")
+ ) -> R.Tensor((10, 20), dtype="float32"):
+ cls = Module
+ with R.dataflow():
+ gv1: R.Tensor((10, 20), dtype="float32") =
cls.fused_add_exp_squeeze(x, p0)
+ R.output(gv1)
+ return gv1
+
+ @I.ir_module
+ class Expected:
+ I.module_attrs({"foo": "bar"})
+
+ @T.prim_func(private=True)
+ def fused_add_exp_squeeze(
+ x: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+ p0: T.Buffer((), "float32"),
+ p_output0: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
+ for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ p_output0[v_i0, v_i1] = T.exp(p_output0[v_i0, v_i1])
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("T_squeeze"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ p_output0[v_ax0, v_ax1] = p_output0[v_ax0, v_ax1]
+
+ @R.function
+ def main(
+ x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((),
dtype="float32")
+ ) -> R.Tensor((10, 20), dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir(
+ cls.fused_add_exp_squeeze,
+ (x, p0),
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ R.output(gv1)
+ return gv1
+
+ _check(Module, Expected)
+
+
+def test_use_as_inplace_and_dps():
+ @I.ir_module
+ class Module:
+ # we will use it both in-place and normally (DPS)
+ @T.prim_func(private=True)
+ def add(
+ A: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+ B: T.Buffer((), "float32"),
+ Out: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
+
+ @R.function(private=True)
+ def fused_sums(
+ x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((),
dtype="float32")
+ ) -> R.Tensor((10, 20), dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ cls = Module
+ with R.dataflow():
+ lv = R.call_tir(
+ cls.add,
+ (x, p0),
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ lv1 = R.call_tir_inplace(
+ cls.add,
+ (x, p0, lv),
+ inplace_indices=[2],
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ lv2 = R.call_tir_inplace(
+ cls.add,
+ (x, p0, lv1),
+ inplace_indices=[2],
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ R.output(lv2)
+ return lv2
+
+ @R.function
+ def main(
+ x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((),
dtype="float32")
+ ) -> R.Tensor((10, 20), dtype="float32"):
+ cls = Module
+ with R.dataflow():
+ gv1: R.Tensor((10, 20), dtype="float32") = cls.fused_sums(x,
p0)
+ R.output(gv1)
+ return gv1
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def fused_sums(
+ x: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+ p0: T.Buffer((), "float32"),
+ p_output0: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
+
+ @R.function
+ def main(
+ x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((),
dtype="float32")
+ ) -> R.Tensor((10, 20), dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir(
+ cls.fused_sums,
+ (x, p0),
+ out_sinfo=R.Tensor((10, 20), dtype="float32"),
+ )
+ R.output(gv1)
+ return gv1
+
+ _check(Module, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()