This is an automated email from the ASF dual-hosted git repository.

sslyu pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 9ea3f887eb [Unity] Introduce call_dps_packed (#14183)
9ea3f887eb is described below

commit 9ea3f887eb7d36b6de621166e8628a445973cddd
Author: Yong Wu <[email protected]>
AuthorDate: Fri Mar 10 09:41:50 2023 -0800

    [Unity] Introduce call_dps_packed (#14183)
    
    Introduce call_dps_packed to call packed functions in destination-passing 
style, reserving call_tir for TIR PrimFuncs instead.
    
    * [Unity] Introduce call_dps_packed
    
    * fix lint
    
    * Fix comments
    
    * Remove well_form update, enforce in InferStructInfoCallTIR
    
    * Update src/relax/op/op.cc
    
    * Update description of call_tir
    
    * Remove unnecessary check in passes
---
 include/tvm/relax/dataflow_pattern.h               |   4 +
 include/tvm/relax/transform.h                      |   2 +-
 python/tvm/relax/__init__.py                       |   2 +-
 python/tvm/relax/dpl/pattern.py                    |  18 +-
 python/tvm/relax/expr.py                           |   2 +-
 python/tvm/relax/op/base.py                        |  54 ++++-
 python/tvm/relax/transform/transform.py            |   2 +-
 python/tvm/script/ir_builder/relax/ir.py           |   2 +
 src/relax/backend/task_extraction.cc               |   5 -
 src/relax/ir/dataflow_pattern.cc                   |  14 ++
 src/relax/op/op.cc                                 |  41 ++++
 src/relax/transform/call_tir_rewrite.cc            |   5 +-
 src/relax/transform/fold_constant.cc               |  12 +-
 src/relax/transform/fuse_ops.cc                    |  20 +-
 src/relax/transform/fuse_tir.cc                    |  23 +-
 src/relax/transform/legalize_ops.cc                |   3 +-
 src/relax/transform/rewrite_dataflow_reshape.cc    |   7 +-
 src/relax/transform/run_codegen.cc                 |   9 +-
 src/script/printer/relax/call.cc                   |  13 +-
 src/script/printer/relax/utils.h                   |   3 +-
 tests/python/relax/test_analysis.py                |  12 +-
 tests/python/relax/test_analysis_well_formed.py    |   6 +-
 tests/python/relax/test_ast_printer.py             |  72 +++++-
 tests/python/relax/test_binding_rewrite.py         |  16 +-
 tests/python/relax/test_dataflow_pattern.py        | 254 +++++++++++----------
 tests/python/relax/test_op_misc.py                 |   2 +-
 tests/python/relax/test_transform.py               |  46 +++-
 .../relax/test_transform_attach_global_symbol.py   |   4 +-
 tests/python/relax/test_transform_bind_params.py   |   4 +-
 tests/python/relax/test_transform_codegen_pass.py  |   4 +-
 tests/python/relax/test_transform_fuse_ops.py      |   4 +-
 tests/python/relax/test_transform_fuse_tir.py      |   2 +-
 tests/python/relax/test_transform_normalize.py     |   6 +-
 tests/python/relax/test_tvmscript_ir_builder.py    |  26 ++-
 tests/python/relax/test_tvmscript_parser.py        | 123 ++++++----
 tests/python/relax/test_tvmscript_printer_relax.py |  15 +-
 tests/python/relax/test_vm_build.py                |   6 +-
 37 files changed, 567 insertions(+), 276 deletions(-)

diff --git a/include/tvm/relax/dataflow_pattern.h 
b/include/tvm/relax/dataflow_pattern.h
index 701879745e..37640750a8 100644
--- a/include/tvm/relax/dataflow_pattern.h
+++ b/include/tvm/relax/dataflow_pattern.h
@@ -793,6 +793,10 @@ ExprPattern IsOp(const String& op_name);
 CallPattern IsCallTIR(const String& name, Optional<TuplePattern> args = 
NullOpt);
 /*! \brief Syntatic Sugar for call_tir (return a tuple of tensor) */
 CallPattern IsCallTIR(const String& name, TuplePattern var_args);
+/*! \brief Syntatic Sugar for call_dps_packed (return a tensor) */
+CallPattern IsCallDPSPacked(const String& name, Optional<TuplePattern> args = 
NullOpt);
+/*! \brief Syntatic Sugar for call_dps_packed (return a tuple of tensor) */
+CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args);
 /*! \brief Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern 
(unordered=true) */
 DFPattern IsTuple(const Array<DFPattern>& fields, bool unordered = false);
 /*! \brief Syntatic Sugar for creating a TupleGetItemPattern */
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 9838fe53b3..446b75da9f 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -83,7 +83,7 @@ TVM_DLL Pass LambdaLift();
 TVM_DLL Pass ToNonDataflow();
 
 /*!
- * \brief Perform explicit tensor allocation for call_tir.
+ * \brief Perform explicit tensor allocation for call_tir and call_dps_packed.
  *
  * \return The Pass.
  */
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index bbd2040dd9..edbd848bd5 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -56,7 +56,7 @@ from .ty import Type, ObjectType, ShapeType, DynTensorType, 
TupleType, FuncType,
 from .exec_builder import ExecBuilder
 
 # Operator
-from .op.base import call_tir
+from .op.base import call_tir, call_dps_packed
 
 # BlockBuilder
 from .block_builder import BlockBuilder
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index 300b0af568..1ca41b378d 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -862,11 +862,23 @@ def is_call_tir(
     return _is_call_tir(func_pattern, args)
 
 
-def is_call_tir_extern(
+def _is_call_dps_packed(
+    func_pattern: DFPattern,
+    args: Union[List, Tuple, TuplePattern] = None,
+) -> CallPattern:
+    if args is None:
+        args = wildcard()
+    elif isinstance(args, (list, tuple)):
+        args = TuplePattern(args)
+
+    return is_op("relax.call_dps_packed")(func_pattern, args)
+
+
+def is_call_dps_packed(
     func_name: str,
     args: Union[List, Tuple, TuplePattern] = None,
 ) -> CallPattern:
-    """Syntax sugar for creating a CallPattern for call_tir that calls an 
extern function
+    """Syntax sugar for creating a CallPattern for call_dps_packed
 
     Parameters
     ----------
@@ -881,7 +893,7 @@ def is_call_tir_extern(
         The resulting CallPattern
     """
     func_pattern = ExternFuncPattern(func_name)
-    return _is_call_tir(func_pattern, args)
+    return _is_call_dps_packed(func_pattern, args)
 
 
 def is_call_packed(
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index ab332eed61..4af08a3118 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -597,7 +597,7 @@ class Function(BaseFunc, Scriptable):
 
 @tvm._ffi.register_object("relax.expr.ExternFunc")
 class ExternFunc(BaseFunc):
-    """extern function, which can represent a TIR PrimFunc or a PackedFunc."""
+    """extern function, which represents a PackedFunc."""
 
     global_symbol: String
 
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index 0b298679c1..aef0e731db 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -22,7 +22,7 @@ import tvm
 from tvm.runtime.object import Object
 
 from . import _ffi_api
-from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc
+from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar
 from ..expr import Tuple as RxTuple
 from ..struct_info import StructInfo, TensorStructInfo
 from ...ir import PrimExpr
@@ -45,18 +45,18 @@ def null_value() -> Call:
 
 @args_converter.auto
 def call_tir(
-    func: Union[str, Expr],
+    gvar: GlobalVar,
     args: Expr,
     out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]],
     tir_vars: Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] = 
None,
 ) -> Call:
     """
-    Call a destination-passing-style function and return the output.
+    Call a tir.prim_func and return the output.
 
     Parameters
     ----------
-    func : Union[str, Expr]
-        The destination-passing-style function, can be ExternFunc or PrimFunc.
+    gvar : GlobalVar
+        The GlobalVar referring to a tir PrimFunc.
 
     args : Expr
         The input arguments.
@@ -74,9 +74,6 @@ def call_tir(
     ret: Call
         A call node for the call_tir operator.
     """
-    if isinstance(func, str):
-        func = ExternFunc(func)
-
     if isinstance(args, Expr) and not isinstance(args, RxTuple):  # type: 
ignore
         args = RxTuple((args,))
 
@@ -86,7 +83,46 @@ def call_tir(
     if isinstance(tir_vars, (list, tuple)):
         tir_vars = ShapeExpr(tir_vars)
 
-    return _ffi_api.call_tir(func, args, out_sinfo, tir_vars)  # type: ignore
+    return _ffi_api.call_tir(gvar, args, out_sinfo, tir_vars)  # type: ignore
+
+
+@args_converter.auto
+def call_dps_packed(
+    func: Union[str, Expr],
+    args: Expr,
+    out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]],
+) -> Call:
+    """
+    Call a destination-passing-style packed function and return the output.
+
+    Parameters
+    ----------
+    func : Union[str, Expr]
+        The destination-passing-style function, can be ExternFunc.
+
+    args : Expr
+        The input arguments.
+
+    out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]]
+        The structure info of the call_dps_packed output.
+        It should be a single or a list of TensorStructInfo. Each one denotes 
the
+        structure info of a returned tensor.
+
+    Returns
+    -------
+    ret: Call
+        A call node for the call_dps_packed operator.
+    """
+    if isinstance(func, str):
+        func = ExternFunc(func)
+
+    if isinstance(args, Expr) and not isinstance(args, RxTuple):  # type: 
ignore
+        args = RxTuple((args,))
+
+    if not isinstance(out_sinfo, list):
+        out_sinfo = [out_sinfo]
+
+    return _ffi_api.call_dps_packed(func, args, out_sinfo)  # type: ignore
 
 
 @args_converter.auto
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 48560792e1..97c8772b3b 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -60,7 +60,7 @@ def LambdaLift():
 
 
 def CallTIRRewrite() -> tvm.ir.transform.Pass:
-    """Perform explicit tensor allocation for call_tir.
+    """Perform explicit tensor allocation for call_tir and call_dps_packed.
 
     Returns
     -------
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 466a38837f..c658b6f77d 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -48,6 +48,7 @@ from tvm.relax.op import (
     builtin,
     call_builtin_with_ctx,
     call_tir,
+    call_dps_packed,
     ceil,
     clip,
     collapse_sum_like,
@@ -545,6 +546,7 @@ __all__ = [
     "builtin",
     "call_packed",
     "call_tir",
+    "call_dps_packed",
     "call_builtin_with_ctx",
     "ceil",
     "clip",
diff --git a/src/relax/backend/task_extraction.cc 
b/src/relax/backend/task_extraction.cc
index beb3950af1..5bd764c68e 100644
--- a/src/relax/backend/task_extraction.cc
+++ b/src/relax/backend/task_extraction.cc
@@ -73,11 +73,6 @@ class TaskExtractor : public ExprVisitor {
       return;
     }
 
-    // Do not extract external function
-    if (call->args[0].as<ExternFuncNode>()) {
-      return;
-    }
-
     const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
     const tir::PrimFunc& func = 
Downcast<tir::PrimFunc>(mod_->Lookup(global_var));
 
diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc
index 3768627c20..5eb1bf3ea6 100644
--- a/src/relax/ir/dataflow_pattern.cc
+++ b/src/relax/ir/dataflow_pattern.cc
@@ -563,6 +563,20 @@ CallPattern IsCallTIR(const String& name, 
Optional<TuplePattern> var_args) {
 CallPattern IsCallTIR(const String& name, TuplePattern var_args) {
   return IsOp("relax.call_tir")(GlobalVarPattern(name), var_args);
 }
+CallPattern IsCallDPSPacked(const String& name, Optional<TuplePattern> 
var_args) {
+  DFPattern arg_pattern;
+  if (!var_args.defined()) {
+    arg_pattern = Wildcard();
+  } else {
+    arg_pattern = var_args.value();
+  }
+
+  return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), arg_pattern);
+}
+
+CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args) {
+  return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), var_args);
+}
 
 DFPattern IsTuple(const Array<DFPattern>& fields, bool unordered) {
   if (unordered)
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index c78ca539bd..cf084d6d20 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -80,6 +80,9 @@ StructInfo InferStructInfoCallTIR(const Call& call, const 
BlockBuilder& ctx) {
     ctx->ReportFatal(Diagnostic::Error(call)
                      << "sinfo_args should have exact 1 output struct info.");
   }
+  CHECK(call->args[0]->IsInstance<GlobalVarNode>())
+      << "call_tir expects the first argument to be a GlobalVar referring to a 
TIR PrimFunc. "
+      << "However, gets " << call->args[0];
   return call->sinfo_args[0];
 }
 
@@ -121,6 +124,44 @@ Expr MakeCallTIR(Expr func, Tuple args, 
Array<TensorStructInfo> out_sinfo_list,
 
 TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR);
 
+// call_dps_packed
+
+StructInfo InferStructInfoCallDPSPacked(const Call& call, const BlockBuilder& 
ctx) {
+  if (call->sinfo_args.size() != 1) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "sinfo_args should have exact 1 output struct info.");
+  }
+  return call->sinfo_args[0];
+}
+
+RELAY_REGISTER_OP("relax.call_dps_packed")
+    .set_num_inputs(2)
+    .add_argument("func", "Expr", "The destination-passing-style function.")
+    .add_argument("args", "Tuple", "The input arguments.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoCallDPSPacked);
+
+Expr MakeCallDPSPacked(Expr func, Tuple args, Array<TensorStructInfo> 
out_sinfo_list) {
+  for (const TensorStructInfo& sinfo : out_sinfo_list) {
+    const auto* shape = sinfo->shape.as<ShapeExprNode>();
+    CHECK(shape != nullptr)
+        << "out_sinfo of call_dps_packed should have defined ShapeExpr as 
shape. "
+           "However, one given structure info is "
+        << sinfo;
+  }
+
+  StructInfo out_sinfo{nullptr};
+  if (out_sinfo_list.size() == 1) {
+    out_sinfo = out_sinfo_list[0];
+  } else {
+    out_sinfo = TupleStructInfo({out_sinfo_list.begin(), 
out_sinfo_list.end()});
+  }
+
+  static const Op& op = Op::Get("relax.call_dps_packed");
+  return Call(op, {func, args}, {}, {out_sinfo});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.call_dps_packed").set_body_typed(MakeCallDPSPacked);
+
 // call builtin
 StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const 
BlockBuilder& ctx) {
   if (call->sinfo_args.size() == 0) {
diff --git a/src/relax/transform/call_tir_rewrite.cc 
b/src/relax/transform/call_tir_rewrite.cc
index 2ea039e022..6066ed8d2a 100644
--- a/src/relax/transform/call_tir_rewrite.cc
+++ b/src/relax/transform/call_tir_rewrite.cc
@@ -33,7 +33,7 @@ namespace relax {
 
 // ==================
 // CallTIRMutator
-// Perform explicit tensor allocation for call_tir.
+// Perform explicit tensor allocation for call_tir or call_dps_packed.
 // Example:
 // lv0: Tensor(n, m) = rx.call_tir(func, (x), (n, m), dtype="float32")
 // -->
@@ -49,10 +49,11 @@ class CallTIRMutator : public ExprMutator {
     call = expr.as<CallNode>();
 
     static const Op& call_tir_op = Op::Get("relax.call_tir");
+    static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
     static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
     static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn");
 
-    if (call->op == call_tir_op) {
+    if (call->op == call_tir_op || call->op == call_dps_packed_op) {
       Array<Expr> outs;
       if (const auto& _tensor_sinfo = MatchStructInfo<TensorStructInfo>(expr)) 
{
         // single output case
diff --git a/src/relax/transform/fold_constant.cc 
b/src/relax/transform/fold_constant.cc
index 6b28f31889..622dd9ad09 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -87,13 +87,11 @@ class ConstantFolder : public ExprMutator {
    * \return The TIR function, or nullopt if pattern match fails.
    */
   Optional<tir::PrimFunc> MatchPrimFunc(const Expr& op) {
-    if (auto* ptr = op.as<GlobalVarNode>()) {
-      // NOTE: as check works for nullptr(returns null)
-      Optional<BaseFunc> base_func =
-          
builder_->GetContextIRModule()->functions.Get(GetRef<GlobalVar>(ptr));
-      if (auto* pfunc = base_func.as<tir::PrimFuncNode>()) {
-        return GetRef<tir::PrimFunc>(pfunc);
-      }
+    const GlobalVar& global_var = Downcast<GlobalVar>(op);
+    // NOTE: as check works for nullptr(returns null)
+    Optional<BaseFunc> base_func = 
builder_->GetContextIRModule()->functions.Get(global_var);
+    if (auto* pfunc = base_func.as<tir::PrimFuncNode>()) {
+      return GetRef<tir::PrimFunc>(pfunc);
     }
     return NullOpt;
   }
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 3b6b3c17ac..6d7c278d80 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -185,19 +185,17 @@ class GraphCreator : public ExprVisitor {
     // recurse into the call expression.
     const auto* op = call->op.as<OpNode>();
     if (op == call_tir_op_.get()) {
-      // Skip ExternFunc for call_dps_packed.
-      if (const auto* global_var = call->args[0].as<GlobalVarNode>()) {
-        tir::PrimFunc func = 
Downcast<tir::PrimFunc>(mod_->Lookup(GetRef<GlobalVar>(global_var)));
+      const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
+      tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(global_var));
 
-        // Override args for call_tir
-        args = Downcast<Tuple>(call->args[1])->fields;
+      // Override args for call_tir
+      args = Downcast<Tuple>(call->args[1])->fields;
 
-        Optional<Integer> opt_pattern = func->GetAttr<Integer>("op_pattern");
-        if (opt_pattern.defined()) {
-          pattern = 
static_cast<OpPatternKind>(Downcast<IntImm>(opt_pattern)->value);
-        } else {
-          pattern = OpPatternKind::kOpaque;
-        }
+      Optional<Integer> opt_pattern = func->GetAttr<Integer>("op_pattern");
+      if (opt_pattern.defined()) {
+        pattern = 
static_cast<OpPatternKind>(Downcast<IntImm>(opt_pattern)->value);
+      } else {
+        pattern = OpPatternKind::kOpaque;
       }
     }
     // The pattern of the current binding variable node is set to the pattern 
of this operator.
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 925f09d85d..e90d6e4bc1 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -302,11 +302,10 @@ class FusedTIRConstructor : public ExprVisitor {
 
     // Step 1. Get Global var and PrimFunc
     GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
-    Optional<tir::PrimFunc> prim_func_ = GetPrimFunc(gv);
-    ICHECK(prim_func_.defined()) << "Cannot find the prim_func of the call_tir 
in the module: "
-                                 << gv;
+    tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+
     // Step 2. Renew all vars/buffer definitions and blocks to avoid 
duplication
-    tir::PrimFunc prim_func = tir::RenewDefs(prim_func_.value());
+    tir::PrimFunc prim_func = tir::RenewDefs(prim_func_);
 
     // Step 3. Check functions are all schedulable funcs. i.e. the body of 
func is root block
     // TODO(Siyuan): support un-schedulable functions.
@@ -364,22 +363,6 @@ class FusedTIRConstructor : public ExprVisitor {
     LOG(FATAL) << "Relax.Constant is not supported in primitive functions.";
   }
 
-  /********** Helper Functions **********/
-
-  /*!
-   * \brief Pattern match op to a TIR function and look it up.
-   * \return The TIR function, or NullOpt if patter match fails.
-   */
-  Optional<tir::PrimFunc> GetPrimFunc(const GlobalVar& global_var) {
-    // NOTE: as check works for nullptr(returns null)
-    Optional<BaseFunc> base_func = mod_->functions.Get(global_var);
-    if (auto* pfunc = base_func.as<tir::PrimFuncNode>()) {
-      return GetRef<tir::PrimFunc>(pfunc);
-    } else {
-      return NullOpt;
-    }
-  }
-
   /*!
    * \brief Get the number of outputs for a call_tir node.
    * \return The number of outputs.
diff --git a/src/relax/transform/legalize_ops.cc 
b/src/relax/transform/legalize_ops.cc
index f9a84c5361..350a40c37b 100644
--- a/src/relax/transform/legalize_ops.cc
+++ b/src/relax/transform/legalize_ops.cc
@@ -75,6 +75,7 @@ class LegalizeMutator : public ExprMutator {
     Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
     static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
     static const Op& call_tir_op = Op::Get("relax.call_tir");
+    static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
     auto* op_node = visited_call->op.as<OpNode>();
 
     // Not an OpNode
@@ -102,7 +103,7 @@ class LegalizeMutator : public ExprMutator {
     }
 
     // No legalization.
-    if (op != call_tir_op) {
+    if (op != call_tir_op && op != call_dps_packed_op) {
       LOG(WARNING) << "No legalization func for " << op->name << " is found.";
     }
     return visited_call;
diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc 
b/src/relax/transform/rewrite_dataflow_reshape.cc
index aec0911ecc..e5d654fba3 100644
--- a/src/relax/transform/rewrite_dataflow_reshape.cc
+++ b/src/relax/transform/rewrite_dataflow_reshape.cc
@@ -75,11 +75,8 @@ class DataflowReshapeRewriter : public ExprMutator {
     if (call->op != call_tir_op) {
       return false;
     }
-    const auto* gv = call->args[0].as<GlobalVarNode>();
-    if (gv == nullptr) {
-      return false;
-    }
-    const auto* func = 
mod_->functions.Get(GetRef<GlobalVar>(gv)).as<tir::PrimFuncNode>();
+    const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
+    const auto* func = mod_->functions.Get(global_var).as<tir::PrimFuncNode>();
     ICHECK_NOTNULL(func);
     return HasReshapePattern(GetRef<tir::PrimFunc>(func));
   }
diff --git a/src/relax/transform/run_codegen.cc 
b/src/relax/transform/run_codegen.cc
index 7deeb139d1..b5a4d7536f 100644
--- a/src/relax/transform/run_codegen.cc
+++ b/src/relax/transform/run_codegen.cc
@@ -75,17 +75,18 @@ class CodeGenRunner : ExprMutator {
     if (auto const* gvar_node = call_node->op.as<GlobalVarNode>()) {
       const GlobalVar gvar = GetRef<GlobalVar>(gvar_node);
 
-      auto create_call_tir = [call_node, this](Expr extern_func, StructInfo 
ret_struct_info) {
+      auto create_call_dps_packed = [call_node, this](Expr extern_func,
+                                                      StructInfo 
ret_struct_info) {
         Array<Expr> new_args({extern_func});
         new_args.push_back(Tuple(call_node->args.Map([this](Expr arg) { return 
VisitExpr(arg); })));
 
-        static const Op& call_op = Op::Get("relax.call_tir");
+        static const Op& call_op = Op::Get("relax.call_dps_packed");
 
         return Call(call_op, new_args, tvm::Attrs(), {ret_struct_info});
       };
 
       if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) {
-        return create_call_tir(it->second.first, it->second.second);
+        return create_call_dps_packed(it->second.first, it->second.second);
       } else {
         // TODO(@sunggg): Is there any better way to get this func?
         Function func = 
Downcast<Function>(builder_->GetContextIRModule()->Lookup(gvar));
@@ -101,7 +102,7 @@ class CodeGenRunner : ExprMutator {
           func = (*RemoveFuncAttrFunc)(func, tvm::attr::kGlobalSymbol);
           func = (*RemoveFuncAttrFunc)(func, attr::kCodegen);
           builder_->UpdateFunction(gvar, func);
-          return create_call_tir(new_func, func->ret_struct_info);
+          return create_call_dps_packed(new_func, func->ret_struct_info);
         }
       }
     }
diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc
index 2feb2082c5..e99b81df8b 100644
--- a/src/script/printer/relax/call.cc
+++ b/src/script/printer/relax/call.cc
@@ -95,9 +95,11 @@ ExprDoc PrintCallee(const relax::Expr& n, const ObjectPath& 
n_p, const IRDocsifi
   }
 }
 
-Optional<ExprDoc> PrintCallTIR(const relax::Call& n, const ObjectPath& n_p, 
const IRDocsifier& d) {
+Optional<ExprDoc> PrintCallTIRDPSPacked(const relax::Call& n, const 
ObjectPath& n_p,
+                                        const IRDocsifier& d) {
   static const Op& call_tir_op = Op::Get("relax.call_tir");
-  if (!n->op.same_as(call_tir_op)) {
+  static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
+  if (!n->op.same_as(call_tir_op) && !n->op.same_as(call_dps_packed_op)) {
     return NullOpt;
   }
   ICHECK(n->args.size() == 2 || n->args.size() == 3);
@@ -123,6 +125,9 @@ Optional<ExprDoc> PrintCallTIR(const relax::Call& n, const 
ObjectPath& n_p, cons
   } else {
     kwargs_values.push_back(d->AsDoc<ExprDoc>(o_sinfo, o_sinfo_p));
   }
+  if (n->op.same_as(call_dps_packed_op)) {
+    return Relax(d, "call_dps_packed")->Call(args, kwargs_keys, kwargs_values);
+  }
   // Step 4. Print n->args[2], the tir variables
   if (n->args.size() == 3) {
     kwargs_keys.push_back("tir_vars");
@@ -134,8 +139,8 @@ Optional<ExprDoc> PrintCallTIR(const relax::Call& n, const 
ObjectPath& n_p, cons
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<relax::Call>(  //
         "", [](relax::Call n, ObjectPath n_p, IRDocsifier d) -> Doc {
-          // Special case: call_tir
-          if (Optional<ExprDoc> doc = PrintCallTIR(n, n_p, d)) {
+          // Special case: call_tir, call_dps_packed
+          if (Optional<ExprDoc> doc = PrintCallTIRDPSPacked(n, n_p, d)) {
             return doc.value();
           }
           ExprDoc prefix{nullptr};
diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h
index 7702f7b22d..8c4281ad78 100644
--- a/src/script/printer/relax/utils.h
+++ b/src/script/printer/relax/utils.h
@@ -82,7 +82,8 @@ inline Optional<ExprDoc> StructInfoAsAnn(const relax::Var& v, 
const ObjectPath&
   }
   if (const auto* call = rhs.as<relax::CallNode>()) {
     static const Op& call_tir_op = Op::Get("relax.call_tir");
-    if (call->op.same_as(call_tir_op)) {
+    static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
+    if (call->op.same_as(call_tir_op) || call->op.same_as(call_dps_packed_op)) 
{
       return NullOpt;
     }
   }
diff --git a/tests/python/relax/test_analysis.py 
b/tests/python/relax/test_analysis.py
index 4a345224e5..8b26a2aa64 100644
--- a/tests/python/relax/test_analysis.py
+++ b/tests/python/relax/test_analysis.py
@@ -66,8 +66,10 @@ def test_chained_remove_all_unused():
         def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
             with R.dataflow():
                 lv0 = x
-                unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), 
dtype="float32"))
-                unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 
32), dtype="float32"))
+                unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 
32), dtype="float32"))
+                unused1 = R.call_dps_packed(
+                    "my_dps_func", (unused0,), R.Tensor((32, 32), 
dtype="float32")
+                )
                 R.output(lv0)
             return lv0
 
@@ -92,8 +94,10 @@ def test_binding_block_remove_all_unused():
         def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
             with R.dataflow():
                 lv0 = x
-                unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), 
dtype="float32"))
-                unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 
32), dtype="float32"))
+                unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 
32), dtype="float32"))
+                unused1 = R.call_dps_packed(
+                    "my_dps_func", (unused0,), R.Tensor((32, 32), 
dtype="float32")
+                )
                 R.output(lv0)
             z = R.call_packed("vm.builtin.copy", lv0, 
sinfo_args=(R.Tensor((32, 32), "float32")))
             return z
diff --git a/tests/python/relax/test_analysis_well_formed.py 
b/tests/python/relax/test_analysis_well_formed.py
index 7b8035b17c..49d2b76011 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -492,7 +492,7 @@ def test_sinfo_args_tir_var_used_before_define_call_tir():
     # Error: Symbolic Var m1, n1 are not defined
     m1 = tir.Var("m1", "int64")
     n1 = tir.Var("n1", "int64")
-    call = R.call_tir("my_func", x, out_sinfo=R.Tensor((m1, n1), "float32"))
+    call = R.call_dps_packed("my_func", x, out_sinfo=R.Tensor((m1, n1), 
"float32"))
     func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), 
call)])])
     mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
     assert not rx.analysis.well_formed(mod, check_struct_info=False)
@@ -505,12 +505,12 @@ def test_sinfo_erase_to_well_formed():
     def foo(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m1", 
"n1"), dtype="float32"):
         m = T.int64()
         n = T.int64()
-        gv = R.call_tir("my_func", (x,), out_sinfo=R.Tensor((m, n), 
dtype="float32"))
+        gv = R.call_dps_packed("my_func", (x,), out_sinfo=R.Tensor((m, n), 
dtype="float32"))
         return gv
     """
     m1 = tir.Var("m1", "int64")
     n1 = tir.Var("n1", "int64")
-    call = R.call_tir("my_func", x, out_sinfo=R.Tensor((m, n), "float32"))
+    call = R.call_dps_packed("my_func", x, out_sinfo=R.Tensor((m, n), 
"float32"))
     blocks = [rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]
     seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var)
     func = rx.Function([x], seq_expr, R.Tensor((m1, n1), "float32")).with_attr(
diff --git a/tests/python/relax/test_ast_printer.py 
b/tests/python/relax/test_ast_printer.py
index e7f2feeaa0..de71e81464 100644
--- a/tests/python/relax/test_ast_printer.py
+++ b/tests/python/relax/test_ast_printer.py
@@ -429,10 +429,74 @@ def test_call_packed():
 
 def test_call_tir():
     # also from test_parser
+    @tvm.script.ir_module
+    class TestCallTIR:
+        @T.prim_func
+        def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), 
"int32")) -> None:
+            T.func_attr(({"global_symbol": "addone"}))
+            for i, j in T.grid(16, 16):
+                with T.block("addone"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    B[vi, vj] = A[vi, vj] + T.int32(1)
+
+        @R.function
+        def foo(x: R.Tensor(("m", "n"), "float32")):
+            m, n = T.var("int64"), T.var("int64")
+            gv0 = R.call_tir(addone, (x,), R.Tensor((m, n), dtype="float32"))
+            return gv0
+
+    mod = TestCallTIR
+    foo = mod["foo"]
+
+    foo_str = strip_whitespace(
+        dump_ast(
+            foo,
+            include_type_annotations=False,
+            include_struct_info_annotations=False,
+            include_call_attrs=False,
+        )
+    )
+    assert foo_str.startswith('Function(params=[Var(name_hint="x")]')
+
+    # call_tir is an op in Relax and it takes an extern func as an argument
+    assert isinstance(foo.body, rx.SeqExpr)
+    tir_call = foo.body.blocks[0].bindings[0].value
+    tir_call_text = dump_ast(
+        tir_call,
+        include_type_annotations=False,
+        include_struct_info_annotations=False,
+        include_call_attrs=False,
+    )
+    assert_fields(
+        "Call",
+        {
+            "op": 'Op(name="relax.call_tir")',
+            "args": """[
+                GlobalVar(name_hint="addone"),
+                Tuple(fields=[Var(name_hint="x")])
+            ]""",
+            "sinfo_args": """[
+                TensorStructInfo(
+                    dtype=float32,
+                    shape=ShapeExpr(
+                        values=[
+                            PrimExpr(value=`m`),
+                            PrimExpr(value=`n`)
+                        ]
+                    )
+                )
+            ]""",
+        },
+        tir_call_text,
+    )
+    assert strip_whitespace(tir_call_text) in foo_str
+
+
+def test_call_dps_packed():
     @R.function
     def foo(x: R.Tensor(("m", "n"), "float32")):
-        m, n = T.int64(), T.int64()
-        gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), 
dtype="float32"))
+        m, n = T.var("int64"), T.var("int64")
+        gv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), 
dtype="float32"))
         return gv0
 
     foo_str = strip_whitespace(
@@ -445,7 +509,7 @@ def test_call_tir():
     )
     assert foo_str.startswith('Function(params=[Var(name_hint="x")]')
 
