This is an automated email from the ASF dual-hosted git repository.
sslyu pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 9ea3f887eb [Unity] Introduce call_dps_packed (#14183)
9ea3f887eb is described below
commit 9ea3f887eb7d36b6de621166e8628a445973cddd
Author: Yong Wu <[email protected]>
AuthorDate: Fri Mar 10 09:41:50 2023 -0800
[Unity] Introduce call_dps_packed (#14183)
Introduce call_dps_packed to call packed functions in destination-passing
style, reserving call_tir for TIR PrimFuncs instead.
* [Unity] Introduce call_dps_packed
* fix lint
* Fix comments
* Remove well_form update, enforce in InferStructInfoCallTIR
* Update src/relax/op/op.cc
* Update description of call_tir
* Remove unnecessary check in passes
---
include/tvm/relax/dataflow_pattern.h | 4 +
include/tvm/relax/transform.h | 2 +-
python/tvm/relax/__init__.py | 2 +-
python/tvm/relax/dpl/pattern.py | 18 +-
python/tvm/relax/expr.py | 2 +-
python/tvm/relax/op/base.py | 54 ++++-
python/tvm/relax/transform/transform.py | 2 +-
python/tvm/script/ir_builder/relax/ir.py | 2 +
src/relax/backend/task_extraction.cc | 5 -
src/relax/ir/dataflow_pattern.cc | 14 ++
src/relax/op/op.cc | 41 ++++
src/relax/transform/call_tir_rewrite.cc | 5 +-
src/relax/transform/fold_constant.cc | 12 +-
src/relax/transform/fuse_ops.cc | 20 +-
src/relax/transform/fuse_tir.cc | 23 +-
src/relax/transform/legalize_ops.cc | 3 +-
src/relax/transform/rewrite_dataflow_reshape.cc | 7 +-
src/relax/transform/run_codegen.cc | 9 +-
src/script/printer/relax/call.cc | 13 +-
src/script/printer/relax/utils.h | 3 +-
tests/python/relax/test_analysis.py | 12 +-
tests/python/relax/test_analysis_well_formed.py | 6 +-
tests/python/relax/test_ast_printer.py | 72 +++++-
tests/python/relax/test_binding_rewrite.py | 16 +-
tests/python/relax/test_dataflow_pattern.py | 254 +++++++++++----------
tests/python/relax/test_op_misc.py | 2 +-
tests/python/relax/test_transform.py | 46 +++-
.../relax/test_transform_attach_global_symbol.py | 4 +-
tests/python/relax/test_transform_bind_params.py | 4 +-
tests/python/relax/test_transform_codegen_pass.py | 4 +-
tests/python/relax/test_transform_fuse_ops.py | 4 +-
tests/python/relax/test_transform_fuse_tir.py | 2 +-
tests/python/relax/test_transform_normalize.py | 6 +-
tests/python/relax/test_tvmscript_ir_builder.py | 26 ++-
tests/python/relax/test_tvmscript_parser.py | 123 ++++++----
tests/python/relax/test_tvmscript_printer_relax.py | 15 +-
tests/python/relax/test_vm_build.py | 6 +-
37 files changed, 567 insertions(+), 276 deletions(-)
diff --git a/include/tvm/relax/dataflow_pattern.h
b/include/tvm/relax/dataflow_pattern.h
index 701879745e..37640750a8 100644
--- a/include/tvm/relax/dataflow_pattern.h
+++ b/include/tvm/relax/dataflow_pattern.h
@@ -793,6 +793,10 @@ ExprPattern IsOp(const String& op_name);
CallPattern IsCallTIR(const String& name, Optional<TuplePattern> args =
NullOpt);
/*! \brief Syntatic Sugar for call_tir (return a tuple of tensor) */
CallPattern IsCallTIR(const String& name, TuplePattern var_args);
+/*! \brief Syntatic Sugar for call_dps_packed (return a tensor) */
+CallPattern IsCallDPSPacked(const String& name, Optional<TuplePattern> args =
NullOpt);
+/*! \brief Syntatic Sugar for call_dps_packed (return a tuple of tensor) */
+CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args);
/*! \brief Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern
(unordered=true) */
DFPattern IsTuple(const Array<DFPattern>& fields, bool unordered = false);
/*! \brief Syntatic Sugar for creating a TupleGetItemPattern */
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 9838fe53b3..446b75da9f 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -83,7 +83,7 @@ TVM_DLL Pass LambdaLift();
TVM_DLL Pass ToNonDataflow();
/*!
- * \brief Perform explicit tensor allocation for call_tir.
+ * \brief Perform explicit tensor allocation for call_tir and call_dps_packed.
*
* \return The Pass.
*/
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index bbd2040dd9..edbd848bd5 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -56,7 +56,7 @@ from .ty import Type, ObjectType, ShapeType, DynTensorType,
TupleType, FuncType,
from .exec_builder import ExecBuilder
# Operator
-from .op.base import call_tir
+from .op.base import call_tir, call_dps_packed
# BlockBuilder
from .block_builder import BlockBuilder
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index 300b0af568..1ca41b378d 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -862,11 +862,23 @@ def is_call_tir(
return _is_call_tir(func_pattern, args)
-def is_call_tir_extern(
+def _is_call_dps_packed(
+ func_pattern: DFPattern,
+ args: Union[List, Tuple, TuplePattern] = None,
+) -> CallPattern:
+ if args is None:
+ args = wildcard()
+ elif isinstance(args, (list, tuple)):
+ args = TuplePattern(args)
+
+ return is_op("relax.call_dps_packed")(func_pattern, args)
+
+
+def is_call_dps_packed(
func_name: str,
args: Union[List, Tuple, TuplePattern] = None,
) -> CallPattern:
- """Syntax sugar for creating a CallPattern for call_tir that calls an
extern function
+ """Syntax sugar for creating a CallPattern for call_dps_packed
Parameters
----------
@@ -881,7 +893,7 @@ def is_call_tir_extern(
The resulting CallPattern
"""
func_pattern = ExternFuncPattern(func_name)
- return _is_call_tir(func_pattern, args)
+ return _is_call_dps_packed(func_pattern, args)
def is_call_packed(
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index ab332eed61..4af08a3118 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -597,7 +597,7 @@ class Function(BaseFunc, Scriptable):
@tvm._ffi.register_object("relax.expr.ExternFunc")
class ExternFunc(BaseFunc):
- """extern function, which can represent a TIR PrimFunc or a PackedFunc."""
+ """extern function, which represents a PackedFunc."""
global_symbol: String
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index 0b298679c1..aef0e731db 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -22,7 +22,7 @@ import tvm
from tvm.runtime.object import Object
from . import _ffi_api
-from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc
+from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar
from ..expr import Tuple as RxTuple
from ..struct_info import StructInfo, TensorStructInfo
from ...ir import PrimExpr
@@ -45,18 +45,18 @@ def null_value() -> Call:
@args_converter.auto
def call_tir(
- func: Union[str, Expr],
+ gvar: GlobalVar,
args: Expr,
out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]],
tir_vars: Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] =
None,
) -> Call:
"""
- Call a destination-passing-style function and return the output.
+ Call a tir.prim_func and return the output.
Parameters
----------
- func : Union[str, Expr]
- The destination-passing-style function, can be ExternFunc or PrimFunc.
+ gvar : GlobalVar
+ The GlobalVar referring to a tir PrimFunc.
args : Expr
The input arguments.
@@ -74,9 +74,6 @@ def call_tir(
ret: Call
A call node for the call_tir operator.
"""
- if isinstance(func, str):
- func = ExternFunc(func)
-
if isinstance(args, Expr) and not isinstance(args, RxTuple): # type:
ignore
args = RxTuple((args,))
@@ -86,7 +83,46 @@ def call_tir(
if isinstance(tir_vars, (list, tuple)):
tir_vars = ShapeExpr(tir_vars)
- return _ffi_api.call_tir(func, args, out_sinfo, tir_vars) # type: ignore
+ return _ffi_api.call_tir(gvar, args, out_sinfo, tir_vars) # type: ignore
+
+
+@args_converter.auto
+def call_dps_packed(
+ func: Union[str, Expr],
+ args: Expr,
+ out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]],
+) -> Call:
+ """
+ Call a destination-passing-style packed function and return the output.
+
+ Parameters
+ ----------
+ func : Union[str, Expr]
+ The destination-passing-style function, can be ExternFunc.
+
+ args : Expr
+ The input arguments.
+
+ out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]]
+ The structure info of the call_dps_packed output.
+ It should be a single or a list of TensorStructInfo. Each one denotes
the
+ structure info of a returned tensor.
+
+ Returns
+ -------
+ ret: Call
+ A call node for the call_dps_packed operator.
+ """
+ if isinstance(func, str):
+ func = ExternFunc(func)
+
+ if isinstance(args, Expr) and not isinstance(args, RxTuple): # type:
ignore
+ args = RxTuple((args,))
+
+ if not isinstance(out_sinfo, list):
+ out_sinfo = [out_sinfo]
+
+ return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore
@args_converter.auto
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 48560792e1..97c8772b3b 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -60,7 +60,7 @@ def LambdaLift():
def CallTIRRewrite() -> tvm.ir.transform.Pass:
- """Perform explicit tensor allocation for call_tir.
+ """Perform explicit tensor allocation for call_tir and call_dps_packed.
Returns
-------
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 466a38837f..c658b6f77d 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -48,6 +48,7 @@ from tvm.relax.op import (
builtin,
call_builtin_with_ctx,
call_tir,
+ call_dps_packed,
ceil,
clip,
collapse_sum_like,
@@ -545,6 +546,7 @@ __all__ = [
"builtin",
"call_packed",
"call_tir",
+ "call_dps_packed",
"call_builtin_with_ctx",
"ceil",
"clip",
diff --git a/src/relax/backend/task_extraction.cc
b/src/relax/backend/task_extraction.cc
index beb3950af1..5bd764c68e 100644
--- a/src/relax/backend/task_extraction.cc
+++ b/src/relax/backend/task_extraction.cc
@@ -73,11 +73,6 @@ class TaskExtractor : public ExprVisitor {
return;
}
- // Do not extract external function
- if (call->args[0].as<ExternFuncNode>()) {
- return;
- }
-
const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
const tir::PrimFunc& func =
Downcast<tir::PrimFunc>(mod_->Lookup(global_var));
diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc
index 3768627c20..5eb1bf3ea6 100644
--- a/src/relax/ir/dataflow_pattern.cc
+++ b/src/relax/ir/dataflow_pattern.cc
@@ -563,6 +563,20 @@ CallPattern IsCallTIR(const String& name,
Optional<TuplePattern> var_args) {
CallPattern IsCallTIR(const String& name, TuplePattern var_args) {
return IsOp("relax.call_tir")(GlobalVarPattern(name), var_args);
}
+CallPattern IsCallDPSPacked(const String& name, Optional<TuplePattern>
var_args) {
+ DFPattern arg_pattern;
+ if (!var_args.defined()) {
+ arg_pattern = Wildcard();
+ } else {
+ arg_pattern = var_args.value();
+ }
+
+ return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), arg_pattern);
+}
+
+CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args) {
+ return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), var_args);
+}
DFPattern IsTuple(const Array<DFPattern>& fields, bool unordered) {
if (unordered)
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index c78ca539bd..cf084d6d20 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -80,6 +80,9 @@ StructInfo InferStructInfoCallTIR(const Call& call, const
BlockBuilder& ctx) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "sinfo_args should have exact 1 output struct info.");
}
+ CHECK(call->args[0]->IsInstance<GlobalVarNode>())
+ << "call_tir expects the first argument to be a GlobalVar referring to a
TIR PrimFunc. "
+ << "However, gets " << call->args[0];
return call->sinfo_args[0];
}
@@ -121,6 +124,44 @@ Expr MakeCallTIR(Expr func, Tuple args,
Array<TensorStructInfo> out_sinfo_list,
TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR);
+// call_dps_packed
+
+StructInfo InferStructInfoCallDPSPacked(const Call& call, const BlockBuilder&
ctx) {
+ if (call->sinfo_args.size() != 1) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "sinfo_args should have exact 1 output struct info.");
+ }
+ return call->sinfo_args[0];
+}
+
+RELAY_REGISTER_OP("relax.call_dps_packed")
+ .set_num_inputs(2)
+ .add_argument("func", "Expr", "The destination-passing-style function.")
+ .add_argument("args", "Tuple", "The input arguments.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoCallDPSPacked);
+
+Expr MakeCallDPSPacked(Expr func, Tuple args, Array<TensorStructInfo>
out_sinfo_list) {
+ for (const TensorStructInfo& sinfo : out_sinfo_list) {
+ const auto* shape = sinfo->shape.as<ShapeExprNode>();
+ CHECK(shape != nullptr)
+ << "out_sinfo of call_dps_packed should have defined ShapeExpr as
shape. "
+ "However, one given structure info is "
+ << sinfo;
+ }
+
+ StructInfo out_sinfo{nullptr};
+ if (out_sinfo_list.size() == 1) {
+ out_sinfo = out_sinfo_list[0];
+ } else {
+ out_sinfo = TupleStructInfo({out_sinfo_list.begin(),
out_sinfo_list.end()});
+ }
+
+ static const Op& op = Op::Get("relax.call_dps_packed");
+ return Call(op, {func, args}, {}, {out_sinfo});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.call_dps_packed").set_body_typed(MakeCallDPSPacked);
+
// call builtin
StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const
BlockBuilder& ctx) {
if (call->sinfo_args.size() == 0) {
diff --git a/src/relax/transform/call_tir_rewrite.cc
b/src/relax/transform/call_tir_rewrite.cc
index 2ea039e022..6066ed8d2a 100644
--- a/src/relax/transform/call_tir_rewrite.cc
+++ b/src/relax/transform/call_tir_rewrite.cc
@@ -33,7 +33,7 @@ namespace relax {
// ==================
// CallTIRMutator
-// Perform explicit tensor allocation for call_tir.
+// Perform explicit tensor allocation for call_tir or call_dps_packed.
// Example:
// lv0: Tensor(n, m) = rx.call_tir(func, (x), (n, m), dtype="float32")
// -->
@@ -49,10 +49,11 @@ class CallTIRMutator : public ExprMutator {
call = expr.as<CallNode>();
static const Op& call_tir_op = Op::Get("relax.call_tir");
+ static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn");
- if (call->op == call_tir_op) {
+ if (call->op == call_tir_op || call->op == call_dps_packed_op) {
Array<Expr> outs;
if (const auto& _tensor_sinfo = MatchStructInfo<TensorStructInfo>(expr))
{
// single output case
diff --git a/src/relax/transform/fold_constant.cc
b/src/relax/transform/fold_constant.cc
index 6b28f31889..622dd9ad09 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -87,13 +87,11 @@ class ConstantFolder : public ExprMutator {
* \return The TIR function, or nullopt if pattern match fails.
*/
Optional<tir::PrimFunc> MatchPrimFunc(const Expr& op) {
- if (auto* ptr = op.as<GlobalVarNode>()) {
- // NOTE: as check works for nullptr(returns null)
- Optional<BaseFunc> base_func =
-
builder_->GetContextIRModule()->functions.Get(GetRef<GlobalVar>(ptr));
- if (auto* pfunc = base_func.as<tir::PrimFuncNode>()) {
- return GetRef<tir::PrimFunc>(pfunc);
- }
+ const GlobalVar& global_var = Downcast<GlobalVar>(op);
+ // NOTE: as check works for nullptr(returns null)
+ Optional<BaseFunc> base_func =
builder_->GetContextIRModule()->functions.Get(global_var);
+ if (auto* pfunc = base_func.as<tir::PrimFuncNode>()) {
+ return GetRef<tir::PrimFunc>(pfunc);
}
return NullOpt;
}
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 3b6b3c17ac..6d7c278d80 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -185,19 +185,17 @@ class GraphCreator : public ExprVisitor {
// recurse into the call expression.
const auto* op = call->op.as<OpNode>();
if (op == call_tir_op_.get()) {
- // Skip ExternFunc for call_dps_packed.
- if (const auto* global_var = call->args[0].as<GlobalVarNode>()) {
- tir::PrimFunc func =
Downcast<tir::PrimFunc>(mod_->Lookup(GetRef<GlobalVar>(global_var)));
+ const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
+ tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(global_var));
- // Override args for call_tir
- args = Downcast<Tuple>(call->args[1])->fields;
+ // Override args for call_tir
+ args = Downcast<Tuple>(call->args[1])->fields;
- Optional<Integer> opt_pattern = func->GetAttr<Integer>("op_pattern");
- if (opt_pattern.defined()) {
- pattern =
static_cast<OpPatternKind>(Downcast<IntImm>(opt_pattern)->value);
- } else {
- pattern = OpPatternKind::kOpaque;
- }
+ Optional<Integer> opt_pattern = func->GetAttr<Integer>("op_pattern");
+ if (opt_pattern.defined()) {
+ pattern =
static_cast<OpPatternKind>(Downcast<IntImm>(opt_pattern)->value);
+ } else {
+ pattern = OpPatternKind::kOpaque;
}
}
// The pattern of the current binding variable node is set to the pattern
of this operator.
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 925f09d85d..e90d6e4bc1 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -302,11 +302,10 @@ class FusedTIRConstructor : public ExprVisitor {
// Step 1. Get Global var and PrimFunc
GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
- Optional<tir::PrimFunc> prim_func_ = GetPrimFunc(gv);
- ICHECK(prim_func_.defined()) << "Cannot find the prim_func of the call_tir
in the module: "
- << gv;
+ tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+
// Step 2. Renew all vars/buffer definitions and blocks to avoid
duplication
- tir::PrimFunc prim_func = tir::RenewDefs(prim_func_.value());
+ tir::PrimFunc prim_func = tir::RenewDefs(prim_func_);
// Step 3. Check functions are all schedulable funcs. i.e. the body of
func is root block
// TODO(Siyuan): support un-schedulable functions.
@@ -364,22 +363,6 @@ class FusedTIRConstructor : public ExprVisitor {
LOG(FATAL) << "Relax.Constant is not supported in primitive functions.";
}
- /********** Helper Functions **********/
-
- /*!
- * \brief Pattern match op to a TIR function and look it up.
- * \return The TIR function, or NullOpt if patter match fails.
- */
- Optional<tir::PrimFunc> GetPrimFunc(const GlobalVar& global_var) {
- // NOTE: as check works for nullptr(returns null)
- Optional<BaseFunc> base_func = mod_->functions.Get(global_var);
- if (auto* pfunc = base_func.as<tir::PrimFuncNode>()) {
- return GetRef<tir::PrimFunc>(pfunc);
- } else {
- return NullOpt;
- }
- }
-
/*!
* \brief Get the number of outputs for a call_tir node.
* \return The number of outputs.
diff --git a/src/relax/transform/legalize_ops.cc
b/src/relax/transform/legalize_ops.cc
index f9a84c5361..350a40c37b 100644
--- a/src/relax/transform/legalize_ops.cc
+++ b/src/relax/transform/legalize_ops.cc
@@ -75,6 +75,7 @@ class LegalizeMutator : public ExprMutator {
Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
static const Op& call_tir_op = Op::Get("relax.call_tir");
+ static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
auto* op_node = visited_call->op.as<OpNode>();
// Not an OpNode
@@ -102,7 +103,7 @@ class LegalizeMutator : public ExprMutator {
}
// No legalization.
- if (op != call_tir_op) {
+ if (op != call_tir_op && op != call_dps_packed_op) {
LOG(WARNING) << "No legalization func for " << op->name << " is found.";
}
return visited_call;
diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc
b/src/relax/transform/rewrite_dataflow_reshape.cc
index aec0911ecc..e5d654fba3 100644
--- a/src/relax/transform/rewrite_dataflow_reshape.cc
+++ b/src/relax/transform/rewrite_dataflow_reshape.cc
@@ -75,11 +75,8 @@ class DataflowReshapeRewriter : public ExprMutator {
if (call->op != call_tir_op) {
return false;
}
- const auto* gv = call->args[0].as<GlobalVarNode>();
- if (gv == nullptr) {
- return false;
- }
- const auto* func =
mod_->functions.Get(GetRef<GlobalVar>(gv)).as<tir::PrimFuncNode>();
+ const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
+ const auto* func = mod_->functions.Get(global_var).as<tir::PrimFuncNode>();
ICHECK_NOTNULL(func);
return HasReshapePattern(GetRef<tir::PrimFunc>(func));
}
diff --git a/src/relax/transform/run_codegen.cc
b/src/relax/transform/run_codegen.cc
index 7deeb139d1..b5a4d7536f 100644
--- a/src/relax/transform/run_codegen.cc
+++ b/src/relax/transform/run_codegen.cc
@@ -75,17 +75,18 @@ class CodeGenRunner : ExprMutator {
if (auto const* gvar_node = call_node->op.as<GlobalVarNode>()) {
const GlobalVar gvar = GetRef<GlobalVar>(gvar_node);
- auto create_call_tir = [call_node, this](Expr extern_func, StructInfo
ret_struct_info) {
+ auto create_call_dps_packed = [call_node, this](Expr extern_func,
+ StructInfo
ret_struct_info) {
Array<Expr> new_args({extern_func});
new_args.push_back(Tuple(call_node->args.Map([this](Expr arg) { return
VisitExpr(arg); })));
- static const Op& call_op = Op::Get("relax.call_tir");
+ static const Op& call_op = Op::Get("relax.call_dps_packed");
return Call(call_op, new_args, tvm::Attrs(), {ret_struct_info});
};
if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) {
- return create_call_tir(it->second.first, it->second.second);
+ return create_call_dps_packed(it->second.first, it->second.second);
} else {
// TODO(@sunggg): Is there any better way to get this func?
Function func =
Downcast<Function>(builder_->GetContextIRModule()->Lookup(gvar));
@@ -101,7 +102,7 @@ class CodeGenRunner : ExprMutator {
func = (*RemoveFuncAttrFunc)(func, tvm::attr::kGlobalSymbol);
func = (*RemoveFuncAttrFunc)(func, attr::kCodegen);
builder_->UpdateFunction(gvar, func);
- return create_call_tir(new_func, func->ret_struct_info);
+ return create_call_dps_packed(new_func, func->ret_struct_info);
}
}
}
diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc
index 2feb2082c5..e99b81df8b 100644
--- a/src/script/printer/relax/call.cc
+++ b/src/script/printer/relax/call.cc
@@ -95,9 +95,11 @@ ExprDoc PrintCallee(const relax::Expr& n, const ObjectPath&
n_p, const IRDocsifi
}
}
-Optional<ExprDoc> PrintCallTIR(const relax::Call& n, const ObjectPath& n_p,
const IRDocsifier& d) {
+Optional<ExprDoc> PrintCallTIRDPSPacked(const relax::Call& n, const
ObjectPath& n_p,
+ const IRDocsifier& d) {
static const Op& call_tir_op = Op::Get("relax.call_tir");
- if (!n->op.same_as(call_tir_op)) {
+ static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
+ if (!n->op.same_as(call_tir_op) && !n->op.same_as(call_dps_packed_op)) {
return NullOpt;
}
ICHECK(n->args.size() == 2 || n->args.size() == 3);
@@ -123,6 +125,9 @@ Optional<ExprDoc> PrintCallTIR(const relax::Call& n, const
ObjectPath& n_p, cons
} else {
kwargs_values.push_back(d->AsDoc<ExprDoc>(o_sinfo, o_sinfo_p));
}
+ if (n->op.same_as(call_dps_packed_op)) {
+ return Relax(d, "call_dps_packed")->Call(args, kwargs_keys, kwargs_values);
+ }
// Step 4. Print n->args[2], the tir variables
if (n->args.size() == 3) {
kwargs_keys.push_back("tir_vars");
@@ -134,8 +139,8 @@ Optional<ExprDoc> PrintCallTIR(const relax::Call& n, const
ObjectPath& n_p, cons
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::Call>( //
"", [](relax::Call n, ObjectPath n_p, IRDocsifier d) -> Doc {
- // Special case: call_tir
- if (Optional<ExprDoc> doc = PrintCallTIR(n, n_p, d)) {
+ // Special case: call_tir, call_dps_packed
+ if (Optional<ExprDoc> doc = PrintCallTIRDPSPacked(n, n_p, d)) {
return doc.value();
}
ExprDoc prefix{nullptr};
diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h
index 7702f7b22d..8c4281ad78 100644
--- a/src/script/printer/relax/utils.h
+++ b/src/script/printer/relax/utils.h
@@ -82,7 +82,8 @@ inline Optional<ExprDoc> StructInfoAsAnn(const relax::Var& v,
const ObjectPath&
}
if (const auto* call = rhs.as<relax::CallNode>()) {
static const Op& call_tir_op = Op::Get("relax.call_tir");
- if (call->op.same_as(call_tir_op)) {
+ static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
+ if (call->op.same_as(call_tir_op) || call->op.same_as(call_dps_packed_op))
{
return NullOpt;
}
}
diff --git a/tests/python/relax/test_analysis.py
b/tests/python/relax/test_analysis.py
index 4a345224e5..8b26a2aa64 100644
--- a/tests/python/relax/test_analysis.py
+++ b/tests/python/relax/test_analysis.py
@@ -66,8 +66,10 @@ def test_chained_remove_all_unused():
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
lv0 = x
- unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32),
dtype="float32"))
- unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32,
32), dtype="float32"))
+ unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32,
32), dtype="float32"))
+ unused1 = R.call_dps_packed(
+ "my_dps_func", (unused0,), R.Tensor((32, 32),
dtype="float32")
+ )
R.output(lv0)
return lv0
@@ -92,8 +94,10 @@ def test_binding_block_remove_all_unused():
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
lv0 = x
- unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32),
dtype="float32"))
- unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32,
32), dtype="float32"))
+ unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32,
32), dtype="float32"))
+ unused1 = R.call_dps_packed(
+ "my_dps_func", (unused0,), R.Tensor((32, 32),
dtype="float32")
+ )
R.output(lv0)
z = R.call_packed("vm.builtin.copy", lv0,
sinfo_args=(R.Tensor((32, 32), "float32")))
return z
diff --git a/tests/python/relax/test_analysis_well_formed.py
b/tests/python/relax/test_analysis_well_formed.py
index 7b8035b17c..49d2b76011 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -492,7 +492,7 @@ def test_sinfo_args_tir_var_used_before_define_call_tir():
# Error: Symbolic Var m1, n1 are not defined
m1 = tir.Var("m1", "int64")
n1 = tir.Var("n1", "int64")
- call = R.call_tir("my_func", x, out_sinfo=R.Tensor((m1, n1), "float32"))
+ call = R.call_dps_packed("my_func", x, out_sinfo=R.Tensor((m1, n1),
"float32"))
func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"),
call)])])
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
assert not rx.analysis.well_formed(mod, check_struct_info=False)
@@ -505,12 +505,12 @@ def test_sinfo_erase_to_well_formed():
def foo(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m1",
"n1"), dtype="float32"):
m = T.int64()
n = T.int64()
- gv = R.call_tir("my_func", (x,), out_sinfo=R.Tensor((m, n),
dtype="float32"))
+ gv = R.call_dps_packed("my_func", (x,), out_sinfo=R.Tensor((m, n),
dtype="float32"))
return gv
"""
m1 = tir.Var("m1", "int64")
n1 = tir.Var("n1", "int64")
- call = R.call_tir("my_func", x, out_sinfo=R.Tensor((m, n), "float32"))
+ call = R.call_dps_packed("my_func", x, out_sinfo=R.Tensor((m, n),
"float32"))
blocks = [rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]
seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var)
func = rx.Function([x], seq_expr, R.Tensor((m1, n1), "float32")).with_attr(
diff --git a/tests/python/relax/test_ast_printer.py
b/tests/python/relax/test_ast_printer.py
index e7f2feeaa0..de71e81464 100644
--- a/tests/python/relax/test_ast_printer.py
+++ b/tests/python/relax/test_ast_printer.py
@@ -429,10 +429,74 @@ def test_call_packed():
def test_call_tir():
# also from test_parser
+ @tvm.script.ir_module
+ class TestCallTIR:
+ @T.prim_func
+ def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16),
"int32")) -> None:
+ T.func_attr(({"global_symbol": "addone"}))
+ for i, j in T.grid(16, 16):
+ with T.block("addone"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] + T.int32(1)
+
+ @R.function
+ def foo(x: R.Tensor(("m", "n"), "float32")):
+ m, n = T.var("int64"), T.var("int64")
+ gv0 = R.call_tir(addone, (x,), R.Tensor((m, n), dtype="float32"))
+ return gv0
+
+ mod = TestCallTIR
+ foo = mod["foo"]
+
+ foo_str = strip_whitespace(
+ dump_ast(
+ foo,
+ include_type_annotations=False,
+ include_struct_info_annotations=False,
+ include_call_attrs=False,
+ )
+ )
+ assert foo_str.startswith('Function(params=[Var(name_hint="x")]')
+
+ # call_tir is an op in Relax and it takes an extern func as an argument
+ assert isinstance(foo.body, rx.SeqExpr)
+ tir_call = foo.body.blocks[0].bindings[0].value
+ tir_call_text = dump_ast(
+ tir_call,
+ include_type_annotations=False,
+ include_struct_info_annotations=False,
+ include_call_attrs=False,
+ )
+ assert_fields(
+ "Call",
+ {
+ "op": 'Op(name="relax.call_tir")',
+ "args": """[
+ GlobalVar(name_hint="addone"),
+ Tuple(fields=[Var(name_hint="x")])
+ ]""",
+ "sinfo_args": """[
+ TensorStructInfo(
+ dtype=float32,
+ shape=ShapeExpr(
+ values=[
+ PrimExpr(value=`m`),
+ PrimExpr(value=`n`)
+ ]
+ )
+ )
+ ]""",
+ },
+ tir_call_text,
+ )
+ assert strip_whitespace(tir_call_text) in foo_str
+
+
+def test_call_dps_packed():
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")):
- m, n = T.int64(), T.int64()
- gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n),
dtype="float32"))
+ m, n = T.var("int64"), T.var("int64")
+ gv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n),
dtype="float32"))
return gv0
foo_str = strip_whitespace(
@@ -445,7 +509,7 @@ def test_call_tir():
)
assert foo_str.startswith('Function(params=[Var(name_hint="x")]')
- # call_tir is an op in Relax and it takes an extern func as an argument
+ # call_dps_packed is an op in Relax and it takes an extern func as an
argument
assert isinstance(foo.body, rx.SeqExpr)
tir_call = foo.body.blocks[0].bindings[0].value
tir_call_text = dump_ast(
@@ -457,7 +521,7 @@ def test_call_tir():
assert_fields(
"Call",
{
- "op": 'Op(name="relax.call_tir")',
+ "op": 'Op(name="relax.call_dps_packed")',
"args": """[
ExternFunc(global_symbol="test.op.identity"),
Tuple(fields=[Var(name_hint="x")])
diff --git a/tests/python/relax/test_binding_rewrite.py
b/tests/python/relax/test_binding_rewrite.py
index 1b424b9792..d0d3344eb6 100644
--- a/tests/python/relax/test_binding_rewrite.py
+++ b/tests/python/relax/test_binding_rewrite.py
@@ -228,8 +228,10 @@ def test_chained_rm_all_unused():
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
lv0 = x
- unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32),
dtype="float32"))
- unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32,
32), dtype="float32"))
+ unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32,
32), dtype="float32"))
+ unused1 = R.call_dps_packed(
+ "my_sigmoid", (unused0,), R.Tensor((32, 32),
dtype="float32")
+ )
R.output(lv0)
return lv0
@@ -262,19 +264,19 @@ def test_simple_replace_all_uses():
# \ /
# lv4
with R.dataflow():
- lv0: R.Tensor((32, 32), "float32") = R.call_tir(
+ lv0: R.Tensor((32, 32), "float32") = R.call_dps_packed(
"my_relu", (x,), R.Tensor((32, 32), dtype="float32")
)
- lv1: R.Tensor((32, 32), "float32") = R.call_tir(
+ lv1: R.Tensor((32, 32), "float32") = R.call_dps_packed(
"my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")
)
- lv2: R.Tensor((32, 32), "float32") = R.call_tir(
+ lv2: R.Tensor((32, 32), "float32") = R.call_dps_packed(
"my_add", (x, lv0), R.Tensor((32, 32), dtype="float32")
)
- lv3: R.Tensor((32, 32), "float32") = R.call_tir(
+ lv3: R.Tensor((32, 32), "float32") = R.call_dps_packed(
"my_mul", (x, lv0), R.Tensor((32, 32), dtype="float32")
)
- lv4: R.Tensor((32, 32), "float32") = R.call_tir(
+ lv4: R.Tensor((32, 32), "float32") = R.call_dps_packed(
"my_whatever", (lv2, lv3), R.Tensor((32, 32),
dtype="float32")
)
R.output(lv4)
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index ab7a5540ad..ba6ea99523 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -354,15 +354,15 @@ def test_simple_oub():
def test_counter_syntax_match():
with PatternContext() as ctx:
- n0 = is_call_tir_extern("tir_matmul")
- n1 = is_call_tir_extern("tir_impossible")
+ n0 = is_call_dps_packed("extern_matmul")
+ n1 = is_call_dps_packed("extern_impossible")
n0 >> n1
dfb = main_fn.body.blocks[0]
assert not ctx.match_dfb(dfb)
with PatternContext() as ctx:
- n0 = is_call_tir_extern("tir_matmul")
- n1 = is_call_tir_extern("tir_impossible")
+ n0 = is_call_dps_packed("extern_matmul")
+ n1 = is_call_dps_packed("extern_impossible")
n0 ^ n1
dfb = main_fn.body.blocks[0]
assert not ctx.match_dfb(dfb)
@@ -378,20 +378,20 @@ class Diamond:
# relu sigmoid
# \ /
# add
- lv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32),
dtype="float32"))
- lv1 = R.call_tir("tir_relu", (lv0,), R.Tensor((32, 32),
dtype="float32"))
- lv2 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32),
dtype="float32"))
- lv3 = R.call_tir("tir_add", (lv1, lv2), R.Tensor((32, 32),
dtype="float32"))
+ lv0 = R.call_dps_packed("extern_matmul", (x, w), R.Tensor((32,
32), dtype="float32"))
+ lv1 = R.call_dps_packed("extern_relu", (lv0,), R.Tensor((32, 32),
dtype="float32"))
+ lv2 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32,
32), dtype="float32"))
+ lv3 = R.call_dps_packed("extern_add", (lv1, lv2), R.Tensor((32,
32), dtype="float32"))
R.output(lv3)
return lv3
def test_diamond():
with PatternContext() as ctx:
- n0 = is_call_tir_extern("tir_matmul")
- n1 = is_call_tir_extern("tir_relu")
- n2 = is_call_tir_extern("tir_sigmoid")
- n3 = is_call_tir_extern("tir_add")
+ n0 = is_call_dps_packed("extern_matmul")
+ n1 = is_call_dps_packed("extern_relu")
+ n2 = is_call_dps_packed("extern_sigmoid")
+ n3 = is_call_dps_packed("extern_add")
n0 ^ n1
n0 ^ n2
@@ -399,15 +399,15 @@ def test_diamond():
n2 >> n3
dfb = Diamond["main"].body.blocks[0]
- assert ctx.match_dfb(dfb)
+ assert ctx.match_dfb(dfb)
# simplify it with fork_to
with PatternContext() as ctx:
- n1 = is_call_tir_extern("tir_relu")
- n2 = is_call_tir_extern("tir_sigmoid")
- n3 = is_call_tir_extern("tir_add")
+ n1 = is_call_dps_packed("extern_relu")
+ n2 = is_call_dps_packed("extern_sigmoid")
+ n3 = is_call_dps_packed("extern_add")
- is_call_tir_extern("tir_matmul").fork_to(n1, n2)
+ is_call_dps_packed("extern_matmul").fork_to(n1, n2)
n1 >> n3
n2 >> n3
@@ -417,10 +417,10 @@ def test_diamond():
def test_diamond_counter_oub():
with PatternContext() as ctx:
- n0 = is_call_tir_extern("tir_matmul")
- n1 = is_call_tir_extern("tir_relu")
- n2 = is_call_tir_extern("tir_sigmoid")
- n3 = is_call_tir_extern("tir_add")
+ n0 = is_call_dps_packed("extern_matmul")
+ n1 = is_call_dps_packed("extern_relu")
+ n2 = is_call_dps_packed("extern_sigmoid")
+ n3 = is_call_dps_packed("extern_add")
n0 >> n1
n0 >> n2
@@ -440,8 +440,8 @@ class SmallDiamond:
# / \
# \ /
# add
- lv0 = R.call_tir("my_relu", (x,), R.Tensor((32, 32),
dtype="float32"))
- lv1 = R.call_tir("my_add", (lv0, lv0), R.Tensor((32, 32),
dtype="float32"))
+ lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32),
dtype="float32"))
+ lv1 = R.call_dps_packed("my_add", (lv0, lv0), R.Tensor((32, 32),
dtype="float32"))
R.output(lv1)
return lv1
@@ -454,9 +454,9 @@ class SmallParallel:
# relu relu
# \ /
# add
- lv0 = R.call_tir("my_relu", (x,), R.Tensor((32, 32),
dtype="float32"))
- lv1 = R.call_tir("my_relu", (x,), R.Tensor((32, 32),
dtype="float32"))
- lv2 = R.call_tir("my_add", (lv0, lv1), R.Tensor((32, 32),
dtype="float32"))
+ lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32),
dtype="float32"))
+ lv1 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32),
dtype="float32"))
+ lv2 = R.call_dps_packed("my_add", (lv0, lv1), R.Tensor((32, 32),
dtype="float32"))
R.output(lv2)
return lv2
@@ -468,8 +468,8 @@ def test_distiguish_diamond_and_parallel():
with PatternContext() as ctx:
# describe a diamond pattern
- fork = is_call_tir_extern("my_relu")
- join = is_call_tir_extern("my_add")
+ fork = is_call_dps_packed("my_relu")
+ join = is_call_dps_packed("my_add")
fork.only_used_by(join, index=0)
fork.only_used_by(join, index=1)
@@ -478,13 +478,13 @@ def test_distiguish_diamond_and_parallel():
with PatternContext() as ctx:
# describe a parallel pattern
- join = is_call_tir_extern("my_add")
+ join = is_call_dps_packed("my_add")
# Due to one-one mathcing:
- # is_call_tir_extern("my_relu") creates the 1st relu
- is_call_tir_extern("my_relu") >> join
- # is_call_tir_extern("my_relu")
+ # is_call_dps_packed("my_relu") creates the 1st relu
+ is_call_dps_packed("my_relu") >> join
+ # is_call_dps_packed("my_relu")
# creates the another different relu (obj address is different)
- is_call_tir_extern("my_relu") >> join
+ is_call_dps_packed("my_relu") >> join
assert ctx.match_dfb(parallel)
assert not ctx.match_dfb(diamond)
@@ -507,13 +507,13 @@ class CBRx2:
# \ /
# concat
with R.dataflow():
- lv0 = R.call_tir("conv1x1", (x, w0), R.Tensor((32, 32),
dtype="float32"))
- lv1 = R.call_tir("bias_add", (lv0, bias0), R.Tensor((32, 32),
dtype="float32"))
- lv2 = R.call_tir("my_relu", (lv1), R.Tensor((32, 32),
dtype="float32"))
- lv3 = R.call_tir("conv1x1", (x, w1), R.Tensor((32, 32),
dtype="float32"))
- lv4 = R.call_tir("bias_add", (lv3, bias1), R.Tensor((32, 32),
dtype="float32"))
- lv5 = R.call_tir("my_relu", (lv4), R.Tensor((32, 32),
dtype="float32"))
- lv6 = R.call_tir("concat", (lv2, lv5), R.Tensor((32, 64),
dtype="float32"))
+ lv0 = R.call_dps_packed("conv1x1", (x, w0), R.Tensor((32, 32),
dtype="float32"))
+ lv1 = R.call_dps_packed("bias_add", (lv0, bias0), R.Tensor((32,
32), dtype="float32"))
+ lv2 = R.call_dps_packed("my_relu", (lv1), R.Tensor((32, 32),
dtype="float32"))
+ lv3 = R.call_dps_packed("conv1x1", (x, w1), R.Tensor((32, 32),
dtype="float32"))
+ lv4 = R.call_dps_packed("bias_add", (lv3, bias1), R.Tensor((32,
32), dtype="float32"))
+ lv5 = R.call_dps_packed("my_relu", (lv4), R.Tensor((32, 32),
dtype="float32"))
+ lv6 = R.call_dps_packed("concat", (lv2, lv5), R.Tensor((32, 64),
dtype="float32"))
R.output(lv6)
return lv6
@@ -521,9 +521,9 @@ class CBRx2:
def test_single_cbr():
with PatternContext() as ctx:
(
- is_call_tir_extern("conv1x1")
- >> is_call_tir_extern("bias_add")
- >> is_call_tir_extern("my_relu")
+ is_call_dps_packed("conv1x1")
+ >> is_call_dps_packed("bias_add")
+ >> is_call_dps_packed("my_relu")
)
dfb = CBRx2["main"].body.blocks[0]
matched = ctx.match_dfb(dfb)
@@ -531,9 +531,9 @@ def test_single_cbr():
with PatternContext() as ctx:
chain = (
- is_call_tir_extern("conv1x1")
- >> is_call_tir_extern("bias_add")
- >> is_call_tir_extern("my_relu")
+ is_call_dps_packed("conv1x1")
+ >> is_call_dps_packed("bias_add")
+ >> is_call_dps_packed("my_relu")
)
dfb = CBRx2["main"].body.blocks[0]
# we want to specifically match the first CBR (lv0)
@@ -549,9 +549,9 @@ def test_single_cbr():
def test_counter_single_crb():
with PatternContext() as ctx:
(
- is_call_tir_extern("conv1x1")
- >> is_call_tir_extern("my_relu")
- >> is_call_tir_extern("bias_add")
+ is_call_dps_packed("conv1x1")
+ >> is_call_dps_packed("my_relu")
+ >> is_call_dps_packed("bias_add")
)
dfb = CBRx2["main"].body.blocks[0]
assert not ctx.match_dfb(dfb)
@@ -567,14 +567,14 @@ def test_nested_context():
dfb = CBRx2["main"].body.blocks[0]
with PatternContext() as ctx0:
(
- is_call_tir_extern("conv1x1")
- >> is_call_tir_extern("bias_add")
- >> is_call_tir_extern("my_relu")
+ is_call_dps_packed("conv1x1")
+ >> is_call_dps_packed("bias_add")
+ >> is_call_dps_packed("my_relu")
)
with PatternContext() as ctx1:
- is_call_tir_extern("conv1x1") >> is_call_tir_extern("my_relu") #
pattern to miss
+ is_call_dps_packed("conv1x1") >> is_call_dps_packed("my_relu") #
pattern to miss
with PatternContext() as ctx2:
- is_call_tir_extern("bias_add") >> is_call_tir_extern("my_relu")
+ is_call_dps_packed("bias_add") >> is_call_dps_packed("my_relu")
assert ctx2.match_dfb(dfb)
assert PatternContext.current() == ctx2
assert not ctx1.match_dfb(dfb)
@@ -586,9 +586,9 @@ def test_nested_context():
def test_two_cbr():
with PatternContext() as ctx:
cbr0 = (
- is_call_tir_extern("conv1x1")
- >> is_call_tir_extern("bias_add")
- >> is_call_tir_extern("my_relu")
+ is_call_dps_packed("conv1x1")
+ >> is_call_dps_packed("bias_add")
+ >> is_call_dps_packed("my_relu")
)
cbr1 = cbr0.dup()
@@ -603,9 +603,9 @@ def test_two_cbr():
with PatternContext() as ctx:
# Deny the pattern
cbr0 = (
- is_call_tir_extern("conv1x1")
- >> is_call_tir_extern("bias_add")
- >> is_call_tir_extern("my_relu")
+ is_call_dps_packed("conv1x1")
+ >> is_call_dps_packed("bias_add")
+ >> is_call_dps_packed("my_relu")
)
cbr1 = cbr0.dup()
@@ -626,25 +626,25 @@ def test_two_matmul():
c: R.Tensor((48, 32), "float32"),
) -> R.Tensor:
with R.dataflow():
- lv0 = R.call_tir("matmul", (a, b), R.Tensor((32, 48),
dtype="float32"))
- lv1 = R.call_tir("matmul", (lv0, c), R.Tensor((32, 32),
dtype="float32"))
+ lv0 = R.call_dps_packed("matmul", (a, b), R.Tensor((32, 48),
dtype="float32"))
+ lv1 = R.call_dps_packed("matmul", (lv0, c), R.Tensor((32, 32),
dtype="float32"))
R.output(lv1)
return lv1
with PatternContext() as ctx:
- is_call_tir_extern("matmul") >> is_call_tir_extern("matmul")
+ is_call_dps_packed("matmul") >> is_call_dps_packed("matmul")
dfb = MatMul2["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
with PatternContext() as ctx:
- is_call_tir_extern("matmul").has_shape([32, 48]) >>
is_call_tir_extern("matmul").has_shape(
+ is_call_dps_packed("matmul").has_shape([32, 48]) >>
is_call_dps_packed("matmul").has_shape(
[32, 32]
)
dfb = MatMul2["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
with PatternContext() as ctx:
- is_call_tir_extern("matmul") >> is_call_tir_extern("matmul") >>
is_call_tir_extern("matmul")
+ is_call_dps_packed("matmul") >> is_call_dps_packed("matmul") >>
is_call_dps_packed("matmul")
dfb = MatMul2["main"].body.blocks[0]
# Three MatMul cannot match
assert not ctx.match_dfb(dfb)
@@ -661,9 +661,9 @@ def test_concat_mm_split():
c: R.Tensor((16, 32), "float32"),
) -> R.Tensor:
with R.dataflow():
- lv0 = R.call_tir("my_concat", (b, c), R.Tensor((32, 32),
dtype="float32"))
- lv1 = R.call_tir("my_matmul", (a, lv0), R.Tensor((32, 32),
dtype="float32"))
- lv2 = R.call_tir(
+ lv0 = R.call_dps_packed("my_concat", (b, c), R.Tensor((32,
32), dtype="float32"))
+ lv1 = R.call_dps_packed("my_matmul", (a, lv0), R.Tensor((32,
32), dtype="float32"))
+ lv2 = R.call_dps_packed(
"my_split",
(lv1,),
[R.Tensor((16, 32), dtype="float32"), R.Tensor((16, 32),
dtype="float32")],
@@ -676,15 +676,15 @@ def test_concat_mm_split():
with PatternContext() as ctx:
(
- is_call_tir_extern("my_concat")
- >> is_call_tir_extern("my_matmul")
- >> is_call_tir_extern("my_split")
+ is_call_dps_packed("my_concat")
+ >> is_call_dps_packed("my_matmul")
+ >> is_call_dps_packed("my_split")
)
dfb = CMS["main"].body.blocks[0]
assert ctx.match_dfb(dfb)
with PatternContext() as ctx:
- split = is_call_tir_extern("my_split")
+ split = is_call_dps_packed("my_split")
lv3 = TupleGetItemPattern(split, 0).has_shape([16, 32])
lv4 = TupleGetItemPattern(split, 1).has_shape([16, 32])
split.fork_to(lv3, lv4)
@@ -711,18 +711,26 @@ def test_self_attention():
) -> R.Tensor:
b, s, n, h = T.int64(), T.int64(), T.int64(), T.int64()
with R.dataflow():
- fcq = R.call_tir("my_fc", (x, wq), R.Tensor((b, s, n, h),
dtype="float32"))
- tpq = R.call_tir("my_transpose", (fcq,), R.Tensor((b, s, h,
n), dtype="float32"))
+ fcq = R.call_dps_packed("my_fc", (x, wq), R.Tensor((b, s, n,
h), dtype="float32"))
+ tpq = R.call_dps_packed(
+ "my_transpose", (fcq,), R.Tensor((b, s, h, n),
dtype="float32")
+ )
- fck = R.call_tir("my_fc", (x, wk), R.Tensor((b, s, n, h),
dtype="float32"))
- tpk = R.call_tir("my_transpose", (fck,), R.Tensor((b, s, h,
n), dtype="float32"))
+ fck = R.call_dps_packed("my_fc", (x, wk), R.Tensor((b, s, n,
h), dtype="float32"))
+ tpk = R.call_dps_packed(
+ "my_transpose", (fck,), R.Tensor((b, s, h, n),
dtype="float32")
+ )
mul = R.multiply(tpq, tpk)
scale = R.multiply(mul, R.const(1.1, "float32"))
- softmax = R.call_tir("softmax", (scale,), R.Tensor((b, s, n,
h), dtype="float32"))
+ softmax = R.call_dps_packed(
+ "softmax", (scale,), R.Tensor((b, s, n, h),
dtype="float32")
+ )
- fcv = R.call_tir("my_fc", (x, wv), R.Tensor((b, s, n, h),
dtype="float32"))
- tpv = R.call_tir("my_transpose", (fcv,), R.Tensor((b, s, h,
n), dtype="float32"))
+ fcv = R.call_dps_packed("my_fc", (x, wv), R.Tensor((b, s, n,
h), dtype="float32"))
+ tpv = R.call_dps_packed(
+ "my_transpose", (fcv,), R.Tensor((b, s, h, n),
dtype="float32")
+ )
out = R.multiply(softmax, tpv)
R.output(out)
@@ -730,7 +738,7 @@ def test_self_attention():
return out
with PatternContext() as ctx:
- fc_trans_q = is_call_tir_extern("my_fc") >>
is_call_tir_extern("my_transpose")
+ fc_trans_q = is_call_dps_packed("my_fc") >>
is_call_dps_packed("my_transpose")
fc_trans_k = fc_trans_q.dup()
fc_trans_v = fc_trans_q.dup()
@@ -752,43 +760,59 @@ def test_nested_diamond():
# add5 add6
# \ /
# add7
- lv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32),
dtype="float32"))
- lv1 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32),
dtype="float32"))
- lv2 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32),
dtype="float32"))
- lv3 = R.call_tir("tir_sigmoid", (lv1), R.Tensor((32, 32),
dtype="float32"))
- lv4 = R.call_tir("tir_add", (lv0, lv1), R.Tensor((32, 32),
dtype="float32"))
- lv5 = R.call_tir("tir_add", (lv2, lv4), R.Tensor((32, 32),
dtype="float32"))
- lv6 = R.call_tir("tir_add", (lv3, lv4), R.Tensor((32, 32),
dtype="float32"))
- lv7 = R.call_tir("tir_add", (lv5, lv6), R.Tensor((32, 32),
dtype="float32"))
+ lv0 = R.call_dps_packed(
+ "extern_matmul", (x, w), R.Tensor((32, 32),
dtype="float32")
+ )
+ lv1 = R.call_dps_packed(
+ "extern_matmul", (x, w), R.Tensor((32, 32),
dtype="float32")
+ )
+ lv2 = R.call_dps_packed(
+ "extern_sigmoid", (lv0), R.Tensor((32, 32),
dtype="float32")
+ )
+ lv3 = R.call_dps_packed(
+ "extern_sigmoid", (lv1), R.Tensor((32, 32),
dtype="float32")
+ )
+ lv4 = R.call_dps_packed(
+ "extern_add", (lv0, lv1), R.Tensor((32, 32),
dtype="float32")
+ )
+ lv5 = R.call_dps_packed(
+ "extern_add", (lv2, lv4), R.Tensor((32, 32),
dtype="float32")
+ )
+ lv6 = R.call_dps_packed(
+ "extern_add", (lv3, lv4), R.Tensor((32, 32),
dtype="float32")
+ )
+ lv7 = R.call_dps_packed(
+ "extern_add", (lv5, lv6), R.Tensor((32, 32),
dtype="float32")
+ )
R.output(lv7)
return lv7
# match matmul0 diamond
with PatternContext() as ctx:
- sigmoid2 = is_call_tir_extern("tir_sigmoid")
- add4 = is_call_tir_extern("tir_add")
- is_call_tir_extern("tir_matmul").fork_to(sigmoid2, add4)
- add5 = is_call_tir_extern("tir_add")
+ sigmoid2 = is_call_dps_packed("extern_sigmoid")
+ add4 = is_call_dps_packed("extern_add")
+ is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4)
+ add5 = is_call_dps_packed("extern_add")
sigmoid2 >> add5
add4 ^ add5
assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
# counter case: mis-match matmul0 diamond
with PatternContext() as ctx:
- sigmoid2 = is_call_tir_extern("tir_sigmoid")
- add4 = is_call_tir_extern("tir_add")
- is_call_tir_extern("tir_matmul").fork_to(sigmoid2, add4)
- add5 = is_call_tir_extern("tir_add")
+ sigmoid2 = is_call_dps_packed("extern_sigmoid")
+ add4 = is_call_dps_packed("extern_add")
+ is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4)
+ add5 = is_call_dps_packed("extern_add")
sigmoid2 >> add5
add4 >> add5 # not only-used-by relation
assert not ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
# match matmul1 diamond
with PatternContext() as ctx:
- sigmoid3 = is_call_tir_extern("tir_sigmoid")
- add4 = is_call_tir_extern("tir_add")
- is_call_tir_extern("tir_matmul").fork_to(sigmoid3, add4)
- add6 = is_call_tir_extern("tir_add")
+ sigmoid3 = is_call_dps_packed("extern_sigmoid")
+ add4 = is_call_dps_packed("extern_add")
+ is_call_dps_packed("extern_matmul").fork_to(sigmoid3, add4)
+ add6 = is_call_dps_packed("extern_add")
sigmoid3 >> add6
add4 ^ add6
assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
@@ -796,11 +820,11 @@ def test_nested_diamond():
# match add-4-5-6-7
with PatternContext() as ctx:
add5, add6, add7 = (
- is_call_tir_extern("tir_add"),
- is_call_tir_extern("tir_add"),
- is_call_tir_extern("tir_add"),
+ is_call_dps_packed("extern_add"),
+ is_call_dps_packed("extern_add"),
+ is_call_dps_packed("extern_add"),
)
- is_call_tir_extern("tir_add").fork_to(add5, add6) # add4
+ is_call_dps_packed("extern_add").fork_to(add5, add6) # add4
add5 >> add7
add6 >> add7
assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
@@ -811,15 +835,15 @@ def test_incremental_solving():
def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# relu -> sigmoid -> neg
- lv0 = R.call_tir("tir_relu", (x), R.Tensor((32, 32),
dtype="float32"))
- lv1 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32),
dtype="float32"))
- lv2 = R.call_tir("tir_neg", (lv1), R.Tensor((32, 32),
dtype="float32"))
+ lv0 = R.call_dps_packed("extern_relu", (x), R.Tensor((32, 32),
dtype="float32"))
+ lv1 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32,
32), dtype="float32"))
+ lv2 = R.call_dps_packed("extern_neg", (lv1), R.Tensor((32, 32),
dtype="float32"))
R.output(lv2)
return lv2
- relu = is_call_tir_extern("tir_relu")
- sigmoid = is_call_tir_extern("tir_sigmoid")
- neg = is_call_tir_extern("tir_neg")
+ relu = is_call_dps_packed("extern_relu")
+ sigmoid = is_call_dps_packed("extern_sigmoid")
+ neg = is_call_dps_packed("extern_neg")
with PatternContext() as ctx0:
relu >> sigmoid
@@ -840,14 +864,14 @@ def test_incremental_solving_counter():
def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
# sigmoid -> neg
- lv0 = R.call_tir("tir_sigmoid", (x), R.Tensor((32, 32),
dtype="float32"))
- lv1 = R.call_tir("tir_neg", (lv0), R.Tensor((32, 32),
dtype="float32"))
+ lv0 = R.call_dps_packed("extern_sigmoid", (x), R.Tensor((32, 32),
dtype="float32"))
+ lv1 = R.call_dps_packed("extern_neg", (lv0), R.Tensor((32, 32),
dtype="float32"))
R.output(lv1)
return lv1
- relu = is_call_tir_extern("tir_relu")
- sigmoid = is_call_tir_extern("tir_sigmoid")
- neg = is_call_tir_extern("tir_neg")
+ relu = is_call_dps_packed("extern_relu")
+ sigmoid = is_call_dps_packed("extern_sigmoid")
+ neg = is_call_dps_packed("extern_neg")
with PatternContext() as ctx0:
relu >> sigmoid # cannot match
diff --git a/tests/python/relax/test_op_misc.py
b/tests/python/relax/test_op_misc.py
index 523a628fa9..fd23911533 100644
--- a/tests/python/relax/test_op_misc.py
+++ b/tests/python/relax/test_op_misc.py
@@ -39,7 +39,7 @@ def identity_tir(a: T.handle, b: T.handle) -> None:
def test_call_tir() -> None:
v0 = rx.Var("v0", R.Tensor([54, 96], "float32"))
- v1 = rx.call_tir(rx.extern("test.op.identity"), [v0], R.Tensor((54, 96),
"float32"))
+ v1 = rx.call_dps_packed(rx.extern("test.op.identity"), [v0], R.Tensor((54,
96), "float32"))
v1 = rx.call_tir(identity_tir, [v0], R.Tensor((54, 96), "float32"))
diff --git a/tests/python/relax/test_transform.py
b/tests/python/relax/test_transform.py
index 85de4f912e..3e6305c492 100644
--- a/tests/python/relax/test_transform.py
+++ b/tests/python/relax/test_transform.py
@@ -32,8 +32,10 @@ def test_to_non_dataflow():
def foo(x: R.Tensor(("m", "n"), "float32")):
m, n = T.int64(), T.int64()
with R.dataflow():
- lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n),
dtype="float32"))
- gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n),
dtype="float32"))
+ lv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m,
n), dtype="float32"))
+ gv0 = R.call_dps_packed(
+ "test.op.identity", (lv0,), R.Tensor((m, n),
dtype="float32")
+ )
R.output(gv0)
return gv0
@@ -73,10 +75,14 @@ def test_to_non_dataflow():
def test_call_tir_rewrite():
@tvm.script.ir_module
class TestCallTIRRewrite:
+ @T.prim_func
+ def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3),
"float32")):
+ T.evaluate(0)
+
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")):
m, n = T.int64(), T.int64()
- gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n),
dtype="float32"))
+ gv0 = R.call_tir(exp, (x,), R.Tensor((m, n), dtype="float32"))
return gv0
mod = TestCallTIRRewrite
@@ -94,6 +100,40 @@ def test_call_tir_rewrite():
block = func.body.blocks[0]
assert not isinstance(block, relax.DataflowBlock)
+ s1 = block.bindings[0].value
+ assert isinstance(s1, relax.Call)
+ assert s1.op.name == "relax.builtin.alloc_tensor"
+ assert isinstance(s1.args[0], relax.ShapeExpr)
+ assert structural_equal(s1.args[0], s0.sinfo_args[0].shape)
+ s2 = block.bindings[1].value
+ tvm.ir.expr.GlobalVar
+ assert s2.op.name_hint == "exp"
+
+
+def test_call_dps_packed_rewrite():
+ @tvm.script.ir_module
+ class TestCallDPSPackedRewrite:
+ @R.function
+ def foo(x: R.Tensor(("m", "n"), "float32")):
+ m, n = T.int64(), T.int64()
+ gv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n),
dtype="float32"))
+ return gv0
+
+ mod = TestCallDPSPackedRewrite
+
+ # before rewrite
+ v0 = mod["foo"].body.blocks[0].bindings[0].var
+ s0 = mod["foo"].body.blocks[0].bindings[0].value
+ assert isinstance(s0, relax.Call)
+ assert s0.op.name == "relax.call_dps_packed"
+
+ # CallTIRRewrite also works for call_dps_packed
+ new_mod = relax.transform.CallTIRRewrite()(mod)
+ func = new_mod["foo"]
+
+ block = func.body.blocks[0]
+ assert not isinstance(block, relax.DataflowBlock)
+
s1 = block.bindings[0].value
assert isinstance(s1, relax.Call)
assert s1.op.name == "relax.builtin.alloc_tensor"
diff --git a/tests/python/relax/test_transform_attach_global_symbol.py
b/tests/python/relax/test_transform_attach_global_symbol.py
index cef3842e3e..7fc6798e37 100644
--- a/tests/python/relax/test_transform_attach_global_symbol.py
+++ b/tests/python/relax/test_transform_attach_global_symbol.py
@@ -45,7 +45,7 @@ class Before:
@R.function
def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"),
"float32")) -> R.Tensor:
m, n, k = T.int64(), T.int64(), T.int64()
- gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k),
dtype="float32"))
+ gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((m, k), dtype="float32"))
return gv0
@@ -75,7 +75,7 @@ def test_basic():
) -> R.Tensor:
R.func_attr({"global_symbol": "main"})
m, n, k = T.int64(), T.int64(), T.int64()
- gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k),
dtype="float32"))
+ gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((m, k),
dtype="float32"))
return gv0
before = Before
diff --git a/tests/python/relax/test_transform_bind_params.py
b/tests/python/relax/test_transform_bind_params.py
index 1dfd9e0c8e..2a30586b1b 100644
--- a/tests/python/relax/test_transform_bind_params.py
+++ b/tests/python/relax/test_transform_bind_params.py
@@ -87,10 +87,10 @@ def test_bind_params_symbolic_vars():
m = T.Var("m", "int64")
n = T.Var("n", "int64")
with R.dataflow():
- lv0 = R.call_tir(
+ lv0 = R.call_dps_packed(
"linear0", (x, w0, b0), out_sinfo=R.Tensor((batch, n),
dtype="float32")
)
- out = R.call_tir(
+ out = R.call_dps_packed(
"linear1", (lv0, w1, b1), out_sinfo=R.Tensor((batch, k),
dtype="float32")
)
R.output(out)
diff --git a/tests/python/relax/test_transform_codegen_pass.py
b/tests/python/relax/test_transform_codegen_pass.py
index 3e9501147a..d82706200a 100644
--- a/tests/python/relax/test_transform_codegen_pass.py
+++ b/tests/python/relax/test_transform_codegen_pass.py
@@ -235,12 +235,12 @@ class Conv2dx2_after:
weight2: R.Tensor((16, 3, 3, 16), dtype="float16"),
) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
with R.dataflow():
- lv = R.call_tir(
+ lv = R.call_dps_packed(
"fused_relax_nn_conv2d_tensorrt",
(data, weight1),
out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"),
)
- gv = R.call_tir(
+ gv = R.call_dps_packed(
"fused_relax_nn_conv2d_tensorrt",
(lv, weight2),
out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"),
diff --git a/tests/python/relax/test_transform_fuse_ops.py
b/tests/python/relax/test_transform_fuse_ops.py
index 33d57417cf..14d70ab77c 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -826,7 +826,7 @@ def test_skip_call_dps_packed():
@R.function
def main(x: R.Tensor((2, 3), "float32")):
with R.dataflow():
- y = R.call_tir("func_packed_dps", x, R.Tensor((2, 3),
"float32"))
+ y = R.call_dps_packed("func_packed_dps", x, R.Tensor((2, 3),
"float32"))
R.output(y)
return y
@@ -842,7 +842,7 @@ def test_edge_with_call_dps_packed():
with R.dataflow():
a = R.call_tir(exp, (x,), out_sinfo=R.Tensor((2, 3),
"float32"))
b = R.call_tir(exp, (a,), out_sinfo=R.Tensor((2, 3),
"float32"))
- c = R.call_tir("packed_dps", (a,), out_sinfo=R.Tensor((2, 3),
"float32"))
+ c = R.call_dps_packed("packed_dps", (a,),
out_sinfo=R.Tensor((2, 3), "float32"))
R.output(b, c)
return R.tuple(b, c)
diff --git a/tests/python/relax/test_transform_fuse_tir.py
b/tests/python/relax/test_transform_fuse_tir.py
index c2784edec7..f8d488e43b 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -690,7 +690,7 @@ def test_skip_call_dps_packed():
@R.function
def main(x: R.Tensor((2, 3), "float32")):
with R.dataflow():
- y = R.call_tir("func_packed_dps", x, R.Tensor((2, 3),
"float32"))
+ y = R.call_dps_packed("func_packed_dps", x, R.Tensor((2, 3),
"float32"))
R.output(y)
return y
diff --git a/tests/python/relax/test_transform_normalize.py
b/tests/python/relax/test_transform_normalize.py
index da123f956d..874e83c7f9 100644
--- a/tests/python/relax/test_transform_normalize.py
+++ b/tests/python/relax/test_transform_normalize.py
@@ -124,8 +124,10 @@ def test_normalize_no_op():
def foo(x: R.Tensor(("m", "n"), "float32")):
m, n = T.int64(), T.int64()
with R.dataflow():
- lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n),
dtype="float32"))
- gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n),
dtype="float32"))
+ lv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m,
n), dtype="float32"))
+ gv0 = R.call_dps_packed(
+ "test.op.identity", (lv0,), R.Tensor((m, n),
dtype="float32")
+ )
R.output(gv0)
return gv0
diff --git a/tests/python/relax/test_tvmscript_ir_builder.py
b/tests/python/relax/test_tvmscript_ir_builder.py
index f7c29b8dbe..e103e9cddd 100644
--- a/tests/python/relax/test_tvmscript_ir_builder.py
+++ b/tests/python/relax/test_tvmscript_ir_builder.py
@@ -25,7 +25,7 @@ def test_function_simple():
"""
@R.function
def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32",
ndim=2):
- out = R.call_tir("extern_func", x, R.Tensor((128, 128),
dtype="float32"))
+ out = R.call_dps_packed("extern_func", x, R.Tensor((128, 128),
dtype="float32"))
return out
"""
# create with Script IRBuilder
@@ -35,8 +35,15 @@ def test_function_simple():
R.func_attr({"Primitive": 1})
x = R.arg("x", relax.TensorStructInfo((128, 128), "float32"))
R.func_ret_struct_info(relax.TensorStructInfo(dtype="float32",
ndim=2))
+ y = R.emit(
+ R.call_dps_packed(
+ "extern_func", x, relax.TensorStructInfo((128, 128),
dtype="float32")
+ )
+ )
out = R.emit(
- R.call_tir("extern_func", x, relax.TensorStructInfo((128,
128), dtype="float32"))
+ R.call_dps_packed(
+ "extern_dps_func", y, relax.TensorStructInfo((128, 128),
dtype="float32")
+ )
)
IRBuilder.name("out", out)
R.func_ret_value(out)
@@ -45,8 +52,15 @@ def test_function_simple():
x = relax.Var("x", relax.TensorStructInfo((128, 128), "float32"))
bb = relax.BlockBuilder()
with bb.function("foo", (x,), attrs={"Primitive": 1}):
+ y = bb.emit(
+ relax.call_dps_packed(
+ "extern_func", x, relax.TensorStructInfo((128, 128),
dtype="float32")
+ )
+ )
out = bb.emit(
- relax.call_tir("extern_func", x, relax.TensorStructInfo((128,
128), dtype="float32"))
+ relax.call_dps_packed(
+ "extern_dps_func", y, relax.TensorStructInfo((128, 128),
dtype="float32")
+ )
)
bb.emit_func_output(out)
mod = bb.get()
@@ -112,7 +126,7 @@ def test_dataflow_block():
def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim
= 2):
# block 0
with R.dataflow():
- lv0 = R.call_tir("extern_func", (x,), R.Tensor((128, 128),
dtype="float32"))
+ lv0 = R.call_dps_packed("extern_func", (x,), R.Tensor((128, 128),
dtype="float32"))
gv: Tensor((128, 128), "float32") = lv0
R.output(gv)
return gv
@@ -124,7 +138,7 @@ def test_dataflow_block():
x = R.arg("x", relax.TensorStructInfo((128, 128), "float32"))
with R.dataflow() as df:
lv0 = R.emit(
- R.call_tir(
+ R.call_dps_packed(
"extern_func", x, relax.TensorStructInfo((128, 128),
dtype="float32")
)
)
@@ -142,7 +156,7 @@ def test_dataflow_block():
with bb.function("foo", (x,)):
with bb.dataflow():
lv0 = bb.emit(
- relax.call_tir(
+ relax.call_dps_packed(
"extern_func", x, relax.TensorStructInfo((128, 128),
dtype="float32")
)
)
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index ffc024403d..136bd8c1ea 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -43,13 +43,17 @@ def test_simple_func():
@R.function
def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128),
"float32"):
R.func_attr({"Primitive": 1})
- gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128),
dtype="float32"))
- return gv0
+ gv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128),
dtype="float32"))
+ gv1 = R.call_dps_packed("extern_dps_func", gv0, R.Tensor((128, 128),
dtype="float32"))
+ return gv1
x = relax.Var("x", R.Tensor((128, 128), "float32"))
bb = relax.BlockBuilder()
with bb.function("foo", (x,), attrs={"Primitive": 1}):
- out = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128),
dtype="float32")))
+ y = bb.emit(relax.call_dps_packed("extern_func", x, R.Tensor((128,
128), dtype="float32")))
+ out = bb.emit(
+ relax.call_dps_packed("extern_dps_func", y, R.Tensor((128, 128),
dtype="float32"))
+ )
bb.emit_func_output(out)
_check(foo, bb.get()["foo"])
@@ -111,15 +115,34 @@ def test_unexpected_tir_cast_args():
return R.call_tir("foo", (x,), R.Tensor((T.cast("int32", m, 1),),
dtype="float32"))
-def test_unexpected_tir_max_args():
+def test_unexpected_tir_args():
+
+ with pytest.raises(tvm.error.DiagnosticError):
+
+ @tvm.script.ir_module
+ class TestWellCallTIR:
+ @T.prim_func
+ def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16,
16), "int32")) -> None:
+ T.func_attr(({"global_symbol": "tir_addone"}))
+ for i, j in T.grid(16, 16):
+ with T.block("tir_addone"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] + T.int32(1)
+
+ @R.function
+ def foo(x: R.Tensor(("m", "m"), "float32")):
+ m = T.int64()
+ # tir.max expects 2 arguments, but got 1
+ gv = R.call_tir(tir_addone, (x,), R.Tensor((T.max(16),),
dtype="float32"))
+ return gv
with pytest.raises(tvm.error.DiagnosticError):
@R.function
def f(x: R.Tensor(("m", "n"), "float32")):
m = T.int64()
- # tir.max expects 2 arguments, but got 1
- return relax.call_tir("foo", (x,), R.Tensor((T.max(m),),
dtype="float32"))
+ # call_tir expected a tir prim_func
+ return relax.call_tir("extern_func", (x,), R.Tensor((T.max(m),),
dtype="float32"))
def test_func_type_annotation_fail():
@@ -315,14 +338,14 @@ def test_symbolic_shape():
def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
m = T.int64()
n = T.int64()
- gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32"))
+ gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n),
dtype="float32"))
return gv0
@R.function
def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
m = T.int64()
n = T.int64()
- gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32"))
+ gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n),
dtype="float32"))
return gv0
with pytest.raises(tvm.error.DiagnosticError):
@@ -331,7 +354,7 @@ def test_symbolic_shape():
def mismatch_dtype(x: R.Tensor(("m", "n"), "float32")) ->
R.Tensor(None, "float32", ndim=2):
m = T.int64()
n = T.int32() # The shape dtype should be int64
- gv0 = R.call_tir("extern_func", x, R.Tensor((m, n),
dtype="float32"))
+ gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n),
dtype="float32"))
return gv0
def _expected(name: str):
@@ -339,7 +362,9 @@ def test_symbolic_shape():
x = relax.Var("x", R.Tensor([m, n], "float32"))
bb = relax.BlockBuilder()
with bb.function(name, (x,)):
- out = bb.emit(relax.call_tir("extern_func", x, R.Tensor((m, n),
dtype="float32")))
+ out = bb.emit(
+ relax.call_dps_packed("extern_func", x, R.Tensor((m, n),
dtype="float32"))
+ )
bb.emit_func_output(out)
return bb.get()[name]
@@ -403,15 +428,15 @@ def test_match_cast():
def test_tuple_return():
@R.function
def foo(x: R.Tensor((4, 4), "float32")):
- gv0 = R.call_tir("extern_func_0", x, R.Tensor((4, 4), dtype="float32"))
- gv1 = R.call_tir("extern_func_1", x, R.Tensor((4, 4), dtype="float32"))
+ gv0 = R.call_dps_packed("extern_func_0", x, R.Tensor((4, 4),
dtype="float32"))
+ gv1 = R.call_dps_packed("extern_func_1", x, R.Tensor((4, 4),
dtype="float32"))
return (gv0, gv1)
x = relax.Var("x", R.Tensor((4, 4), "float32"))
bb = relax.BlockBuilder()
with bb.function("foo", (x,)):
- gv0 = bb.emit(relax.call_tir("extern_func_0", x, R.Tensor((4, 4),
dtype="float32")))
- gv1 = bb.emit(relax.call_tir("extern_func_1", x, R.Tensor((4, 4),
dtype="float32")))
+ gv0 = bb.emit(relax.call_dps_packed("extern_func_0", x, R.Tensor((4,
4), dtype="float32")))
+ gv1 = bb.emit(relax.call_dps_packed("extern_func_1", x, R.Tensor((4,
4), dtype="float32")))
bb.emit_func_output(relax.Tuple((gv0, gv1)))
_check(foo, bb.get()["foo"])
@@ -483,8 +508,8 @@ def test_dataflow_block():
@R.function
def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32",
ndim=2):
with R.dataflow():
- lv0 = R.call_tir("extern_func", x, R.Tensor((128, 128),
dtype="float32"))
- lv1 = R.call_tir("extern_func", lv0, R.Tensor((128, 128),
dtype="float32"))
+ lv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128),
dtype="float32"))
+ lv1 = R.call_dps_packed("extern_func", lv0, R.Tensor((128, 128),
dtype="float32"))
gv = lv1
R.output(gv)
return gv
@@ -493,8 +518,12 @@ def test_dataflow_block():
bb = relax.BlockBuilder()
with bb.function("foo", (x,)):
with bb.dataflow():
- lv0 = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128,
128), dtype="float32")))
- lv1 = bb.emit(relax.call_tir("extern_func", lv0, R.Tensor((128,
128), dtype="float32")))
+ lv0 = bb.emit(
+ relax.call_dps_packed("extern_func", x, R.Tensor((128, 128),
dtype="float32"))
+ )
+ lv1 = bb.emit(
+ relax.call_dps_packed("extern_func", lv0, R.Tensor((128, 128),
dtype="float32"))
+ )
gv = bb.emit_output(lv1)
bb.emit_func_output(gv)
@@ -504,22 +533,22 @@ def test_dataflow_block():
def test_dataflow_block_advanced():
@R.function
def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32",
ndim=2):
- gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128),
dtype="float32"))
- gv1 = R.call_tir("extern_func", gv0, R.Tensor((128, 128),
dtype="float32"))
+ gv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128),
dtype="float32"))
+ gv1 = R.call_dps_packed("extern_func", gv0, R.Tensor((128, 128),
dtype="float32"))
with R.dataflow():
m = T.int64()
n = T.int64()
- lv0 = R.call_tir("extern_func", gv1, R.Tensor((128, 128),
dtype="float32"))
+ lv0 = R.call_dps_packed("extern_func", gv1, R.Tensor((128, 128),
dtype="float32"))
lv1 = R.match_cast(lv0, R.Tensor((m, n), "float32"))
- gv2 = R.call_tir("extern_func", lv0, R.Tensor((128, 128),
dtype="float32"))
- gv2 = R.call_tir("extern_func", gv2, R.Tensor((128, 128),
dtype="float32"))
+ gv2 = R.call_dps_packed("extern_func", lv0, R.Tensor((128, 128),
dtype="float32"))
+ gv2 = R.call_dps_packed("extern_func", gv2, R.Tensor((128, 128),
dtype="float32"))
gv3 = R.match_cast(gv2, R.Tensor((m, n), "float32"))
gv3 = R.match_cast(lv0, R.Tensor((m, n), "float32"))
gv4 = gv3
gv5 = gv2
R.output(gv5, gv4)
- gv6 = R.call_tir("extern_func", gv5, R.Tensor((128, 128),
dtype="float32"))
- gv7 = R.call_tir("extern_func", gv6, R.Tensor((128, 128),
dtype="float32"))
+ gv6 = R.call_dps_packed("extern_func", gv5, R.Tensor((128, 128),
dtype="float32"))
+ gv7 = R.call_dps_packed("extern_func", gv6, R.Tensor((128, 128),
dtype="float32"))
return gv7
x = relax.Var("x", R.Tensor((128, 128), "float32"))
@@ -527,21 +556,33 @@ def test_dataflow_block_advanced():
m = tir.Var("m", dtype="int64")
n = tir.Var("n", dtype="int64")
with bb.function("foo", (x,)):
- gv0 = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128),
dtype="float32")))
- gv1 = bb.emit(relax.call_tir("extern_func", gv0, R.Tensor((128, 128),
dtype="float32")))
+ gv0 = bb.emit(
+ relax.call_dps_packed("extern_func", x, R.Tensor((128, 128),
dtype="float32"))
+ )
+ gv1 = bb.emit(
+ relax.call_dps_packed("extern_func", gv0, R.Tensor((128, 128),
dtype="float32"))
+ )
with bb.dataflow():
- lv0 = bb.emit(relax.call_tir("extern_func", gv1, R.Tensor((128,
128), dtype="float32")))
+ lv0 = bb.emit(
+ relax.call_dps_packed("extern_func", gv1, R.Tensor((128, 128),
dtype="float32"))
+ )
lv1 = bb.match_cast(lv0, R.Tensor((m, n), "float32"))
- gv2 = bb.emit(relax.call_tir("extern_func", lv0, R.Tensor((128,
128), dtype="float32")))
+ gv2 = bb.emit(
+ relax.call_dps_packed("extern_func", lv0, R.Tensor((128, 128),
dtype="float32"))
+ )
gv21 = bb.emit(
- relax.call_tir("extern_func", gv2, R.Tensor((128, 128),
dtype="float32"))
+ relax.call_dps_packed("extern_func", gv2, R.Tensor((128, 128),
dtype="float32"))
)
gv3 = bb.match_cast(gv21, R.Tensor((m, n), "float32"))
gv31 = bb.match_cast(lv0, R.Tensor((m, n), "float32"))
gv32 = bb.emit_output(gv31)
gv22 = bb.emit_output(gv21)
- gv4 = bb.emit(relax.call_tir("extern_func", gv22, R.Tensor((128, 128),
dtype="float32")))
- gv5 = bb.emit(relax.call_tir("extern_func", gv4, R.Tensor((128, 128),
dtype="float32")))
+ gv4 = bb.emit(
+ relax.call_dps_packed("extern_func", gv22, R.Tensor((128, 128),
dtype="float32"))
+ )
+ gv5 = bb.emit(
+ relax.call_dps_packed("extern_func", gv4, R.Tensor((128, 128),
dtype="float32"))
+ )
bb.emit_func_output(gv5)
_check(foo, bb.get()["foo"])
@@ -640,13 +681,13 @@ def test_function_without_return():
def test_tensor_type_without_args():
@R.function
def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
- v = R.call_tir("tir_relu", x, R.Tensor((32, 32), dtype="float32"))
+ v = R.call_dps_packed("extern_relu", x, R.Tensor((32, 32),
dtype="float32"))
return v
x = relax.Var("x", R.Tensor((32, 32), "float32"))
bb = relax.BlockBuilder()
with bb.function("foo", (x)):
- v = bb.emit(relax.call_tir("tir_relu", x, R.Tensor((32, 32),
dtype="float32")))
+ v = bb.emit(relax.call_dps_packed("extern_relu", x, R.Tensor((32, 32),
dtype="float32")))
bb.emit_func_output(v)
_check(foo, bb.get()["foo"])
@@ -753,10 +794,10 @@ def test_annotate_override():
assert isinstance(z_bind.var.struct_info, relax.TensorStructInfo)
-def test_call_tir_empty_shape():
+def test_call_dps_packed_empty_shape():
@R.function
def foo(x: R.Tensor((), "float32")):
- z = R.call_tir("scalar_add", x, R.Tensor((), dtype="float32"))
+ z = R.call_dps_packed("scalar_add", x, R.Tensor((), dtype="float32"))
return z
(z_bind,) = foo.body.blocks[0].bindings
@@ -1024,7 +1065,7 @@ def test_symbolic_shape_computing():
x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",),
"float32")
) -> R.Tensor(("T.max(m, 20) + 1",), "float32"):
m = T.int64()
- z = R.call_tir("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,),
dtype="float32"))
+ z = R.call_dps_packed("test_intrin", (x, y), R.Tensor((T.max(m, 20) +
1,), dtype="float32"))
return z
m = tir.Var("m", "int64")
@@ -1033,7 +1074,9 @@ def test_symbolic_shape_computing():
bb = relax.BlockBuilder()
with bb.function("bar", (x, y)):
z = bb.emit(
- relax.call_tir("test_intrin", (x, y), R.Tensor((tir.max(m, 20) +
1,), dtype="float32"))
+ relax.call_dps_packed(
+ "test_intrin", (x, y), R.Tensor((tir.max(m, 20) + 1,),
dtype="float32")
+ )
)
bb.emit_func_output(z)
@@ -1043,7 +1086,7 @@ def test_symbolic_shape_computing():
@R.function
def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")):
m = T.int64()
- z = R.call_tir("test_intrin", y, R.Tensor((m * 2,), dtype="float32"))
+ z = R.call_dps_packed("test_intrin", y, R.Tensor((m * 2,),
dtype="float32"))
return z
m = tir.Var("m", "int64")
@@ -1051,7 +1094,7 @@ def test_symbolic_shape_computing():
y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32"))
bb = relax.BlockBuilder()
with bb.function("baz", (x, y)):
- z = bb.emit(relax.call_tir("test_intrin", (y), R.Tensor((m * 2,),
dtype="float32")))
+ z = bb.emit(relax.call_dps_packed("test_intrin", (y), R.Tensor((m *
2,), dtype="float32")))
bb.emit_func_output(z)
_check(baz, bb.get()["baz"])
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py
b/tests/python/relax/test_tvmscript_printer_relax.py
index 464591f259..76bb3bb812 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -299,13 +299,22 @@ def test_shape_expr():
def test_call():
x = tir.Var("x", "int64")
a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32"))
- obj = relax.call_tir("my_func", args=a, out_sinfo=a.struct_info,
tir_vars=[x])
+ o0 = relax.call_tir(relax.GlobalVar("tir_func"), args=a,
out_sinfo=a.struct_info, tir_vars=[x])
+ o1 = relax.call_dps_packed("my_dps_func", args=a, out_sinfo=a.struct_info)
_assert_print(
- obj,
+ o0,
+ """
+x = T.int64()
+a: R.Tensor((1, x, 3), dtype="float32")
+R.call_tir(tir_func, (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"),
tir_vars=R.shape([x]))
+""",
+ )
+ _assert_print(
+ o1,
"""
x = T.int64()
a: R.Tensor((1, x, 3), dtype="float32")
-R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"),
tir_vars=R.shape([x]))
+R.call_dps_packed("my_dps_func", (a,), out_sinfo=R.Tensor((1, x, 3),
dtype="float32"))
""",
)
diff --git a/tests/python/relax/test_vm_build.py
b/tests/python/relax/test_vm_build.py
index e51e22e323..776679103f 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -121,7 +121,7 @@ def test_vm_compile_stage3(exec_mode):
@R.function
def foo(x: R.Tensor((32, 16), "float32")) -> R.Tensor:
with R.dataflow():
- y = R.call_tir("test.vm.identity", (x), R.Tensor((32, 16),
dtype="float32"))
+ y = R.call_dps_packed("test.vm.identity", (x), R.Tensor((32,
16), dtype="float32"))
R.output(y)
return y
@@ -145,7 +145,7 @@ def test_vm_compile_e2e(exec_mode):
with R.dataflow():
n, m = T.int64(), T.int64()
_ = R.match_cast(x, R.Tensor((n, m), "float32"))
- y = R.call_tir("test.vm.tile", (x), R.Tensor((n, m * 2),
dtype="float32"))
+ y = R.call_dps_packed("test.vm.tile", (x), R.Tensor((n, m *
2), dtype="float32"))
R.output(y)
return y
@@ -714,7 +714,7 @@ class TestVMSetInput:
@R.function
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32),
"float32")) -> R.Tensor:
- gv0 = R.call_tir("test_vm_mul", (x, w), R.Tensor((32, 32),
dtype="float32"))
+ gv0 = R.call_tir(test_vm_mul, (x, w), R.Tensor((32, 32),
dtype="float32"))
return gv0