-    # call_tir is an op in Relax and it takes an extern func as an argument
+    # call_dps_packed is an op in Relax and it takes an extern func as an 
argument
     assert isinstance(foo.body, rx.SeqExpr)
     tir_call = foo.body.blocks[0].bindings[0].value
     tir_call_text = dump_ast(
@@ -457,7 +521,7 @@ def test_call_tir():
     assert_fields(
         "Call",
         {
-            "op": 'Op(name="relax.call_tir")',
+            "op": 'Op(name="relax.call_dps_packed")',
             "args": """[
                 ExternFunc(global_symbol="test.op.identity"),
                 Tuple(fields=[Var(name_hint="x")])
diff --git a/tests/python/relax/test_binding_rewrite.py 
b/tests/python/relax/test_binding_rewrite.py
index 1b424b9792..d0d3344eb6 100644
--- a/tests/python/relax/test_binding_rewrite.py
+++ b/tests/python/relax/test_binding_rewrite.py
@@ -228,8 +228,10 @@ def test_chained_rm_all_unused():
         def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
             with R.dataflow():
                 lv0 = x
-                unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), 
dtype="float32"))
-                unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 
32), dtype="float32"))
+                unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 
32), dtype="float32"))
+                unused1 = R.call_dps_packed(
+                    "my_sigmoid", (unused0,), R.Tensor((32, 32), 
dtype="float32")
+                )
                 R.output(lv0)
             return lv0
 
@@ -262,19 +264,19 @@ def test_simple_replace_all_uses():
             #  \   /
             #   lv4
             with R.dataflow():
-                lv0: R.Tensor((32, 32), "float32") = R.call_tir(
+                lv0: R.Tensor((32, 32), "float32") = R.call_dps_packed(
                     "my_relu", (x,), R.Tensor((32, 32), dtype="float32")
                 )
-                lv1: R.Tensor((32, 32), "float32") = R.call_tir(
+                lv1: R.Tensor((32, 32), "float32") = R.call_dps_packed(
                     "my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")
                 )
-                lv2: R.Tensor((32, 32), "float32") = R.call_tir(
+                lv2: R.Tensor((32, 32), "float32") = R.call_dps_packed(
                     "my_add", (x, lv0), R.Tensor((32, 32), dtype="float32")
                 )
-                lv3: R.Tensor((32, 32), "float32") = R.call_tir(
+                lv3: R.Tensor((32, 32), "float32") = R.call_dps_packed(
                     "my_mul", (x, lv0), R.Tensor((32, 32), dtype="float32")
                 )
-                lv4: R.Tensor((32, 32), "float32") = R.call_tir(
+                lv4: R.Tensor((32, 32), "float32") = R.call_dps_packed(
                     "my_whatever", (lv2, lv3), R.Tensor((32, 32), 
dtype="float32")
                 )
                 R.output(lv4)
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index ab7a5540ad..ba6ea99523 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -354,15 +354,15 @@ def test_simple_oub():
 
 def test_counter_syntax_match():
     with PatternContext() as ctx:
-        n0 = is_call_tir_extern("tir_matmul")
-        n1 = is_call_tir_extern("tir_impossible")
+        n0 = is_call_dps_packed("extern_matmul")
+        n1 = is_call_dps_packed("extern_impossible")
         n0 >> n1
         dfb = main_fn.body.blocks[0]
         assert not ctx.match_dfb(dfb)
 
     with PatternContext() as ctx:
-        n0 = is_call_tir_extern("tir_matmul")
-        n1 = is_call_tir_extern("tir_impossible")
+        n0 = is_call_dps_packed("extern_matmul")
+        n1 = is_call_dps_packed("extern_impossible")
         n0 ^ n1
         dfb = main_fn.body.blocks[0]
         assert not ctx.match_dfb(dfb)
@@ -378,20 +378,20 @@ class Diamond:
             # relu  sigmoid
             #  \      /
             #    add
-            lv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32), 
dtype="float32"))
-            lv1 = R.call_tir("tir_relu", (lv0,), R.Tensor((32, 32), 
dtype="float32"))
-            lv2 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32), 
dtype="float32"))
-            lv3 = R.call_tir("tir_add", (lv1, lv2), R.Tensor((32, 32), 
dtype="float32"))
+            lv0 = R.call_dps_packed("extern_matmul", (x, w), R.Tensor((32, 
32), dtype="float32"))
+            lv1 = R.call_dps_packed("extern_relu", (lv0,), R.Tensor((32, 32), 
dtype="float32"))
+            lv2 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32, 
32), dtype="float32"))
+            lv3 = R.call_dps_packed("extern_add", (lv1, lv2), R.Tensor((32, 
32), dtype="float32"))
             R.output(lv3)
         return lv3
 
 
 def test_diamond():
     with PatternContext() as ctx:
-        n0 = is_call_tir_extern("tir_matmul")
-        n1 = is_call_tir_extern("tir_relu")
-        n2 = is_call_tir_extern("tir_sigmoid")
-        n3 = is_call_tir_extern("tir_add")
+        n0 = is_call_dps_packed("extern_matmul")
+        n1 = is_call_dps_packed("extern_relu")
+        n2 = is_call_dps_packed("extern_sigmoid")
+        n3 = is_call_dps_packed("extern_add")
 
         n0 ^ n1
         n0 ^ n2
@@ -399,15 +399,15 @@ def test_diamond():
         n2 >> n3
 
         dfb = Diamond["main"].body.blocks[0]
-        assert ctx.match_dfb(dfb)
 
+        assert ctx.match_dfb(dfb)
     # simplify it with fork_to
     with PatternContext() as ctx:
-        n1 = is_call_tir_extern("tir_relu")
-        n2 = is_call_tir_extern("tir_sigmoid")
-        n3 = is_call_tir_extern("tir_add")
+        n1 = is_call_dps_packed("extern_relu")
+        n2 = is_call_dps_packed("extern_sigmoid")
+        n3 = is_call_dps_packed("extern_add")
 
-        is_call_tir_extern("tir_matmul").fork_to(n1, n2)
+        is_call_dps_packed("extern_matmul").fork_to(n1, n2)
         n1 >> n3
         n2 >> n3
 
@@ -417,10 +417,10 @@ def test_diamond():
 
 def test_diamond_counter_oub():
     with PatternContext() as ctx:
-        n0 = is_call_tir_extern("tir_matmul")
-        n1 = is_call_tir_extern("tir_relu")
-        n2 = is_call_tir_extern("tir_sigmoid")
-        n3 = is_call_tir_extern("tir_add")
+        n0 = is_call_dps_packed("extern_matmul")
+        n1 = is_call_dps_packed("extern_relu")
+        n2 = is_call_dps_packed("extern_sigmoid")
+        n3 = is_call_dps_packed("extern_add")
 
         n0 >> n1
         n0 >> n2
@@ -440,8 +440,8 @@ class SmallDiamond:
             #  /      \
             #  \      /
             #    add
-            lv0 = R.call_tir("my_relu", (x,), R.Tensor((32, 32), 
dtype="float32"))
-            lv1 = R.call_tir("my_add", (lv0, lv0), R.Tensor((32, 32), 
dtype="float32"))
+            lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), 
dtype="float32"))
+            lv1 = R.call_dps_packed("my_add", (lv0, lv0), R.Tensor((32, 32), 
dtype="float32"))
             R.output(lv1)
         return lv1
 
@@ -454,9 +454,9 @@ class SmallParallel:
             # relu   relu
             #   \    /
             #    add
-            lv0 = R.call_tir("my_relu", (x,), R.Tensor((32, 32), 
dtype="float32"))
-            lv1 = R.call_tir("my_relu", (x,), R.Tensor((32, 32), 
dtype="float32"))
-            lv2 = R.call_tir("my_add", (lv0, lv1), R.Tensor((32, 32), 
dtype="float32"))
+            lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), 
dtype="float32"))
+            lv1 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), 
dtype="float32"))
+            lv2 = R.call_dps_packed("my_add", (lv0, lv1), R.Tensor((32, 32), 
dtype="float32"))
             R.output(lv2)
         return lv2
 
@@ -468,8 +468,8 @@ def test_distiguish_diamond_and_parallel():
 
     with PatternContext() as ctx:
         # describe a diamond pattern
-        fork = is_call_tir_extern("my_relu")
-        join = is_call_tir_extern("my_add")
+        fork = is_call_dps_packed("my_relu")
+        join = is_call_dps_packed("my_add")
         fork.only_used_by(join, index=0)
         fork.only_used_by(join, index=1)
 
@@ -478,13 +478,13 @@ def test_distiguish_diamond_and_parallel():
 
     with PatternContext() as ctx:
         # describe a parallel pattern
-        join = is_call_tir_extern("my_add")
+        join = is_call_dps_packed("my_add")
         # Due to one-one mathcing:
-        # is_call_tir_extern("my_relu") creates the 1st relu
-        is_call_tir_extern("my_relu") >> join
-        # is_call_tir_extern("my_relu")
+        # is_call_dps_packed("my_relu") creates the 1st relu
+        is_call_dps_packed("my_relu") >> join
+        # is_call_dps_packed("my_relu")
         # creates the another different relu (obj address is different)
-        is_call_tir_extern("my_relu") >> join
+        is_call_dps_packed("my_relu") >> join
 
         assert ctx.match_dfb(parallel)
         assert not ctx.match_dfb(diamond)
@@ -507,13 +507,13 @@ class CBRx2:
         #     \   /
         #     concat
         with R.dataflow():
-            lv0 = R.call_tir("conv1x1", (x, w0), R.Tensor((32, 32), 
dtype="float32"))
-            lv1 = R.call_tir("bias_add", (lv0, bias0), R.Tensor((32, 32), 
dtype="float32"))
-            lv2 = R.call_tir("my_relu", (lv1), R.Tensor((32, 32), 
dtype="float32"))
-            lv3 = R.call_tir("conv1x1", (x, w1), R.Tensor((32, 32), 
dtype="float32"))
-            lv4 = R.call_tir("bias_add", (lv3, bias1), R.Tensor((32, 32), 
dtype="float32"))
-            lv5 = R.call_tir("my_relu", (lv4), R.Tensor((32, 32), 
dtype="float32"))
-            lv6 = R.call_tir("concat", (lv2, lv5), R.Tensor((32, 64), 
dtype="float32"))
+            lv0 = R.call_dps_packed("conv1x1", (x, w0), R.Tensor((32, 32), 
dtype="float32"))
+            lv1 = R.call_dps_packed("bias_add", (lv0, bias0), R.Tensor((32, 
32), dtype="float32"))
+            lv2 = R.call_dps_packed("my_relu", (lv1), R.Tensor((32, 32), 
dtype="float32"))
+            lv3 = R.call_dps_packed("conv1x1", (x, w1), R.Tensor((32, 32), 
dtype="float32"))
+            lv4 = R.call_dps_packed("bias_add", (lv3, bias1), R.Tensor((32, 
32), dtype="float32"))
+            lv5 = R.call_dps_packed("my_relu", (lv4), R.Tensor((32, 32), 
dtype="float32"))
+            lv6 = R.call_dps_packed("concat", (lv2, lv5), R.Tensor((32, 64), 
dtype="float32"))
             R.output(lv6)
         return lv6
 
@@ -521,9 +521,9 @@ class CBRx2:
 def test_single_cbr():
     with PatternContext() as ctx:
         (
-            is_call_tir_extern("conv1x1")
-            >> is_call_tir_extern("bias_add")
-            >> is_call_tir_extern("my_relu")
+            is_call_dps_packed("conv1x1")
+            >> is_call_dps_packed("bias_add")
+            >> is_call_dps_packed("my_relu")
         )
         dfb = CBRx2["main"].body.blocks[0]
         matched = ctx.match_dfb(dfb)
@@ -531,9 +531,9 @@ def test_single_cbr():
 
     with PatternContext() as ctx:
         chain = (
-            is_call_tir_extern("conv1x1")
-            >> is_call_tir_extern("bias_add")
-            >> is_call_tir_extern("my_relu")
+            is_call_dps_packed("conv1x1")
+            >> is_call_dps_packed("bias_add")
+            >> is_call_dps_packed("my_relu")
         )
         dfb = CBRx2["main"].body.blocks[0]
         # we want to specifically match the first CBR (lv0)
@@ -549,9 +549,9 @@ def test_single_cbr():
 def test_counter_single_crb():
     with PatternContext() as ctx:
         (
-            is_call_tir_extern("conv1x1")
-            >> is_call_tir_extern("my_relu")
-            >> is_call_tir_extern("bias_add")
+            is_call_dps_packed("conv1x1")
+            >> is_call_dps_packed("my_relu")
+            >> is_call_dps_packed("bias_add")
         )
         dfb = CBRx2["main"].body.blocks[0]
         assert not ctx.match_dfb(dfb)
@@ -567,14 +567,14 @@ def test_nested_context():
     dfb = CBRx2["main"].body.blocks[0]
     with PatternContext() as ctx0:
         (
-            is_call_tir_extern("conv1x1")
-            >> is_call_tir_extern("bias_add")
-            >> is_call_tir_extern("my_relu")
+            is_call_dps_packed("conv1x1")
+            >> is_call_dps_packed("bias_add")
+            >> is_call_dps_packed("my_relu")
         )
         with PatternContext() as ctx1:
-            is_call_tir_extern("conv1x1") >> is_call_tir_extern("my_relu")  # 
pattern to miss
+            is_call_dps_packed("conv1x1") >> is_call_dps_packed("my_relu")  # 
pattern to miss
             with PatternContext() as ctx2:
-                is_call_tir_extern("bias_add") >> is_call_tir_extern("my_relu")
+                is_call_dps_packed("bias_add") >> is_call_dps_packed("my_relu")
                 assert ctx2.match_dfb(dfb)
                 assert PatternContext.current() == ctx2
             assert not ctx1.match_dfb(dfb)
@@ -586,9 +586,9 @@ def test_nested_context():
 def test_two_cbr():
     with PatternContext() as ctx:
         cbr0 = (
-            is_call_tir_extern("conv1x1")
-            >> is_call_tir_extern("bias_add")
-            >> is_call_tir_extern("my_relu")
+            is_call_dps_packed("conv1x1")
+            >> is_call_dps_packed("bias_add")
+            >> is_call_dps_packed("my_relu")
         )
         cbr1 = cbr0.dup()
 
@@ -603,9 +603,9 @@ def test_two_cbr():
     with PatternContext() as ctx:
         # Deny the pattern
         cbr0 = (
-            is_call_tir_extern("conv1x1")
-            >> is_call_tir_extern("bias_add")
-            >> is_call_tir_extern("my_relu")
+            is_call_dps_packed("conv1x1")
+            >> is_call_dps_packed("bias_add")
+            >> is_call_dps_packed("my_relu")
         )
         cbr1 = cbr0.dup()
 
@@ -626,25 +626,25 @@ def test_two_matmul():
             c: R.Tensor((48, 32), "float32"),
         ) -> R.Tensor:
             with R.dataflow():
-                lv0 = R.call_tir("matmul", (a, b), R.Tensor((32, 48), 
dtype="float32"))
-                lv1 = R.call_tir("matmul", (lv0, c), R.Tensor((32, 32), 
dtype="float32"))
+                lv0 = R.call_dps_packed("matmul", (a, b), R.Tensor((32, 48), 
dtype="float32"))
+                lv1 = R.call_dps_packed("matmul", (lv0, c), R.Tensor((32, 32), 
dtype="float32"))
                 R.output(lv1)
             return lv1
 
     with PatternContext() as ctx:
-        is_call_tir_extern("matmul") >> is_call_tir_extern("matmul")
+        is_call_dps_packed("matmul") >> is_call_dps_packed("matmul")
         dfb = MatMul2["main"].body.blocks[0]
         assert ctx.match_dfb(dfb)
 
     with PatternContext() as ctx:
-        is_call_tir_extern("matmul").has_shape([32, 48]) >> 
is_call_tir_extern("matmul").has_shape(
+        is_call_dps_packed("matmul").has_shape([32, 48]) >> 
is_call_dps_packed("matmul").has_shape(
             [32, 32]
         )
         dfb = MatMul2["main"].body.blocks[0]
         assert ctx.match_dfb(dfb)
 
     with PatternContext() as ctx:
-        is_call_tir_extern("matmul") >> is_call_tir_extern("matmul") >> 
is_call_tir_extern("matmul")
+        is_call_dps_packed("matmul") >> is_call_dps_packed("matmul") >> 
is_call_dps_packed("matmul")
         dfb = MatMul2["main"].body.blocks[0]
         # Three MatMul cannot match
         assert not ctx.match_dfb(dfb)
@@ -661,9 +661,9 @@ def test_concat_mm_split():
             c: R.Tensor((16, 32), "float32"),
         ) -> R.Tensor:
             with R.dataflow():
-                lv0 = R.call_tir("my_concat", (b, c), R.Tensor((32, 32), 
dtype="float32"))
-                lv1 = R.call_tir("my_matmul", (a, lv0), R.Tensor((32, 32), 
dtype="float32"))
-                lv2 = R.call_tir(
+                lv0 = R.call_dps_packed("my_concat", (b, c), R.Tensor((32, 
32), dtype="float32"))
+                lv1 = R.call_dps_packed("my_matmul", (a, lv0), R.Tensor((32, 
32), dtype="float32"))
+                lv2 = R.call_dps_packed(
                     "my_split",
                     (lv1,),
                     [R.Tensor((16, 32), dtype="float32"), R.Tensor((16, 32), 
dtype="float32")],
@@ -676,15 +676,15 @@ def test_concat_mm_split():
 
     with PatternContext() as ctx:
         (
-            is_call_tir_extern("my_concat")
-            >> is_call_tir_extern("my_matmul")
-            >> is_call_tir_extern("my_split")
+            is_call_dps_packed("my_concat")
+            >> is_call_dps_packed("my_matmul")
+            >> is_call_dps_packed("my_split")
         )
         dfb = CMS["main"].body.blocks[0]
         assert ctx.match_dfb(dfb)
 
     with PatternContext() as ctx:
-        split = is_call_tir_extern("my_split")
+        split = is_call_dps_packed("my_split")
         lv3 = TupleGetItemPattern(split, 0).has_shape([16, 32])
         lv4 = TupleGetItemPattern(split, 1).has_shape([16, 32])
         split.fork_to(lv3, lv4)
@@ -711,18 +711,26 @@ def test_self_attention():
         ) -> R.Tensor:
             b, s, n, h = T.int64(), T.int64(), T.int64(), T.int64()
             with R.dataflow():
-                fcq = R.call_tir("my_fc", (x, wq), R.Tensor((b, s, n, h), 
dtype="float32"))
-                tpq = R.call_tir("my_transpose", (fcq,), R.Tensor((b, s, h, 
n), dtype="float32"))
+                fcq = R.call_dps_packed("my_fc", (x, wq), R.Tensor((b, s, n, 
h), dtype="float32"))
+                tpq = R.call_dps_packed(
+                    "my_transpose", (fcq,), R.Tensor((b, s, h, n), 
dtype="float32")
+                )
 
-                fck = R.call_tir("my_fc", (x, wk), R.Tensor((b, s, n, h), 
dtype="float32"))
-                tpk = R.call_tir("my_transpose", (fck,), R.Tensor((b, s, h, 
n), dtype="float32"))
+                fck = R.call_dps_packed("my_fc", (x, wk), R.Tensor((b, s, n, 
h), dtype="float32"))
+                tpk = R.call_dps_packed(
+                    "my_transpose", (fck,), R.Tensor((b, s, h, n), 
dtype="float32")
+                )
 
                 mul = R.multiply(tpq, tpk)
                 scale = R.multiply(mul, R.const(1.1, "float32"))
-                softmax = R.call_tir("softmax", (scale,), R.Tensor((b, s, n, 
h), dtype="float32"))
+                softmax = R.call_dps_packed(
+                    "softmax", (scale,), R.Tensor((b, s, n, h), 
dtype="float32")
+                )
 
-                fcv = R.call_tir("my_fc", (x, wv), R.Tensor((b, s, n, h), 
dtype="float32"))
-                tpv = R.call_tir("my_transpose", (fcv,), R.Tensor((b, s, h, 
n), dtype="float32"))
+                fcv = R.call_dps_packed("my_fc", (x, wv), R.Tensor((b, s, n, 
h), dtype="float32"))
+                tpv = R.call_dps_packed(
+                    "my_transpose", (fcv,), R.Tensor((b, s, h, n), 
dtype="float32")
+                )
 
                 out = R.multiply(softmax, tpv)
                 R.output(out)
@@ -730,7 +738,7 @@ def test_self_attention():
             return out
 
     with PatternContext() as ctx:
-        fc_trans_q = is_call_tir_extern("my_fc") >> 
is_call_tir_extern("my_transpose")
+        fc_trans_q = is_call_dps_packed("my_fc") >> 
is_call_dps_packed("my_transpose")
         fc_trans_k = fc_trans_q.dup()
         fc_trans_v = fc_trans_q.dup()
 
@@ -752,43 +760,59 @@ def test_nested_diamond():
                 #      add5      add6
                 #          \    /
                 #           add7
-                lv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32), 
dtype="float32"))
-                lv1 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32), 
dtype="float32"))
-                lv2 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32), 
dtype="float32"))
-                lv3 = R.call_tir("tir_sigmoid", (lv1), R.Tensor((32, 32), 
dtype="float32"))
-                lv4 = R.call_tir("tir_add", (lv0, lv1), R.Tensor((32, 32), 
dtype="float32"))
-                lv5 = R.call_tir("tir_add", (lv2, lv4), R.Tensor((32, 32), 
dtype="float32"))
-                lv6 = R.call_tir("tir_add", (lv3, lv4), R.Tensor((32, 32), 
dtype="float32"))
-                lv7 = R.call_tir("tir_add", (lv5, lv6), R.Tensor((32, 32), 
dtype="float32"))
+                lv0 = R.call_dps_packed(
+                    "extern_matmul", (x, w), R.Tensor((32, 32), 
dtype="float32")
+                )
+                lv1 = R.call_dps_packed(
+                    "extern_matmul", (x, w), R.Tensor((32, 32), 
dtype="float32")
+                )
+                lv2 = R.call_dps_packed(
+                    "extern_sigmoid", (lv0), R.Tensor((32, 32), 
dtype="float32")
+                )
+                lv3 = R.call_dps_packed(
+                    "extern_sigmoid", (lv1), R.Tensor((32, 32), 
dtype="float32")
+                )
+                lv4 = R.call_dps_packed(
+                    "extern_add", (lv0, lv1), R.Tensor((32, 32), 
dtype="float32")
+                )
+                lv5 = R.call_dps_packed(
+                    "extern_add", (lv2, lv4), R.Tensor((32, 32), 
dtype="float32")
+                )
+                lv6 = R.call_dps_packed(
+                    "extern_add", (lv3, lv4), R.Tensor((32, 32), 
dtype="float32")
+                )
+                lv7 = R.call_dps_packed(
+                    "extern_add", (lv5, lv6), R.Tensor((32, 32), 
dtype="float32")
+                )
                 R.output(lv7)
             return lv7
 
     # match matmul0 diamond
     with PatternContext() as ctx:
-        sigmoid2 = is_call_tir_extern("tir_sigmoid")
-        add4 = is_call_tir_extern("tir_add")
-        is_call_tir_extern("tir_matmul").fork_to(sigmoid2, add4)
-        add5 = is_call_tir_extern("tir_add")
+        sigmoid2 = is_call_dps_packed("extern_sigmoid")
+        add4 = is_call_dps_packed("extern_add")
+        is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4)
+        add5 = is_call_dps_packed("extern_add")
         sigmoid2 >> add5
         add4 ^ add5
         assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
 
     # counter case: mis-match matmul0 diamond
     with PatternContext() as ctx:
-        sigmoid2 = is_call_tir_extern("tir_sigmoid")
-        add4 = is_call_tir_extern("tir_add")
-        is_call_tir_extern("tir_matmul").fork_to(sigmoid2, add4)
-        add5 = is_call_tir_extern("tir_add")
+        sigmoid2 = is_call_dps_packed("extern_sigmoid")
+        add4 = is_call_dps_packed("extern_add")
+        is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4)
+        add5 = is_call_dps_packed("extern_add")
         sigmoid2 >> add5
         add4 >> add5  # not only-used-by relation
         assert not ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
 
     # match matmul1 diamond
     with PatternContext() as ctx:
-        sigmoid3 = is_call_tir_extern("tir_sigmoid")
-        add4 = is_call_tir_extern("tir_add")
-        is_call_tir_extern("tir_matmul").fork_to(sigmoid3, add4)
-        add6 = is_call_tir_extern("tir_add")
+        sigmoid3 = is_call_dps_packed("extern_sigmoid")
+        add4 = is_call_dps_packed("extern_add")
+        is_call_dps_packed("extern_matmul").fork_to(sigmoid3, add4)
+        add6 = is_call_dps_packed("extern_add")
         sigmoid3 >> add6
         add4 ^ add6
         assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
@@ -796,11 +820,11 @@ def test_nested_diamond():
     # match add-4-5-6-7
     with PatternContext() as ctx:
         add5, add6, add7 = (
-            is_call_tir_extern("tir_add"),
-            is_call_tir_extern("tir_add"),
-            is_call_tir_extern("tir_add"),
+            is_call_dps_packed("extern_add"),
+            is_call_dps_packed("extern_add"),
+            is_call_dps_packed("extern_add"),
         )
-        is_call_tir_extern("tir_add").fork_to(add5, add6)  # add4
+        is_call_dps_packed("extern_add").fork_to(add5, add6)  # add4
         add5 >> add7
         add6 >> add7
         assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])
@@ -811,15 +835,15 @@ def test_incremental_solving():
     def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
         with R.dataflow():
             # relu -> sigmoid -> neg
-            lv0 = R.call_tir("tir_relu", (x), R.Tensor((32, 32), 
dtype="float32"))
-            lv1 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32), 
dtype="float32"))
-            lv2 = R.call_tir("tir_neg", (lv1), R.Tensor((32, 32), 
dtype="float32"))
+            lv0 = R.call_dps_packed("extern_relu", (x), R.Tensor((32, 32), 
dtype="float32"))
+            lv1 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32, 
32), dtype="float32"))
+            lv2 = R.call_dps_packed("extern_neg", (lv1), R.Tensor((32, 32), 
dtype="float32"))
             R.output(lv2)
         return lv2
 
-    relu = is_call_tir_extern("tir_relu")
-    sigmoid = is_call_tir_extern("tir_sigmoid")
-    neg = is_call_tir_extern("tir_neg")
+    relu = is_call_dps_packed("extern_relu")
+    sigmoid = is_call_dps_packed("extern_sigmoid")
+    neg = is_call_dps_packed("extern_neg")
 
     with PatternContext() as ctx0:
         relu >> sigmoid
@@ -840,14 +864,14 @@ def test_incremental_solving_counter():
     def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
         with R.dataflow():
             # sigmoid -> neg
-            lv0 = R.call_tir("tir_sigmoid", (x), R.Tensor((32, 32), 
dtype="float32"))
-            lv1 = R.call_tir("tir_neg", (lv0), R.Tensor((32, 32), 
dtype="float32"))
+            lv0 = R.call_dps_packed("extern_sigmoid", (x), R.Tensor((32, 32), 
dtype="float32"))
+            lv1 = R.call_dps_packed("extern_neg", (lv0), R.Tensor((32, 32), 
dtype="float32"))
             R.output(lv1)
         return lv1
 
-    relu = is_call_tir_extern("tir_relu")
-    sigmoid = is_call_tir_extern("tir_sigmoid")
-    neg = is_call_tir_extern("tir_neg")
+    relu = is_call_dps_packed("extern_relu")
+    sigmoid = is_call_dps_packed("extern_sigmoid")
+    neg = is_call_dps_packed("extern_neg")
 
     with PatternContext() as ctx0:
         relu >> sigmoid  # cannot match
diff --git a/tests/python/relax/test_op_misc.py 
b/tests/python/relax/test_op_misc.py
index 523a628fa9..fd23911533 100644
--- a/tests/python/relax/test_op_misc.py
+++ b/tests/python/relax/test_op_misc.py
@@ -39,7 +39,7 @@ def identity_tir(a: T.handle, b: T.handle) -> None:
 
 def test_call_tir() -> None:
     v0 = rx.Var("v0", R.Tensor([54, 96], "float32"))
-    v1 = rx.call_tir(rx.extern("test.op.identity"), [v0], R.Tensor((54, 96), 
"float32"))
+    v1 = rx.call_dps_packed(rx.extern("test.op.identity"), [v0], R.Tensor((54, 
96), "float32"))
     v1 = rx.call_tir(identity_tir, [v0], R.Tensor((54, 96), "float32"))
 
 
diff --git a/tests/python/relax/test_transform.py 
b/tests/python/relax/test_transform.py
index 85de4f912e..3e6305c492 100644
--- a/tests/python/relax/test_transform.py
+++ b/tests/python/relax/test_transform.py
@@ -32,8 +32,10 @@ def test_to_non_dataflow():
         def foo(x: R.Tensor(("m", "n"), "float32")):
             m, n = T.int64(), T.int64()
             with R.dataflow():
-                lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), 
dtype="float32"))
-                gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), 
dtype="float32"))
+                lv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, 
n), dtype="float32"))
+                gv0 = R.call_dps_packed(
+                    "test.op.identity", (lv0,), R.Tensor((m, n), 
dtype="float32")
+                )
                 R.output(gv0)
             return gv0
 
@@ -73,10 +75,14 @@ def test_to_non_dataflow():
 def test_call_tir_rewrite():
     @tvm.script.ir_module
     class TestCallTIRRewrite:
+        @T.prim_func
+        def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), 
"float32")):
+            T.evaluate(0)
+
         @R.function
         def foo(x: R.Tensor(("m", "n"), "float32")):
             m, n = T.int64(), T.int64()
-            gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), 
dtype="float32"))
+            gv0 = R.call_tir(exp, (x,), R.Tensor((m, n), dtype="float32"))
             return gv0
 
     mod = TestCallTIRRewrite
@@ -94,6 +100,40 @@ def test_call_tir_rewrite():
     block = func.body.blocks[0]
     assert not isinstance(block, relax.DataflowBlock)
 
+    s1 = block.bindings[0].value
+    assert isinstance(s1, relax.Call)
+    assert s1.op.name == "relax.builtin.alloc_tensor"
+    assert isinstance(s1.args[0], relax.ShapeExpr)
+    assert structural_equal(s1.args[0], s0.sinfo_args[0].shape)
+    s2 = block.bindings[1].value
+    tvm.ir.expr.GlobalVar
+    assert s2.op.name_hint == "exp"
+
+
+def test_call_dps_packed_rewrite():
+    @tvm.script.ir_module
+    class TestCallDPSPackedRewrite:
+        @R.function
+        def foo(x: R.Tensor(("m", "n"), "float32")):
+            m, n = T.int64(), T.int64()
+            gv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), 
dtype="float32"))
+            return gv0
+
+    mod = TestCallDPSPackedRewrite
+
+    # before rewrite
+    v0 = mod["foo"].body.blocks[0].bindings[0].var
+    s0 = mod["foo"].body.blocks[0].bindings[0].value
+    assert isinstance(s0, relax.Call)
+    assert s0.op.name == "relax.call_dps_packed"
+
+    # CallTIRRewrite also works for call_dps_packed
+    new_mod = relax.transform.CallTIRRewrite()(mod)
+    func = new_mod["foo"]
+
+    block = func.body.blocks[0]
+    assert not isinstance(block, relax.DataflowBlock)
+
     s1 = block.bindings[0].value
     assert isinstance(s1, relax.Call)
     assert s1.op.name == "relax.builtin.alloc_tensor"
diff --git a/tests/python/relax/test_transform_attach_global_symbol.py 
b/tests/python/relax/test_transform_attach_global_symbol.py
index cef3842e3e..7fc6798e37 100644
--- a/tests/python/relax/test_transform_attach_global_symbol.py
+++ b/tests/python/relax/test_transform_attach_global_symbol.py
@@ -45,7 +45,7 @@ class Before:
     @R.function
     def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), 
"float32")) -> R.Tensor:
         m, n, k = T.int64(), T.int64(), T.int64()
-        gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), 
dtype="float32"))
+        gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((m, k), dtype="float32"))
         return gv0
 
 
@@ -75,7 +75,7 @@ def test_basic():
         ) -> R.Tensor:
             R.func_attr({"global_symbol": "main"})
             m, n, k = T.int64(), T.int64(), T.int64()
-            gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), 
dtype="float32"))
+            gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((m, k), 
dtype="float32"))
             return gv0
 
     before = Before
diff --git a/tests/python/relax/test_transform_bind_params.py 
b/tests/python/relax/test_transform_bind_params.py
index 1dfd9e0c8e..2a30586b1b 100644
--- a/tests/python/relax/test_transform_bind_params.py
+++ b/tests/python/relax/test_transform_bind_params.py
@@ -87,10 +87,10 @@ def test_bind_params_symbolic_vars():
             m = T.Var("m", "int64")
             n = T.Var("n", "int64")
             with R.dataflow():
-                lv0 = R.call_tir(
+                lv0 = R.call_dps_packed(
                     "linear0", (x, w0, b0), out_sinfo=R.Tensor((batch, n), 
dtype="float32")
                 )
-                out = R.call_tir(
+                out = R.call_dps_packed(
                     "linear1", (lv0, w1, b1), out_sinfo=R.Tensor((batch, k), 
dtype="float32")
                 )
                 R.output(out)
diff --git a/tests/python/relax/test_transform_codegen_pass.py 
b/tests/python/relax/test_transform_codegen_pass.py
index 3e9501147a..d82706200a 100644
--- a/tests/python/relax/test_transform_codegen_pass.py
+++ b/tests/python/relax/test_transform_codegen_pass.py
@@ -235,12 +235,12 @@ class Conv2dx2_after:
         weight2: R.Tensor((16, 3, 3, 16), dtype="float16"),
     ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
         with R.dataflow():
-            lv = R.call_tir(
+            lv = R.call_dps_packed(
                 "fused_relax_nn_conv2d_tensorrt",
                 (data, weight1),
                 out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"),
             )
-            gv = R.call_tir(
+            gv = R.call_dps_packed(
                 "fused_relax_nn_conv2d_tensorrt",
                 (lv, weight2),
                 out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"),
diff --git a/tests/python/relax/test_transform_fuse_ops.py 
b/tests/python/relax/test_transform_fuse_ops.py
index 33d57417cf..14d70ab77c 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -826,7 +826,7 @@ def test_skip_call_dps_packed():
         @R.function
         def main(x: R.Tensor((2, 3), "float32")):
             with R.dataflow():
-                y = R.call_tir("func_packed_dps", x, R.Tensor((2, 3), 
"float32"))
+                y = R.call_dps_packed("func_packed_dps", x, R.Tensor((2, 3), 
"float32"))
                 R.output(y)
             return y
 
@@ -842,7 +842,7 @@ def test_edge_with_call_dps_packed():
             with R.dataflow():
                 a = R.call_tir(exp, (x,), out_sinfo=R.Tensor((2, 3), 
"float32"))
                 b = R.call_tir(exp, (a,), out_sinfo=R.Tensor((2, 3), 
"float32"))
-                c = R.call_tir("packed_dps", (a,), out_sinfo=R.Tensor((2, 3), 
"float32"))
+                c = R.call_dps_packed("packed_dps", (a,), 
out_sinfo=R.Tensor((2, 3), "float32"))
                 R.output(b, c)
             return R.tuple(b, c)
 
diff --git a/tests/python/relax/test_transform_fuse_tir.py 
b/tests/python/relax/test_transform_fuse_tir.py
index c2784edec7..f8d488e43b 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -690,7 +690,7 @@ def test_skip_call_dps_packed():
         @R.function
         def main(x: R.Tensor((2, 3), "float32")):
             with R.dataflow():
-                y = R.call_tir("func_packed_dps", x, R.Tensor((2, 3), 
"float32"))
+                y = R.call_dps_packed("func_packed_dps", x, R.Tensor((2, 3), 
"float32"))
                 R.output(y)
             return y
 
diff --git a/tests/python/relax/test_transform_normalize.py 
b/tests/python/relax/test_transform_normalize.py
index da123f956d..874e83c7f9 100644
--- a/tests/python/relax/test_transform_normalize.py
+++ b/tests/python/relax/test_transform_normalize.py
@@ -124,8 +124,10 @@ def test_normalize_no_op():
         def foo(x: R.Tensor(("m", "n"), "float32")):
             m, n = T.int64(), T.int64()
             with R.dataflow():
-                lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), 
dtype="float32"))
-                gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), 
dtype="float32"))
+                lv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, 
n), dtype="float32"))
+                gv0 = R.call_dps_packed(
+                    "test.op.identity", (lv0,), R.Tensor((m, n), 
dtype="float32")
+                )
                 R.output(gv0)
             return gv0
 
diff --git a/tests/python/relax/test_tvmscript_ir_builder.py 
b/tests/python/relax/test_tvmscript_ir_builder.py
index f7c29b8dbe..e103e9cddd 100644
--- a/tests/python/relax/test_tvmscript_ir_builder.py
+++ b/tests/python/relax/test_tvmscript_ir_builder.py
@@ -25,7 +25,7 @@ def test_function_simple():
     """
     @R.function
     def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", 
ndim=2):
-        out = R.call_tir("extern_func", x, R.Tensor((128, 128), 
dtype="float32"))
+        out = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), 
dtype="float32"))
         return out
     """
     # create with Script IRBuilder
@@ -35,8 +35,15 @@ def test_function_simple():
             R.func_attr({"Primitive": 1})
             x = R.arg("x", relax.TensorStructInfo((128, 128), "float32"))
             R.func_ret_struct_info(relax.TensorStructInfo(dtype="float32", 
ndim=2))
+            y = R.emit(
+                R.call_dps_packed(
+                    "extern_func", x, relax.TensorStructInfo((128, 128), 
dtype="float32")
+                )
+            )
             out = R.emit(
-                R.call_tir("extern_func", x, relax.TensorStructInfo((128, 
128), dtype="float32"))
+                R.call_dps_packed(
+                    "extern_dps_func", y, relax.TensorStructInfo((128, 128), 
dtype="float32")
+                )
             )
             IRBuilder.name("out", out)
             R.func_ret_value(out)
@@ -45,8 +52,15 @@ def test_function_simple():
     x = relax.Var("x", relax.TensorStructInfo((128, 128), "float32"))
     bb = relax.BlockBuilder()
     with bb.function("foo", (x,), attrs={"Primitive": 1}):
+        y = bb.emit(
+            relax.call_dps_packed(
+                "extern_func", x, relax.TensorStructInfo((128, 128), 
dtype="float32")
+            )
+        )
         out = bb.emit(
-            relax.call_tir("extern_func", x, relax.TensorStructInfo((128, 
128), dtype="float32"))
+            relax.call_dps_packed(
+                "extern_dps_func", y, relax.TensorStructInfo((128, 128), 
dtype="float32")
+            )
         )
         bb.emit_func_output(out)
     mod = bb.get()
@@ -112,7 +126,7 @@ def test_dataflow_block():
     def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim 
= 2):
         # block 0
         with R.dataflow():
-            lv0 = R.call_tir("extern_func", (x,), R.Tensor((128, 128), 
dtype="float32"))
+            lv0 = R.call_dps_packed("extern_func", (x,), R.Tensor((128, 128), 
dtype="float32"))
             gv: Tensor((128, 128), "float32") = lv0
             R.output(gv)
         return gv
@@ -124,7 +138,7 @@ def test_dataflow_block():
             x = R.arg("x", relax.TensorStructInfo((128, 128), "float32"))
             with R.dataflow() as df:
                 lv0 = R.emit(
-                    R.call_tir(
+                    R.call_dps_packed(
                         "extern_func", x, relax.TensorStructInfo((128, 128), 
dtype="float32")
                     )
                 )
@@ -142,7 +156,7 @@ def test_dataflow_block():
     with bb.function("foo", (x,)):
         with bb.dataflow():
             lv0 = bb.emit(
-                relax.call_tir(
+                relax.call_dps_packed(
                     "extern_func", x, relax.TensorStructInfo((128, 128), 
dtype="float32")
                 )
             )
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index ffc024403d..136bd8c1ea 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -43,13 +43,17 @@ def test_simple_func():
     @R.function
     def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), 
"float32"):
         R.func_attr({"Primitive": 1})
-        gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), 
dtype="float32"))
-        return gv0
+        gv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), 
dtype="float32"))
+        gv1 = R.call_dps_packed("extern_dps_func", gv0, R.Tensor((128, 128), 
dtype="float32"))
+        return gv1
 
     x = relax.Var("x", R.Tensor((128, 128), "float32"))
     bb = relax.BlockBuilder()
     with bb.function("foo", (x,), attrs={"Primitive": 1}):
-        out = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), 
dtype="float32")))
+        y = bb.emit(relax.call_dps_packed("extern_func", x, R.Tensor((128, 
128), dtype="float32")))
+        out = bb.emit(
+            relax.call_dps_packed("extern_dps_func", y, R.Tensor((128, 128), 
dtype="float32"))
+        )
         bb.emit_func_output(out)
 
     _check(foo, bb.get()["foo"])
@@ -111,15 +115,34 @@ def test_unexpected_tir_cast_args():
             return R.call_tir("foo", (x,), R.Tensor((T.cast("int32", m, 1),), 
dtype="float32"))
 
 
-def test_unexpected_tir_max_args():
+def test_unexpected_tir_args():
+
+    with pytest.raises(tvm.error.DiagnosticError):
+
+        @tvm.script.ir_module
+        class TestWellCallTIR:
+            @T.prim_func
+            def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 
16), "int32")) -> None:
+                T.func_attr(({"global_symbol": "tir_addone"}))
+                for i, j in T.grid(16, 16):
+                    with T.block("tir_addone"):
+                        vi, vj = T.axis.remap("SS", [i, j])
+                        B[vi, vj] = A[vi, vj] + T.int32(1)
+
+            @R.function
+            def foo(x: R.Tensor(("m", "m"), "float32")):
+                m = T.int64()
+                # tir.max expects 2 arguments, but got 1
+                gv = R.call_tir(tir_addone, (x,), R.Tensor((T.max(16),), 
dtype="float32"))
+                return gv
 
     with pytest.raises(tvm.error.DiagnosticError):
 
         @R.function
         def f(x: R.Tensor(("m", "n"), "float32")):
             m = T.int64()
-            # tir.max expects 2 arguments, but got 1
-            return relax.call_tir("foo", (x,), R.Tensor((T.max(m),), 
dtype="float32"))
+            # call_tir expected a tir prim_func
+            return relax.call_tir("extern_func", (x,), R.Tensor((T.max(m),), 
dtype="float32"))
 
 
 def test_func_type_annotation_fail():
@@ -315,14 +338,14 @@ def test_symbolic_shape():
     def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
         m = T.int64()
         n = T.int64()
-        gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32"))
+        gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), 
dtype="float32"))
         return gv0
 
     @R.function
     def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
         m = T.int64()
         n = T.int64()
-        gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32"))
+        gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), 
dtype="float32"))
         return gv0
 
     with pytest.raises(tvm.error.DiagnosticError):
@@ -331,7 +354,7 @@ def test_symbolic_shape():
         def mismatch_dtype(x: R.Tensor(("m", "n"), "float32")) -> 
R.Tensor(None, "float32", ndim=2):
             m = T.int64()
             n = T.int32()  # The shape dtype should be int64
-            gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), 
dtype="float32"))
+            gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), 
dtype="float32"))
             return gv0
 
     def _expected(name: str):
@@ -339,7 +362,9 @@ def test_symbolic_shape():
         x = relax.Var("x", R.Tensor([m, n], "float32"))
         bb = relax.BlockBuilder()
         with bb.function(name, (x,)):
-            out = bb.emit(relax.call_tir("extern_func", x, R.Tensor((m, n), 
dtype="float32")))
+            out = bb.emit(
+                relax.call_dps_packed("extern_func", x, R.Tensor((m, n), 
dtype="float32"))
+            )
             bb.emit_func_output(out)
         return bb.get()[name]
 
@@ -403,15 +428,15 @@ def test_match_cast():
 def test_tuple_return():
     @R.function
     def foo(x: R.Tensor((4, 4), "float32")):
-        gv0 = R.call_tir("extern_func_0", x, R.Tensor((4, 4), dtype="float32"))
-        gv1 = R.call_tir("extern_func_1", x, R.Tensor((4, 4), dtype="float32"))
+        gv0 = R.call_dps_packed("extern_func_0", x, R.Tensor((4, 4), 
dtype="float32"))
+        gv1 = R.call_dps_packed("extern_func_1", x, R.Tensor((4, 4), 
dtype="float32"))
         return (gv0, gv1)
 
     x = relax.Var("x", R.Tensor((4, 4), "float32"))
     bb = relax.BlockBuilder()
     with bb.function("foo", (x,)):
-        gv0 = bb.emit(relax.call_tir("extern_func_0", x, R.Tensor((4, 4), 
dtype="float32")))
-        gv1 = bb.emit(relax.call_tir("extern_func_1", x, R.Tensor((4, 4), 
dtype="float32")))
+        gv0 = bb.emit(relax.call_dps_packed("extern_func_0", x, R.Tensor((4, 
4), dtype="float32")))
+        gv1 = bb.emit(relax.call_dps_packed("extern_func_1", x, R.Tensor((4, 
4), dtype="float32")))
         bb.emit_func_output(relax.Tuple((gv0, gv1)))
 
     _check(foo, bb.get()["foo"])
@@ -483,8 +508,8 @@ def test_dataflow_block():
     @R.function
     def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", 
ndim=2):
         with R.dataflow():
-            lv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), 
dtype="float32"))
-            lv1 = R.call_tir("extern_func", lv0, R.Tensor((128, 128), 
dtype="float32"))
+            lv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), 
dtype="float32"))
+            lv1 = R.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), 
dtype="float32"))
             gv = lv1
             R.output(gv)
         return gv
@@ -493,8 +518,12 @@ def test_dataflow_block():
     bb = relax.BlockBuilder()
     with bb.function("foo", (x,)):
         with bb.dataflow():
-            lv0 = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 
128), dtype="float32")))
-            lv1 = bb.emit(relax.call_tir("extern_func", lv0, R.Tensor((128, 
128), dtype="float32")))
+            lv0 = bb.emit(
+                relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), 
dtype="float32"))
+            )
+            lv1 = bb.emit(
+                relax.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), 
dtype="float32"))
+            )
             gv = bb.emit_output(lv1)
         bb.emit_func_output(gv)
 
@@ -504,22 +533,22 @@ def test_dataflow_block():
 def test_dataflow_block_advanced():
     @R.function
     def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", 
ndim=2):
-        gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), 
dtype="float32"))
-        gv1 = R.call_tir("extern_func", gv0, R.Tensor((128, 128), 
dtype="float32"))
+        gv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), 
dtype="float32"))
+        gv1 = R.call_dps_packed("extern_func", gv0, R.Tensor((128, 128), 
dtype="float32"))
         with R.dataflow():
             m = T.int64()
             n = T.int64()
-            lv0 = R.call_tir("extern_func", gv1, R.Tensor((128, 128), 
dtype="float32"))
+            lv0 = R.call_dps_packed("extern_func", gv1, R.Tensor((128, 128), 
dtype="float32"))
             lv1 = R.match_cast(lv0, R.Tensor((m, n), "float32"))
-            gv2 = R.call_tir("extern_func", lv0, R.Tensor((128, 128), 
dtype="float32"))
-            gv2 = R.call_tir("extern_func", gv2, R.Tensor((128, 128), 
dtype="float32"))
+            gv2 = R.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), 
dtype="float32"))
+            gv2 = R.call_dps_packed("extern_func", gv2, R.Tensor((128, 128), 
dtype="float32"))
             gv3 = R.match_cast(gv2, R.Tensor((m, n), "float32"))
             gv3 = R.match_cast(lv0, R.Tensor((m, n), "float32"))
             gv4 = gv3
             gv5 = gv2
             R.output(gv5, gv4)
-        gv6 = R.call_tir("extern_func", gv5, R.Tensor((128, 128), 
dtype="float32"))
-        gv7 = R.call_tir("extern_func", gv6, R.Tensor((128, 128), 
dtype="float32"))
+        gv6 = R.call_dps_packed("extern_func", gv5, R.Tensor((128, 128), 
dtype="float32"))
+        gv7 = R.call_dps_packed("extern_func", gv6, R.Tensor((128, 128), 
dtype="float32"))
         return gv7
 
     x = relax.Var("x", R.Tensor((128, 128), "float32"))
@@ -527,21 +556,33 @@ def test_dataflow_block_advanced():
     m = tir.Var("m", dtype="int64")
     n = tir.Var("n", dtype="int64")
     with bb.function("foo", (x,)):
-        gv0 = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), 
dtype="float32")))
-        gv1 = bb.emit(relax.call_tir("extern_func", gv0, R.Tensor((128, 128), 
dtype="float32")))
+        gv0 = bb.emit(
+            relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), 
dtype="float32"))
+        )
+        gv1 = bb.emit(
+            relax.call_dps_packed("extern_func", gv0, R.Tensor((128, 128), 
dtype="float32"))
+        )
         with bb.dataflow():
-            lv0 = bb.emit(relax.call_tir("extern_func", gv1, R.Tensor((128, 
128), dtype="float32")))
+            lv0 = bb.emit(
+                relax.call_dps_packed("extern_func", gv1, R.Tensor((128, 128), 
dtype="float32"))
+            )
             lv1 = bb.match_cast(lv0, R.Tensor((m, n), "float32"))
-            gv2 = bb.emit(relax.call_tir("extern_func", lv0, R.Tensor((128, 
128), dtype="float32")))
+            gv2 = bb.emit(
+                relax.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), 
dtype="float32"))
+            )
             gv21 = bb.emit(
-                relax.call_tir("extern_func", gv2, R.Tensor((128, 128), 
dtype="float32"))
+                relax.call_dps_packed("extern_func", gv2, R.Tensor((128, 128), 
dtype="float32"))
             )
             gv3 = bb.match_cast(gv21, R.Tensor((m, n), "float32"))
             gv31 = bb.match_cast(lv0, R.Tensor((m, n), "float32"))
             gv32 = bb.emit_output(gv31)
             gv22 = bb.emit_output(gv21)
-        gv4 = bb.emit(relax.call_tir("extern_func", gv22, R.Tensor((128, 128), 
dtype="float32")))
-        gv5 = bb.emit(relax.call_tir("extern_func", gv4, R.Tensor((128, 128), 
dtype="float32")))
+        gv4 = bb.emit(
+            relax.call_dps_packed("extern_func", gv22, R.Tensor((128, 128), 
dtype="float32"))
+        )
+        gv5 = bb.emit(
+            relax.call_dps_packed("extern_func", gv4, R.Tensor((128, 128), 
dtype="float32"))
+        )
         bb.emit_func_output(gv5)
 
     _check(foo, bb.get()["foo"])
@@ -640,13 +681,13 @@ def test_function_without_return():
 def test_tensor_type_without_args():
     @R.function
     def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
-        v = R.call_tir("tir_relu", x, R.Tensor((32, 32), dtype="float32"))
+        v = R.call_dps_packed("extern_relu", x, R.Tensor((32, 32), 
dtype="float32"))
         return v
 
     x = relax.Var("x", R.Tensor((32, 32), "float32"))
     bb = relax.BlockBuilder()
     with bb.function("foo", (x)):
-        v = bb.emit(relax.call_tir("tir_relu", x, R.Tensor((32, 32), 
dtype="float32")))
+        v = bb.emit(relax.call_dps_packed("extern_relu", x, R.Tensor((32, 32), 
dtype="float32")))
         bb.emit_func_output(v)
 
     _check(foo, bb.get()["foo"])
@@ -753,10 +794,10 @@ def test_annotate_override():
     assert isinstance(z_bind.var.struct_info, relax.TensorStructInfo)
 
 
-def test_call_tir_empty_shape():
+def test_call_dps_packed_empty_shape():
     @R.function
     def foo(x: R.Tensor((), "float32")):
-        z = R.call_tir("scalar_add", x, R.Tensor((), dtype="float32"))
+        z = R.call_dps_packed("scalar_add", x, R.Tensor((), dtype="float32"))
         return z
 
     (z_bind,) = foo.body.blocks[0].bindings
@@ -1024,7 +1065,7 @@ def test_symbolic_shape_computing():
         x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), 
"float32")
     ) -> R.Tensor(("T.max(m, 20) + 1",), "float32"):
         m = T.int64()
-        z = R.call_tir("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,), 
dtype="float32"))
+        z = R.call_dps_packed("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 
1,), dtype="float32"))
         return z
 
     m = tir.Var("m", "int64")
@@ -1033,7 +1074,9 @@ def test_symbolic_shape_computing():
     bb = relax.BlockBuilder()
     with bb.function("bar", (x, y)):
         z = bb.emit(
-            relax.call_tir("test_intrin", (x, y), R.Tensor((tir.max(m, 20) + 
1,), dtype="float32"))
+            relax.call_dps_packed(
+                "test_intrin", (x, y), R.Tensor((tir.max(m, 20) + 1,), 
dtype="float32")
+            )
         )
         bb.emit_func_output(z)
 
@@ -1043,7 +1086,7 @@ def test_symbolic_shape_computing():
     @R.function
     def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")):
         m = T.int64()
-        z = R.call_tir("test_intrin", y, R.Tensor((m * 2,), dtype="float32"))
+        z = R.call_dps_packed("test_intrin", y, R.Tensor((m * 2,), 
dtype="float32"))
         return z
 
     m = tir.Var("m", "int64")
@@ -1051,7 +1094,7 @@ def test_symbolic_shape_computing():
     y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32"))
     bb = relax.BlockBuilder()
     with bb.function("baz", (x, y)):
-        z = bb.emit(relax.call_tir("test_intrin", (y), R.Tensor((m * 2,), 
dtype="float32")))
+        z = bb.emit(relax.call_dps_packed("test_intrin", (y), R.Tensor((m * 
2,), dtype="float32")))
         bb.emit_func_output(z)
 
     _check(baz, bb.get()["baz"])
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py 
b/tests/python/relax/test_tvmscript_printer_relax.py
index 464591f259..76bb3bb812 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -299,13 +299,22 @@ def test_shape_expr():
 def test_call():
     x = tir.Var("x", "int64")
     a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32"))
-    obj = relax.call_tir("my_func", args=a, out_sinfo=a.struct_info, 
tir_vars=[x])
+    o0 = relax.call_tir(relax.GlobalVar("tir_func"), args=a, 
out_sinfo=a.struct_info, tir_vars=[x])
+    o1 = relax.call_dps_packed("my_dps_func", args=a, out_sinfo=a.struct_info)
     _assert_print(
-        obj,
+        o0,
+        """
+x = T.int64()
+a: R.Tensor((1, x, 3), dtype="float32")
+R.call_tir(tir_func, (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), 
tir_vars=R.shape([x]))
+""",
+    )
+    _assert_print(
+        o1,
         """
 x = T.int64()
 a: R.Tensor((1, x, 3), dtype="float32")
-R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), 
tir_vars=R.shape([x]))
+R.call_dps_packed("my_dps_func", (a,), out_sinfo=R.Tensor((1, x, 3), 
dtype="float32"))
 """,
     )
 
diff --git a/tests/python/relax/test_vm_build.py 
b/tests/python/relax/test_vm_build.py
index e51e22e323..776679103f 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -121,7 +121,7 @@ def test_vm_compile_stage3(exec_mode):
         @R.function
         def foo(x: R.Tensor((32, 16), "float32")) -> R.Tensor:
             with R.dataflow():
-                y = R.call_tir("test.vm.identity", (x), R.Tensor((32, 16), 
dtype="float32"))
+                y = R.call_dps_packed("test.vm.identity", (x), R.Tensor((32, 
16), dtype="float32"))
                 R.output(y)
             return y
 
@@ -145,7 +145,7 @@ def test_vm_compile_e2e(exec_mode):
             with R.dataflow():
                 n, m = T.int64(), T.int64()
                 _ = R.match_cast(x, R.Tensor((n, m), "float32"))
-                y = R.call_tir("test.vm.tile", (x), R.Tensor((n, m * 2), 
dtype="float32"))
+                y = R.call_dps_packed("test.vm.tile", (x), R.Tensor((n, m * 
2), dtype="float32"))
                 R.output(y)
             return y
 
@@ -714,7 +714,7 @@ class TestVMSetInput:
 
     @R.function
     def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), 
"float32")) -> R.Tensor:
-        gv0 = R.call_tir("test_vm_mul", (x, w), R.Tensor((32, 32), 
dtype="float32"))
+        gv0 = R.call_tir(test_vm_mul, (x, w), R.Tensor((32, 32), 
dtype="float32"))
         return gv0
 
 

Reply via email to