This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 521ab47edf [MSC] Reconstruct tensorrt module (#17344)
521ab47edf is described below

commit 521ab47edf1a2b25b6614d64df5d9f6133dfa329
Author: Archermmt <archer...@126.com>
AuthorDate: Sun Sep 8 18:40:49 2024 +0800

    [MSC] Reconstruct tensorrt module (#17344)
    
    * reconstruct tensorrt
    
    * format fix
---
 python/tvm/contrib/msc/core/frontend/translate.py  |   2 +-
 .../msc/framework/tensorrt/frontend/translate.py   |   5 +-
 .../msc/framework/tensorrt/transform/pattern.py    |  31 +-
 .../msc/framework/tensorrt/transform/transform.py  |  13 +-
 src/contrib/msc/core/transform/rewrite_utils.cc    |  58 ++
 src/contrib/msc/core/transform/rewrite_utils.h     |  72 +++
 src/contrib/msc/core/utils.cc                      |  19 +-
 src/contrib/msc/core/utils.h                       |   4 +-
 .../msc/framework/tensorrt/tensorrt_opcode.cc      |   6 +-
 .../msc/framework/tensorrt/transform_tensorrt.cc   | 668 +++++++++++++--------
 .../contrib/test_msc/test_translate_tensorrt.py    |  47 +-
 11 files changed, 642 insertions(+), 283 deletions(-)

diff --git a/python/tvm/contrib/msc/core/frontend/translate.py 
b/python/tvm/contrib/msc/core/frontend/translate.py
index 63b4424524..cea021ade3 100644
--- a/python/tvm/contrib/msc/core/frontend/translate.py
+++ b/python/tvm/contrib/msc/core/frontend/translate.py
@@ -330,7 +330,7 @@ def byoc_partition(
     msc_mod = _partition_mod(mod)
     func_names = [var.name_hint for var, func in msc_mod.functions.items() if 
_is_target_func(func)]
 
-    if not trans_config.get("allow_incomplete", False):
+    if trans_config.get("as_complete", True):
         assert len(func_names) == 1, "More than 1 target func is found: " + 
str(msc_mod)
         BYOCChecker().check(func_names, msc_mod[entry])
 
diff --git a/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py 
b/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py
index 8758fdb630..4a02b02728 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py
@@ -49,7 +49,10 @@ def transform_for_tensorrt(
     return tvm.transform.Sequential(
         [
             msc_transform.SetExprName(),
-            trt_transform.TransformTensorRT(trans_config.get("version")),
+            trt_transform.TransformTensorRT(
+                version=trans_config.get("version"),
+                linear_to_conv=trans_config.get("linear_to_conv", False),
+            ),
             relax.transform.FoldConstant(),
         ]
     )(mod)
diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py 
b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py
index 8eea3f7081..17aee690e3 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py
@@ -136,12 +136,22 @@ def _check_expr(expr: relax.Expr, dtypes: Tuple[str] = 
None) -> bool:
         return True
     if isinstance(expr, relax.Tuple):
         return all(_check_expr(field) for field in expr.fields)
-    if any(i < 0 for i in expr.struct_info.shape.values):
-        return False
-    dtypes = dtypes or ("float32", "float16")
-    if expr.struct_info.dtype not in dtypes:
-        return False
-    return True
+    dtypes = dtypes or ("float32", "float16", "int64", "int32", "bool")
+
+    def _check(sinfo):
+        if not sinfo.shape or sinfo.dtype not in dtypes:
+            return False
+        unknown_dim = 0
+        for s in sinfo.shape.values:
+            if isinstance(s, (tvm.tir.Var, tvm.tir.Any)):
+                unknown_dim += 1
+            elif isinstance(s, tvm.tir.IntImm) and s < 0:
+                unknown_dim += 1
+        return unknown_dim <= 1
+
+    if isinstance(expr.struct_info, relax.TupleStructInfo):
+        return all(_check(s) for s in expr.struct_info.fields)
+    return _check(expr.struct_info)
 
 
 def _basic_check(context: PatternCheckContext) -> bool:
@@ -216,8 +226,7 @@ def _reshape_check(context: PatternCheckContext) -> bool:
         Whether the pattern is correct.
     """
 
-    dtypes = ("float32", "float16", "int32")
-    if any(not _check_expr(context.annotated_expr[key], dtypes) for key in 
["input_0", "out"]):
+    if any(not _check_expr(context.annotated_expr[key]) for key in ["input_0", 
"out"]):
         return False
     return True
 
@@ -323,16 +332,18 @@ def get_patterns(target) -> List[Pattern]:
         "nn.avg_pool2d": ["input"],
         "nn.conv2d": ["input", "constant"],
         "nn.max_pool2d": ["input"],
+        "astype": ["input"],
         "concat": ["input"],
         "clip": ["input", "input", "input"],
         "image.resize2d": ["input", "input"],
         "matmul": ["input", "input"],
         "permute_dims": ["input"],
-        "strided_slice": ["input"],
+        "strided_slice": ["input", "input", "input", "input", "input"],
+        "topk": ["input"],
     }
     activation_ops = ["nn.relu", "nn.softmax", "sigmoid", "tanh"]
     reduce_ops = ["max", "min", "mean", "sum"]
-    unary_ops = ["cos", "exp", "negative", "round", "sin", "square", "sqrt", 
"tan"]
+    unary_ops = ["cos", "erf", "exp", "negative", "round", "sin", "square", 
"sqrt", "tan"]
     elemwise_ops = [
         "add",
         "divide",
diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py 
b/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py
index d6f15c43da..cf4d4b9f33 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py
@@ -25,18 +25,25 @@ from tvm.contrib.msc.core.utils import MSCFramework
 from tvm.contrib.msc.core import utils as msc_utils
 
 
-def TransformTensorRT(version: List[int] = None) -> tvm.ir.transform.Pass:
+def TransformTensorRT(
+    version: List[int] = None, linear_to_conv: bool = False
+) -> tvm.ir.transform.Pass:
     """Transform the Function to fit TensorRT.
 
     Parameters
     ----------
     version: list<int>
         The tensorrt version.
+    linear_to_conv: bool
+        Whether to cast linear to conv2d
 
     Returns
     -------
     ret: tvm.ir.transform.Pass
     """
 
-    version = version or msc_utils.get_version(MSCFramework.TENSORRT)
-    return relax_api.TransformTensorRT(version)  # type: ignore
+    config = {
+        "version": version or msc_utils.get_version(MSCFramework.TENSORRT),
+        "linear_to_conv": linear_to_conv,
+    }
+    return relax_api.TransformTensorRT(msc_utils.dump_dict(config))  # type: 
ignore
diff --git a/src/contrib/msc/core/transform/rewrite_utils.cc 
b/src/contrib/msc/core/transform/rewrite_utils.cc
new file mode 100644
index 0000000000..20e4821e6f
--- /dev/null
+++ b/src/contrib/msc/core/transform/rewrite_utils.cc
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/transform/rewrite_utils.cc
+ */
+#include "rewrite_utils.h"
+
+#include <set>
+#include <string>
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+Var RewriteUtils::ReEmit(BlockBuilder builder, const String& name, const Expr& 
expr) {
+  expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kName, name);
+  return builder->Emit(expr, name);
+}
+
+Var RewriteUtils::MakeCall(BlockBuilder builder, const String& name, Expr op, 
Array<Expr> args,
+                           Attrs attrs) {
+  const auto& call = Call(op, args, attrs);
+  return ReEmit(builder, name, call);
+}
+
+Expr RewriteUtils::MakeConstant(BlockBuilder builder, const String& name, 
double value,
+                                const DataType& dtype, size_t ndim) {
+  const auto& data = support::FloatImmToNDArray(FloatImm(dtype, value));
+  Span span = SpanUtils::CreateWithAttr(msc_attr::kName, name);
+  const auto& constant = Constant(data, NullOpt, span);
+  if (ndim == 0) {
+    return constant;
+  }
+  static const Op& reshape_op = Op::Get("relax.reshape");
+  Array<PrimExpr> exp_shape(ndim, Integer(1));
+  return MakeCall(builder, name + "_exp", reshape_op, {constant, 
ShapeExpr(exp_shape)});
+}
+
+}  // namespace msc
+}  // namespace contrib
+}  // namespace tvm
diff --git a/src/contrib/msc/core/transform/rewrite_utils.h 
b/src/contrib/msc/core/transform/rewrite_utils.h
new file mode 100644
index 0000000000..2693a6ccd2
--- /dev/null
+++ b/src/contrib/msc/core/transform/rewrite_utils.h
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/msc/core/transform/rewrite_utils.h
+ * \brief Common utilities for rewrite.
+ */
+#ifndef TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_
+#define TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_
+
+#include <tvm/ir/source_map.h>
+#include <tvm/relax/expr.h>
+
+#include <vector>
+
+#include "../../../../relax/transform/utils.h"
+#include "../../../../support/scalars.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace contrib {
+namespace msc {
+
+using Expr = tvm::RelayExpr;
+using namespace tvm::relax;
+
+/*!
+ * \brief Utils for Layout.
+ */
+class RewriteUtils {
+ public:
+  /*!
+   * \brief Emit call with span name.
+   * \return The emitted var.
+   */
+  TVM_DLL static Var ReEmit(BlockBuilder builder, const String& name, const 
Expr& expr);
+
+  /*!
+   * \brief Make and emit a call binding with span.
+   * \return The emitted var.
+   */
+  TVM_DLL static Var MakeCall(BlockBuilder builder, const String& name, Expr 
op, Array<Expr> args,
+                              Attrs attrs = Attrs());
+
+  /*!
+   * \brief Make and emit a (shaped)constant with span.
+   * \return The constant/reshape.
+   */
+  TVM_DLL static Expr MakeConstant(BlockBuilder builder, const String& name, 
double value,
+                                   const DataType& dtype, size_t ndim = 0);
+};
+
+}  // namespace msc
+}  // namespace contrib
+}  // namespace tvm
+#endif  // TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_
diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc
index c6e74d4284..1e846b0b3a 100644
--- a/src/contrib/msc/core/utils.cc
+++ b/src/contrib/msc/core/utils.cc
@@ -507,12 +507,25 @@ const String ExprUtils::GetSpanName(const Expr& expr, 
const String& suffix) {
   return name;
 }
 
-const Array<PrimExpr> ExprUtils::GetShape(const Expr& expr) {
-  const auto& shape_opt = 
Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr))->GetShape();
-  ICHECK(shape_opt.defined()) << "Shape is not defined for " << expr;
+const Array<PrimExpr> ExprUtils::GetShape(const relax::TensorStructInfo& 
sinfo, bool as_int) {
+  const auto& shape_opt = sinfo->GetShape();
+  if (!shape_opt.defined()) {
+    return Array<PrimExpr>();
+  }
+  if (as_int) {
+    Array<PrimExpr> shape;
+    for (const auto& s : shape_opt.value()) {
+      shape.push_back(s->IsInstance<IntImmNode>() ? s : Integer(-1));
+    }
+    return shape;
+  }
   return shape_opt.value();
 }
 
+const Array<PrimExpr> ExprUtils::GetShape(const Expr& expr, bool as_int) {
+  return 
GetShape(Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr)), as_int);
+}
+
 const DataType ExprUtils::GetDataType(const Expr& expr) {
   return Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr))->dtype;
 }
diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h
index d7758cc23d..7fb9c87a99 100644
--- a/src/contrib/msc/core/utils.h
+++ b/src/contrib/msc/core/utils.h
@@ -398,7 +398,9 @@ class ExprUtils {
    * \brief Get shape of expr.
    * \return The shape.
    */
-  TVM_DLL static const Array<PrimExpr> GetShape(const Expr& expr);
+  TVM_DLL static const Array<PrimExpr> GetShape(const relax::TensorStructInfo& 
sinfo,
+                                                bool as_int = true);
+  TVM_DLL static const Array<PrimExpr> GetShape(const Expr& expr, bool as_int 
= true);
 
   /*!
    * \brief Get dtype of expr.
diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc 
b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc
index a080fdd778..d90cdc35d1 100644
--- a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc
+++ b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc
@@ -92,6 +92,8 @@ const String TensorRTOpCode::DType(const DataType& dtype) {
     dtype_enum = "DataType::kINT8";
   } else if (dtype_name == "int32") {
     dtype_enum = "DataType::kINT32";
+  } else if (dtype_name == "int64") {
+    dtype_enum = "DataType::kINT32";
   } else if (dtype_name == "float16") {
     dtype_enum = "DataType::kHALF";
   } else if (dtype_name == "float32") {
@@ -267,7 +269,7 @@ class TensorRTAstypeCodeGen : public TensorRTOpCode {
   void CodeGenBuild() final {
     stack_.op_call()
         .op_input_arg()
-        .func_call("setOutput", NullOpt, DocUtils::ToPtr(IdxNode()))
+        .func_call("setOutputType", NullOpt, DocUtils::ToPtr(IdxNode()))
         .call_arg(0)
         .op_dtype_arg(node()->OutputAt(0)->dtype);
   }
@@ -661,7 +663,7 @@ class TensorRTTopkCodeGen : public TensorRTOpCode {
 
  protected:
   void CodeGenBuild() final {
-    const String& symbol = node()->GetTypeAttr<bool>("is_asend") ? "MIN" : 
"MAX";
+    const String& symbol = node()->GetTypeAttr<bool>("largest") ? "MAX" : 
"MIN";
     stack_.op_call()
         .op_input_arg()
         .call_arg("TopKOperation::k" + symbol)
diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc 
b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
index 3f85309cd8..542e15d06c 100644
--- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
+++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
@@ -22,83 +22,101 @@
  * \brief Pass for transform the function to tensorrt.
  */
 
+#include <tvm/relax/attrs/sorting.h>
 #include <tvm/relax/expr.h>
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/transform.h>
 
 #include "../../../../relax/transform/utils.h"
 #include "../../../../support/scalars.h"
+#include "../../core/transform/rewrite_utils.h"
 #include "../../core/utils.h"
 
 namespace tvm {
 namespace relax {
 using namespace tvm::contrib::msc;
 
-const Array<PrimExpr> GetShape(const Expr& var) {
-  const auto& shape_opt = 
Downcast<TensorStructInfo>(GetStructInfo(var))->GetShape();
-  ICHECK(shape_opt.defined()) << "Shape is not defined for " << var;
-  return shape_opt.value();
-}
-
-Var EmitCall(BlockBuilder builder, const Expr& expr, const Span& src_span, 
const String& suffix) {
-  const auto& name = SpanUtils::GetAttr(src_span, msc_attr::kName) + "_" + 
suffix;
-  expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kName, name);
-  return builder->Emit(expr, name);
-}
-
-Var MakeCall(BlockBuilder builder, const Span& src_span, const String& suffix, 
Expr op,
-             Array<Expr> args, Attrs attrs = Attrs()) {
-  const auto& call = Call(op, args, attrs);
-  return EmitCall(builder, call, src_span, suffix);
-}
+struct TensorRTTransConfig {
+  // Whether to cast linear to conv
+  bool linear_to_conv{true};
+  std::vector<size_t> version{0, 0, 0};
+
+  void Load(dmlc::JSONReader* reader) {
+    std::string key;
+    reader->BeginObject();
+    while (reader->NextObjectItem(&key)) {
+      if (key == "linear_to_conv") {
+        reader->Read(&linear_to_conv);
+      } else if (key == "version") {
+        reader->Read(&version);
+      } else {
+        LOG(FATAL) << "Do not support key " << key;
+      }
+    }
+  }
+};
 
-Expr MakeConstant(double value, const DataType& dtype, const String& name) {
-  const auto& data = support::FloatImmToNDArray(FloatImm(dtype, value));
-  const auto& span = SpanUtils::SetAttr(Span(), msc_attr::kName, name);
-  return Constant(data, NullOpt, span);
+const TensorRTTransConfig ParseConfig(const String& config_str) {
+  TensorRTTransConfig config;
+  if (config_str.size() > 0) {
+    std::istringstream is(config_str);
+    dmlc::JSONReader reader(&is);
+    reader.Read(&config);
+  }
+  return config;
 }
 
 using FRewriteTensorRT =
     runtime::TypedPackedFunc<Expr(BlockBuilder builder, const Var& var, const 
Call& src_call,
-                                  const Map<Expr, Call>& new_calls, const 
Array<Integer>& version)>;
+                                  const Map<Expr, Call>& new_calls, const 
String& config)>;
+
+const Array<PrimExpr> BroadcastShape(const Array<PrimExpr>& src_shape,
+                                     const Array<PrimExpr>& out_shape) {
+  size_t diff = out_shape.size() - src_shape.size();
+  Array<PrimExpr> leading_shape, tailing_shape;
+  for (size_t i = 0; i < diff; i++) {
+    leading_shape.push_back(Integer(1));
+  }
+  for (const auto& s : src_shape) {
+    tailing_shape.push_back(s);
+    leading_shape.push_back(s);
+  }
+  for (size_t i = 0; i < diff; i++) {
+    tailing_shape.push_back(Integer(1));
+  }
+  if (ArrayUtils::Broadcastable(tailing_shape, out_shape)) {
+    return tailing_shape;
+  }
+  ICHECK(ArrayUtils::Broadcastable(leading_shape, out_shape))
+      << "Only support elemwise ops with leading or tailing expand";
+  return leading_shape;
+}
 
 Expr RewriteElemwise(BlockBuilder builder, const Var& var, const Call& 
src_call,
-                     const Map<Expr, Call>& new_calls, const Array<Integer>& 
version) {
+                     const Map<Expr, Call>& new_calls, const String& config) {
   const auto& call = new_calls.count(src_call) ? new_calls[src_call] : 
src_call;
-  const auto& shape_a = GetShape(call->args[0]);
-  const auto& shape_b = GetShape(call->args[1]);
+  const auto& shape_a = ExprUtils::GetShape(call->args[0]);
+  const auto& shape_b = ExprUtils::GetShape(call->args[1]);
+  const auto& shape_out = ExprUtils::GetShape(var);
   static const Op& reshape_op = Op::Get("relax.reshape");
   if (shape_a.size() > shape_b.size()) {
-    Array<PrimExpr> exp_shape(shape_a.size(), Integer(1));
-    if (shape_b.size() == 1) {
-      exp_shape.Set(shape_a.size() - 1, shape_b[0]);
-    } else if (shape_b.size() == 0) {
-      LOG_DEBUG << "Expand scalar argument to " << exp_shape;
-    } else {
-      LOG_FATAL << "broadcast only support 1 dim and scalar, get " << shape_b;
-    }
-    const auto& expand_b = MakeCall(builder, call->span, "expand_b", 
reshape_op,
-                                    {call->args[1], ShapeExpr(exp_shape)});
+    const auto& exp_shape = BroadcastShape(shape_b, shape_out);
+    const auto& expand_b =
+        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"expand_b"), reshape_op,
+                               {call->args[1], ShapeExpr(exp_shape)});
     return Call(call->op, {call->args[0], expand_b}, call->attrs, 
call->sinfo_args, call->span);
-  }
-  if (shape_a.size() < shape_b.size()) {
-    Array<PrimExpr> exp_shape(shape_b.size(), Integer(1));
-    if (shape_a.size() == 1) {
-      exp_shape.Set(shape_b.size() - 1, shape_a[0]);
-    } else if (shape_a.size() == 0) {
-      LOG_DEBUG << "Expand scalar argument to " << exp_shape;
-    } else {
-      LOG_FATAL << "broadcast only support 1 dim and scalar, get " << shape_a;
-    }
-    const auto& expand_a = MakeCall(builder, call->span, "expand_a", 
reshape_op,
-                                    {call->args[0], ShapeExpr(exp_shape)});
+  } else if (shape_a.size() < shape_b.size()) {
+    const auto& exp_shape = BroadcastShape(shape_a, shape_out);
+    const auto& expand_a =
+        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"expand_a"), reshape_op,
+                               {call->args[0], ShapeExpr(exp_shape)});
     return Call(call->op, {expand_a, call->args[1]}, call->attrs, 
call->sinfo_args, call->span);
   }
   return call;
 }
 
 Expr RewriteAdd(BlockBuilder builder, const Var& var, const Call& src_call,
-                const Map<Expr, Call>& new_calls, const Array<Integer>& 
version) {
+                const Map<Expr, Call>& new_calls, const String& config) {
   const auto& call = new_calls.count(src_call) ? new_calls[src_call] : 
src_call;
   if (new_calls.count(call->args[0]) &&
       new_calls[call->args[0]]->op == Op::Get("relax.nn.conv1d")) {
@@ -110,19 +128,20 @@ Expr RewriteAdd(BlockBuilder builder, const Var& var, 
const Call& src_call,
     if (conv2d->op != Op::Get("relax.nn.conv2d")) {
       return call;
     }
-    const auto& input_shape = GetShape(call->args[0]);
-    const auto& bias_shape = GetShape(call->args[1]);
+    const auto& input_shape = ExprUtils::GetShape(call->args[0]);
+    const auto& bias_shape = ExprUtils::GetShape(call->args[1]);
     const auto* conv_attrs = conv2d->attrs.as<Conv2DAttrs>();
     if (conv_attrs->data_layout == "NCHW") {
       // expand bias reshape
       Array<PrimExpr> exp_bias_shape{bias_shape[0], bias_shape[1], Integer(1), 
bias_shape[2]};
       static const Op& reshape_op = Op::Get("relax.reshape");
-      const auto& exp_bias = MakeCall(builder, call->span, "exp_bias", 
reshape_op,
-                                      {call->args[1], 
ShapeExpr(exp_bias_shape)});
+      const auto& exp_bias =
+          RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"exp_bias"), reshape_op,
+                                 {call->args[1], ShapeExpr(exp_bias_shape)});
       // redirect to conv2d
       static const Op& add_op = Op::Get("relax.add");
-      const auto& exp_add =
-          MakeCall(builder, call->span, "exp_add", add_op, {reshape->args[0], 
exp_bias});
+      const auto& exp_add = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "exp_add"),
+                                                   add_op, {reshape->args[0], 
exp_bias});
       // reduce output
       return Call(reshape_op, {exp_add, ShapeExpr(input_shape)}, Attrs(), 
call->sinfo_args,
                   call->span);
@@ -130,48 +149,50 @@ Expr RewriteAdd(BlockBuilder builder, const Var& var, 
const Call& src_call,
       LOG_FATAL << "Unexpected data layout " << conv_attrs->data_layout;
     }
   }
-  return RewriteElemwise(builder, var, call, new_calls, version);
+  return RewriteElemwise(builder, var, call, new_calls, config);
 }
 
 Expr RewriteArgmaxmin(BlockBuilder builder, const Var& var, const Call& 
src_call,
-                      const Map<Expr, Call>& new_calls, const Array<Integer>& 
version) {
+                      const Map<Expr, Call>& new_calls, const String& config) {
   const auto& call = new_calls.count(src_call) ? new_calls[src_call] : 
src_call;
-  const auto& out_dtype = 
Downcast<TensorStructInfo>(GetStructInfo(var))->dtype;
+  const auto& out_dtype = ExprUtils::GetDataType(var);
   const auto* src_attrs = src_call->attrs.as<ArgmaxArgminAttrs>();
-  Expr raw_var;
-  if (src_attrs->keepdims) {
-    raw_var = EmitCall(builder, call, call->span, "raw");
-  } else {
-    auto new_attrs = make_object<ArgmaxArgminAttrs>();
-    new_attrs->axis = src_attrs->axis;
-    new_attrs->keepdims = true;
-    raw_var =
-        MakeCall(builder, call->span, "keepdims", call->op, {call->args[0]}, 
Attrs(new_attrs));
+  ICHECK(out_dtype == DataType::Int(32) || out_dtype == DataType::Int(64))
+      << "Unexpected out dtype " << out_dtype;
+  static const Op& topk_op = Op::Get("relax.topk");
+  auto topk_attrs = make_object<TopKAttrs>();
+  topk_attrs->k = 1;
+  if (src_attrs->axis.defined()) {
+    topk_attrs->axis = src_attrs->axis.value()->value;
   }
-  static const Op& astype_op = Op::Get("relax.astype");
-  auto cast_to_attrs = make_object<AstypeAttrs>();
-  cast_to_attrs->dtype = DataType::Int(32);
-  Expr res = MakeCall(builder, call->span, "cast_to", astype_op, {raw_var}, 
Attrs(cast_to_attrs));
-  // reshape back
-  if (!src_attrs->keepdims) {
-    const auto& output_shape = GetShape(var);
-    static const Op& reshape_op = Op::Get("relax.reshape");
-    res = MakeCall(builder, call->span, "reshape", reshape_op, {res, 
ShapeExpr(output_shape)});
+  topk_attrs->largest = call->op == Op::Get("relax.argmax");
+  topk_attrs->ret_type = "both";
+  topk_attrs->dtype = out_dtype;
+  // change to topk
+  const auto& topk = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "topk"), topk_op,
+                                            {call->args[0]}, 
Attrs(topk_attrs));
+  const auto& get_name = ExprUtils::GetSpanName(call, ".1");
+  const auto& get_item =
+      TupleGetItem(topk, 1, SpanUtils::CreateWithAttr(msc_attr::kName, 
get_name));
+  if (src_attrs->keepdims) {
+    return get_item;
   }
-  auto cast_from_attrs = make_object<AstypeAttrs>();
-  cast_from_attrs->dtype = out_dtype;
-  return Call(astype_op, {res}, Attrs(cast_from_attrs), call->sinfo_args, 
call->span);
+  const auto& get_item_var = builder->Emit(get_item, get_name);
+  static const Op& reshape_op = Op::Get("relax.reshape");
+  const auto& output_shape = ExprUtils::GetShape(var);
+  return Call(reshape_op, {get_item_var, ShapeExpr(output_shape)}, Attrs(), 
call->sinfo_args,
+              call->span);
 }
 
 Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& 
src_call,
-                      const Map<Expr, Call>& new_calls, const Array<Integer>& 
version) {
+                      const Map<Expr, Call>& new_calls, const String& config) {
   const auto& call = new_calls.count(src_call) ? new_calls[src_call] : 
src_call;
-  const auto& in_dtype = 
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->dtype;
+  const auto& in_dtype = ExprUtils::GetDataType(call->args[0]);
   const auto* src_attrs = src_call->attrs.as<AttentionAttrs>();
 
   // define dims
-  const auto& in_q_shape = GetShape(call->args[0]);
-  const auto& in_v_shape = GetShape(call->args[2]);
+  const auto& in_q_shape = ExprUtils::GetShape(call->args[0]);
+  const auto& in_v_shape = ExprUtils::GetShape(call->args[2]);
   const auto& batch_size = in_q_shape[0];
   const auto& seq_len = in_q_shape[1];
   const auto& num_head = in_q_shape[2];
@@ -198,50 +219,53 @@ Expr RewriteAttention(BlockBuilder builder, const Var& 
var, const Call& src_call
   auto permute_attrs = make_object<PermuteDimsAttrs>();
   Array<Integer> axes{Integer(0), Integer(2), Integer(1), Integer(3)};
   permute_attrs->axes = axes;
-  const auto& q_trans = MakeCall(builder, call->span, "q_trans", 
permute_dims_op, {call->args[0]},
-                                 Attrs(permute_attrs));
-  const auto& k_trans = MakeCall(builder, call->span, "k_trans", 
permute_dims_op, {call->args[1]},
-                                 Attrs(permute_attrs));
-  const auto& v_trans = MakeCall(builder, call->span, "v_trans", 
permute_dims_op, {call->args[2]},
-                                 Attrs(permute_attrs));
+  const auto& q_trans =
+      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "q_trans"), 
permute_dims_op,
+                             {call->args[0]}, Attrs(permute_attrs));
+  const auto& k_trans =
+      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_trans"), 
permute_dims_op,
+                             {call->args[1]}, Attrs(permute_attrs));
+  const auto& v_trans =
+      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "v_trans"), 
permute_dims_op,
+                             {call->args[2]}, Attrs(permute_attrs));
   Array<PrimExpr> q_shape({batch_size * num_head, seq_len, head_dim});
-  const auto& q_reshape =
-      MakeCall(builder, call->span, "q_reshape", reshape_op, {q_trans, 
ShapeExpr(q_shape)});
+  const auto& q_reshape = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "q_reshape"),
+                                                 reshape_op, {q_trans, 
ShapeExpr(q_shape)});
   Array<PrimExpr> k_shape({batch_size * num_head, seq_len_kv, head_dim});
-  const auto& k_reshape =
-      MakeCall(builder, call->span, "k_reshape", reshape_op, {k_trans, 
ShapeExpr(k_shape)});
+  const auto& k_reshape = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "k_reshape"),
+                                                 reshape_op, {k_trans, 
ShapeExpr(k_shape)});
   Array<PrimExpr> v_shape({batch_size * num_head, seq_len_kv, head_dim_v});
-  const auto& v_reshape =
-      MakeCall(builder, call->span, "v_reshape", reshape_op, {v_trans, 
ShapeExpr(v_shape)});
+  const auto& v_reshape = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "v_reshape"),
+                                                 reshape_op, {v_trans, 
ShapeExpr(v_shape)});
   auto reduce_permute_attrs = make_object<PermuteDimsAttrs>();
   Array<Integer> v_axes{Integer(0), Integer(2), Integer(1)};
   reduce_permute_attrs->axes = v_axes;
   // transpose for batch_matmul
-  const auto& k_reshape_trans = MakeCall(builder, call->span, 
"k_reshape_trans", permute_dims_op,
-                                         {k_reshape}, 
Attrs(reduce_permute_attrs));
+  const auto& k_reshape_trans =
+      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"k_reshape_trans"),
+                             permute_dims_op, {k_reshape}, 
Attrs(reduce_permute_attrs));
 
   // calculate product
   auto matmul_attrs = make_object<MatmulAttrs>();
   matmul_attrs->out_dtype = in_dtype;
-  const auto& qk_prod = MakeCall(builder, call->span, "qk_prod", matmul_op,
-                                 {q_reshape, k_reshape_trans}, 
Attrs(matmul_attrs));
+  const auto& qk_prod =
+      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "qk_prod"), 
matmul_op,
+                             {q_reshape, k_reshape_trans}, 
Attrs(matmul_attrs));
   Expr p_scale;
   if (src_attrs->scale.defined()) {
-    const auto& scale = 
MakeConstant(static_cast<double>(src_attrs->scale.value()->value), in_dtype,
-                                     SpanUtils::GetAttr(call->span, 
msc_attr::kName) + "_scale");
-    Array<PrimExpr> exp_shape(3, Integer(1));
-    const auto& exp_scale =
-        MakeCall(builder, call->span, "exp_scale", reshape_op, {scale, 
ShapeExpr(exp_shape)});
-    p_scale = MakeCall(builder, call->span, "p_scale", multiply_op, {qk_prod, 
exp_scale});
+    double value = static_cast<double>(src_attrs->scale.value()->value);
+    const auto& scale = RewriteUtils::MakeConstant(builder, 
ExprUtils::GetSpanName(call, "scale"),
+                                                   value, in_dtype, 3);
+    p_scale = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"p_scale"), multiply_op,
+                                     {qk_prod, scale});
   } else {
-    const auto& scale =
-        MakeConstant(static_cast<double>(Downcast<Integer>(head_dim)->value), 
in_dtype,
-                     SpanUtils::GetAttr(call->span, msc_attr::kName) + 
"_scale");
-    Array<PrimExpr> exp_shape(3, Integer(1));
-    const auto& exp_scale =
-        MakeCall(builder, call->span, "exp_scale", reshape_op, {scale, 
ShapeExpr(exp_shape)});
-    const auto& sqrt_scale = MakeCall(builder, call->span, "sqrt_scale", 
sqrt_op, {exp_scale});
-    p_scale = MakeCall(builder, call->span, "p_scale", divide_op, {qk_prod, 
sqrt_scale});
+    double value = static_cast<double>(Downcast<Integer>(head_dim)->value);
+    const auto& scale = RewriteUtils::MakeConstant(builder, 
ExprUtils::GetSpanName(call, "scale"),
+                                                   value, in_dtype, 3);
+    const auto& sqrt_scale = RewriteUtils::MakeCall(
+        builder, ExprUtils::GetSpanName(call, "sqrt_scale"), sqrt_op, {scale});
+    p_scale = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"p_scale"), divide_op,
+                                     {qk_prod, sqrt_scale});
   }
 
   // bias
@@ -249,12 +273,12 @@ Expr RewriteAttention(BlockBuilder builder, const Var& 
var, const Call& src_call
   if (call->args.size() == 4) {
     Array<PrimExpr> exp_shape{batch_size, num_head, seq_len, seq_len_kv};
     Array<PrimExpr> reduce_shape{batch_size * num_head, seq_len, seq_len_kv};
-    const auto& prod_exp =
-        MakeCall(builder, call->span, "prod_exp", reshape_op, {prod, 
ShapeExpr(exp_shape)});
-    const auto& prod_add =
-        MakeCall(builder, call->span, "prod_add", add_op, {prod_exp, 
call->args[3]});
-    prod = MakeCall(builder, call->span, "prod_reduce", reshape_op,
-                    {prod_add, ShapeExpr(reduce_shape)});
+    const auto& prod_exp = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "prod_exp"),
+                                                  reshape_op, {prod, 
ShapeExpr(exp_shape)});
+    const auto& prod_add = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "prod_add"),
+                                                  add_op, {prod_exp, 
call->args[3]});
+    prod = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"prod_reduce"), reshape_op,
+                                  {prod_add, ShapeExpr(reduce_shape)});
   }
 
   // causal_mask
@@ -262,7 +286,8 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, 
const Call& src_call
   if (!src_attrs->causal_mask.defined()) {
     auto softmax_attrs = make_object<SoftmaxAttrs>();
     softmax_attrs->axis = 2;
-    s_value = MakeCall(builder, call->span, "act", softmax_op, {prod}, 
Attrs(softmax_attrs));
+    s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"act"), softmax_op,
+                                     {prod}, Attrs(softmax_attrs));
   } else {
     const auto& causal_mask = src_attrs->causal_mask.value();
     PrimValue tril_k;
@@ -273,41 +298,47 @@ Expr RewriteAttention(BlockBuilder builder, const Var& 
var, const Call& src_call
     } else {
       LOG_FATAL << "Unexpected causal_mask " << causal_mask;
     }
-    const auto& p_masked = MakeCall(builder, call->span, "p_masked", tril_op, 
{prod, tril_k});
+    const auto& p_masked = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "p_masked"),
+                                                  tril_op, {prod, tril_k});
     auto reduce_attrs = make_object<StatisticalAttrs>();
     Array<Integer> axis{Integer(2)};
     reduce_attrs->axis = axis;
     reduce_attrs->keepdims = true;
-    const auto& p_max = MakeCall(builder, call->span, "p_max", max_op, {prod}, 
Attrs(reduce_attrs));
-    const auto& p_diff = MakeCall(builder, call->span, "p_diff", subtract_op, 
{p_masked, p_max});
-    const auto& p_exp = MakeCall(builder, call->span, "p_exp", exp_op, 
{p_diff});
-    const auto& p_masked_exp =
-        MakeCall(builder, call->span, "p_masked_exp", tril_op, {p_exp, 
tril_k});
+    const auto& p_max = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "p_max"),
+                                               max_op, {prod}, 
Attrs(reduce_attrs));
+    const auto& p_diff = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "p_diff"),
+                                                subtract_op, {p_masked, 
p_max});
+    const auto& p_exp =
+        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_exp"), 
exp_op, {p_diff});
+    const auto& p_masked_exp = RewriteUtils::MakeCall(
+        builder, ExprUtils::GetSpanName(call, "p_masked_exp"), tril_op, 
{p_exp, tril_k});
     const auto& p_masked_sum =
-        MakeCall(builder, call->span, "p_masked_sum", sum_op, {p_masked_exp}, 
Attrs(reduce_attrs));
-    s_value = MakeCall(builder, call->span, "act", divide_op, {p_masked_exp, 
p_masked_sum});
+        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"p_masked_sum"), sum_op,
+                               {p_masked_exp}, Attrs(reduce_attrs));
+    s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"act"), divide_op,
+                                     {p_masked_exp, p_masked_sum});
   }
 
   // final calculation
-  const auto& o_prod =
-      MakeCall(builder, call->span, "o_prod", matmul_op, {s_value, v_reshape}, 
Attrs(matmul_attrs));
+  const auto& o_prod = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "o_prod"),
+                                              matmul_op, {s_value, v_reshape}, 
Attrs(matmul_attrs));
   Array<PrimExpr> o_shape{batch_size, num_head, seq_len, head_dim_v};
   return Call(reshape_op, {o_prod, ShapeExpr(o_shape)}, Attrs(), 
call->sinfo_args, call->span);
 }
 
 Expr RewriteBatchNorm(BlockBuilder builder, const Var& var, const Call& 
src_call,
-                      const Map<Expr, Call>& new_calls, const Array<Integer>& 
version) {
+                      const Map<Expr, Call>& new_calls, const String& config) {
   const auto& call = new_calls.count(src_call) ? new_calls[src_call] : 
src_call;
-  const auto& input_shape = GetShape(call->args[0]);
-  const auto& in_dtype = 
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->dtype;
+  const auto& input_shape = ExprUtils::GetShape(call->args[0]);
+  const auto& in_dtype = ExprUtils::GetDataType(call->args[0]);
   const auto* src_attrs = src_call->attrs.as<BatchNormAttrs>();
   // define expand shape
   Array<PrimExpr> exp_shape(input_shape.size(), Integer(1));
   exp_shape.Set(src_attrs->axis, input_shape[src_attrs->axis]);
 
   // create eps constant
-  const auto& eps = MakeConstant(src_attrs->epsilon, in_dtype,
-                                 SpanUtils::GetAttr(call->span, 
msc_attr::kName) + "_eps");
+  const auto& eps = RewriteUtils::MakeConstant(builder, 
ExprUtils::GetSpanName(call, "eps"),
+                                               src_attrs->epsilon, in_dtype);
 
   // create ops
   static const Op& add_op = Op::Get("relax.add");
@@ -318,36 +349,43 @@ Expr RewriteBatchNorm(BlockBuilder builder, const Var& 
var, const Call& src_call
   static const Op& subtract_op = Op::Get("relax.subtract");
 
   // scale factor: gamma/sqrt(var + epsilon)
-  const auto& eps_add = MakeCall(builder, call->span, "eps_add", add_op, 
{call->args[4], eps});
-  const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {eps_add});
-  const auto& scale_factor =
-      MakeCall(builder, call->span, "scale_factor", divide_op, {call->args[1], 
sqrt});
+  const auto& eps_add = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "eps_add"),
+                                               add_op, {call->args[4], eps});
+  const auto& sqrt =
+      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"), 
sqrt_op, {eps_add});
+  const auto& scale_factor = RewriteUtils::MakeCall(
+      builder, ExprUtils::GetSpanName(call, "scale_factor"), divide_op, 
{call->args[1], sqrt});
   Expr res = call->args[0];
   // scale
   if (src_attrs->scale) {
-    const auto& exp_scale = MakeCall(builder, call->span, "exp_scale", 
reshape_op,
-                                     {scale_factor, ShapeExpr(exp_shape)});
-    res = MakeCall(builder, call->span, "scale", multiply_op, {res, 
exp_scale});
+    const auto& exp_scale =
+        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"exp_scale"), reshape_op,
+                               {scale_factor, ShapeExpr(exp_shape)});
+    res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"scale"), multiply_op,
+                                 {res, exp_scale});
   }
   // offset
   if (src_attrs->center) {
     // offset factor: beta-mean*scale_factor
-    const auto& average =
-        MakeCall(builder, call->span, "average", multiply_op, {call->args[3], 
scale_factor});
+    const auto& average = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "average"),
+                                                 multiply_op, {call->args[3], 
scale_factor});
     const auto& offset_factor =
-        MakeCall(builder, call->span, "offset_factor", subtract_op, 
{call->args[2], average});
-    const auto& exp_offset = MakeCall(builder, call->span, "exp_offset", 
reshape_op,
-                                      {offset_factor, ShapeExpr(exp_shape)});
-    res = MakeCall(builder, call->span, "offset", add_op, {res, exp_offset});
+        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"offset_factor"), subtract_op,
+                               {call->args[2], average});
+    const auto& exp_offset =
+        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"exp_offset"), reshape_op,
+                               {offset_factor, ShapeExpr(exp_shape)});
+    res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"offset"), add_op,
+                                 {res, exp_offset});
   }
   return Tuple(Array<Expr>{res}, call->span);
 }
 
 Expr RewriteBroadcastTo(BlockBuilder builder, const Var& var, const Call& 
src_call,
-                        const Map<Expr, Call>& new_calls, const 
Array<Integer>& version) {
+                        const Map<Expr, Call>& new_calls, const String& 
config) {
   const auto& call = new_calls.count(src_call) ? new_calls[src_call] : 
src_call;
-  const auto& input_shape = GetShape(call->args[0]);
-  const auto& output_shape = GetShape(var);
+  const auto& input_shape = ExprUtils::GetShape(call->args[0]);
+  const auto& output_shape = ExprUtils::GetShape(var);
   Expr concat_input = call->args[0];
   static const Op& concat_op = Op::Get("relax.concat");
   for (size_t i = 0; i < input_shape.size(); i++) {
@@ -357,30 +395,33 @@ Expr RewriteBroadcastTo(BlockBuilder builder, const Var& 
var, const Call& src_ca
       Array<Expr> concat_inputs(out_dim / in_dim, concat_input);
       auto concat_attrs = make_object<ConcatAttrs>();
       concat_attrs->axis = Integer(i);
-      concat_input = MakeCall(builder, call->span, "concat_" + 
std::to_string(i), concat_op,
-                              {Tuple(concat_inputs)}, Attrs(concat_attrs));
+      concat_input = RewriteUtils::MakeCall(
+          builder, ExprUtils::GetSpanName(call, "concat_" + 
std::to_string(i)), concat_op,
+          {Tuple(concat_inputs)}, Attrs(concat_attrs));
     }
   }
   return concat_input;
 }
 
 Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call,
-                   const Map<Expr, Call>& new_calls, const Array<Integer>& 
version) {
+                   const Map<Expr, Call>& new_calls, const String& config) {
   const auto& call = new_calls.count(src_call) ? new_calls[src_call] : 
src_call;
   const auto* src_attrs = src_call->attrs.as<Conv1DAttrs>();
-  const auto& input_shape = GetShape(call->args[0]);
-  const auto& weight_shape = GetShape(call->args[1]);
-  const auto& output_shape = GetShape(var);
+  const auto& input_shape = ExprUtils::GetShape(call->args[0]);
+  const auto& weight_shape = ExprUtils::GetShape(call->args[1]);
+  const auto& output_shape = ExprUtils::GetShape(var);
   if (src_attrs->data_layout == "NCW") {
     Array<Expr> new_args;
     // expand inputs
     Array<PrimExpr> exp_input_shape{input_shape[0], input_shape[1], 
Integer(1), input_shape[2]};
     Array<PrimExpr> exp_weight_shape{weight_shape[0], weight_shape[1], 
Integer(1), weight_shape[2]};
     static const Op& reshape_op = Op::Get("relax.reshape");
-    new_args.push_back(MakeCall(builder, call->span, "exp_input", reshape_op,
-                                {call->args[0], ShapeExpr(exp_input_shape)}));
-    new_args.push_back(MakeCall(builder, call->span, "exp_weight", reshape_op,
-                                {call->args[1], ShapeExpr(exp_weight_shape)}));
+    new_args.push_back(RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "exp_input"),
+                                              reshape_op,
+                                              {call->args[0], 
ShapeExpr(exp_input_shape)}));
+    new_args.push_back(RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "exp_weight"),
+                                              reshape_op,
+                                              {call->args[1], 
ShapeExpr(exp_weight_shape)}));
     // change to conv2d
     static const Op& conv2d_op = Op::Get("relax.nn.conv2d");
     auto conv_attrs = make_object<Conv2DAttrs>();
@@ -393,8 +434,8 @@ Expr RewriteConv1d(BlockBuilder builder, const Var& var, 
const Call& src_call,
     conv_attrs->kernel_layout = "OIHW";
     conv_attrs->out_layout = "NCHW";
     conv_attrs->out_dtype = src_attrs->out_dtype;
-    const auto& conv2d =
-        MakeCall(builder, call->span, "exp", conv2d_op, new_args, 
Attrs(conv_attrs));
+    const auto& conv2d = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "exp"),
+                                                conv2d_op, new_args, 
Attrs(conv_attrs));
     // reduce output
     return Call(reshape_op, {conv2d, ShapeExpr(output_shape)}, Attrs(), 
call->sinfo_args,
                 call->span);
@@ -404,11 +445,80 @@ Expr RewriteConv1d(BlockBuilder builder, const Var& var, 
const Call& src_call,
   return call;
 }
 
+Expr RewriteGelu(BlockBuilder builder, const Var& var, const Call& src_call,
+                 const Map<Expr, Call>& new_calls, const String& config) {
+  // 0.5 * x * (1 + erf(sqrt(0.5) * x))
+  const auto& call = new_calls.count(src_call) ? new_calls[src_call] : 
src_call;
+  size_t in_dim = ExprUtils::GetShape(call->args[0]).size();
+  const auto& in_dtype = ExprUtils::GetDataType(call->args[0]);
+  // create ops
+  static const Op& add_op = Op::Get("relax.add");
+  static const Op& multiply_op = Op::Get("relax.multiply");
+  static const Op& erf_op = Op::Get("relax.erf");
+
+  const auto& factor = RewriteUtils::MakeConstant(builder, 
ExprUtils::GetSpanName(call, "factor"),
+                                                  std::sqrt(0.5), in_dtype, 
in_dim);
+  const auto& mul = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "mul"),
+                                           multiply_op, {factor, 
call->args[0]});
+  const auto& erf =
+      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "erf"), 
erf_op, {mul});
+  const auto& one =
+      RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "one"), 
1, in_dtype, in_dim);
+  const auto& add =
+      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "add"), 
add_op, {one, erf});
+  const auto& mul2 = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "mul2"),
+                                            multiply_op, {call->args[0], add});
+  const auto& half = RewriteUtils::MakeConstant(builder, 
ExprUtils::GetSpanName(call, "one"), 0.5,
+                                                in_dtype, in_dim);
+  return Call(multiply_op, {half, mul2}, Attrs(), call->sinfo_args, 
call->span);
+}
+
+Expr RewriteGeluTanh(BlockBuilder builder, const Var& var, const Call& 
src_call,
+                     const Map<Expr, Call>& new_calls, const String& config) {
+  // 0.5 * x * (1 + tanh(sqrt(2/pi) * (0.044715F * pow(x, 3) + x)))
+  const auto& call = new_calls.count(src_call) ? new_calls[src_call] : 
src_call;
+  size_t in_dim = ExprUtils::GetShape(call->args[0]).size();
+  const auto& in_dtype = ExprUtils::GetDataType(call->args[0]);
+
+  // create ops
+  static const Op& add_op = Op::Get("relax.add");
+  static const Op& multiply_op = Op::Get("relax.multiply");
+  static const Op& pow_op = Op::Get("relax.power");
+  static const Op& tanh_op = Op::Get("relax.tanh");
+
+  const auto& pow_factor = RewriteUtils::MakeConstant(
+      builder, ExprUtils::GetSpanName(call, "pow_factor"), 3, in_dtype, 
in_dim);
+  const auto& mul_factor = RewriteUtils::MakeConstant(
+      builder, ExprUtils::GetSpanName(call, "mul_factor"), 0.044715, in_dtype, 
in_dim);
+  const auto& pi_factor = RewriteUtils::MakeConstant(
+      builder, ExprUtils::GetSpanName(call, "pi_factor"), std::sqrt(2 / M_PI), 
in_dtype, in_dim);
+
+  const auto& pow = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "pow"), pow_op,
+                                           {call->args[0], pow_factor});
+  const auto& mul = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "mul"),
+                                           multiply_op, {mul_factor, pow});
+  const auto& add = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "add"), add_op,
+                                           {mul, call->args[0]});
+  const auto& mul2 = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "mul2"),
+                                            multiply_op, {pi_factor, add});
+  const auto& tanh =
+      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "tanh"), 
tanh_op, {mul2});
+  const auto& one =
+      RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "one"), 
1, in_dtype, in_dim);
+  const auto& add2 =
+      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "add"), 
add_op, {one, tanh});
+  const auto& mul3 = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "mul3"),
+                                            multiply_op, {call->args[0], 
add2});
+  const auto& half = RewriteUtils::MakeConstant(builder, 
ExprUtils::GetSpanName(call, "one"), 0.5,
+                                                in_dtype, in_dim);
+  return Call(multiply_op, {half, mul3}, Attrs(), call->sinfo_args, 
call->span);
+}
+
 Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& 
src_call,
-                      const Map<Expr, Call>& new_calls, const Array<Integer>& 
version) {
+                      const Map<Expr, Call>& new_calls, const String& config) {
   const auto& call = new_calls.count(src_call) ? new_calls[src_call] : 
src_call;
-  const auto& input_shape = GetShape(call->args[0]);
-  const auto& in_dtype = 
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->dtype;
+  const auto& input_shape = ExprUtils::GetShape(call->args[0]);
+  const auto& in_dtype = ExprUtils::GetDataType(call->args[0]);
   const auto* src_attrs = src_call->attrs.as<GroupNormAttrs>();
   Array<PrimExpr> group_shape = input_shape;
   Array<PrimExpr> exp_shape(input_shape.size(), Integer(1));
@@ -420,8 +530,8 @@ Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, 
const Call& src_call
   exp_shape.Set(axis, Integer(src_attrs->num_groups));
 
   // create eps constant
-  const auto& eps = MakeConstant(src_attrs->epsilon, in_dtype,
-                                 SpanUtils::GetAttr(call->span, 
msc_attr::kName) + "_eps");
+  const auto& eps = RewriteUtils::MakeConstant(builder, 
ExprUtils::GetSpanName(call, "eps"),
+                                               src_attrs->epsilon, in_dtype);
 
   // create ops
   static const Op& add_op = Op::Get("relax.add");
@@ -434,53 +544,63 @@ Expr RewriteGroupNorm(BlockBuilder builder, const Var& 
var, const Call& src_call
   static const Op& subtract_op = Op::Get("relax.subtract");
 
   // reshape input
-  const auto& reshape_in = MakeCall(builder, call->span, "reshape_in", 
reshape_op,
-                                    {call->args[0], ShapeExpr(group_shape)});
+  const auto& reshape_in =
+      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"reshape_in"), reshape_op,
+                             {call->args[0], ShapeExpr(group_shape)});
 
   // mean(input)
   auto mean_attrs = make_object<StatisticalAttrs>();
   mean_attrs->axis = src_attrs->axes;
   mean_attrs->keepdims = true;
-  const auto& mean =
-      MakeCall(builder, call->span, "mean", mean_op, {reshape_in}, 
Attrs(mean_attrs));
+  const auto& mean = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "mean"), mean_op,
+                                            {reshape_in}, Attrs(mean_attrs));
 
   // variance: mean((input-mean)*(input-mean))
-  const auto& diff = MakeCall(builder, call->span, "diff", subtract_op, 
{reshape_in, mean});
-  const auto& square = MakeCall(builder, call->span, "square", square_op, 
{diff});
-  const auto& variance =
-      MakeCall(builder, call->span, "variance", mean_op, {square}, 
Attrs(mean_attrs));
+  const auto& diff = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "diff"),
+                                            subtract_op, {reshape_in, mean});
+  const auto& square =
+      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "square"), 
square_op, {diff});
+  const auto& variance = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "variance"),
+                                                mean_op, {square}, 
Attrs(mean_attrs));
 
   // sqrt(var + epsilon)
   Array<PrimExpr> exp_eps_shape(input_shape.size(), Integer(1));
-  const auto& exp_eps =
-      MakeCall(builder, call->span, "exp_eps", reshape_op, {eps, 
ShapeExpr(exp_eps_shape)});
-  const auto& eps_add = MakeCall(builder, call->span, "eps_add", add_op, 
{variance, exp_eps});
-  const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {eps_add});
+  const auto& exp_eps = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "exp_eps"),
+                                               reshape_op, {eps, 
ShapeExpr(exp_eps_shape)});
+  const auto& eps_add = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "eps_add"),
+                                               add_op, {variance, exp_eps});
+  const auto& sqrt =
+      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"), 
sqrt_op, {eps_add});
 
   // diff/sqrt
-  Expr res = MakeCall(builder, call->span, "divide", divide_op, {diff, sqrt});
+  Expr res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"divide"), divide_op,
+                                    {diff, sqrt});
 
   // scale
   if (src_attrs->scale) {
-    const auto& exp_gamma = MakeCall(builder, call->span, "exp_gamma", 
reshape_op,
-                                     {call->args[1], ShapeExpr(exp_shape)});
-    res = MakeCall(builder, call->span, "scale", multiply_op, {res, 
exp_gamma});
+    const auto& exp_gamma =
+        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"exp_gamma"), reshape_op,
+                               {call->args[1], ShapeExpr(exp_shape)});
+    res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"scale"), multiply_op,
+                                 {res, exp_gamma});
   }
   // offset
   if (src_attrs->center) {
-    const auto& exp_beta = MakeCall(builder, call->span, "exp_beta", 
reshape_op,
-                                    {call->args[2], ShapeExpr(exp_shape)});
-    res = MakeCall(builder, call->span, "offset", add_op, {res, exp_beta});
+    const auto& exp_beta =
+        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"exp_beta"), reshape_op,
+                               {call->args[2], ShapeExpr(exp_shape)});
+    res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"offset"), add_op,
+                                 {res, exp_beta});
   }
   // reshape output
   return Call(reshape_op, {res, ShapeExpr(input_shape)}, Attrs(), 
call->sinfo_args, call->span);
 }
 
 Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& 
src_call,
-                      const Map<Expr, Call>& new_calls, const Array<Integer>& 
version) {
+                      const Map<Expr, Call>& new_calls, const String& config) {
   const auto& call = new_calls.count(src_call) ? new_calls[src_call] : 
src_call;
-  const auto& input_shape = GetShape(call->args[0]);
-  const auto& in_dtype = 
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->dtype;
+  const auto& input_shape = ExprUtils::GetShape(call->args[0]);
+  const auto& in_dtype = ExprUtils::GetDataType(call->args[0]);
   const auto* src_attrs = src_call->attrs.as<LayerNormAttrs>();
   Array<PrimExpr> exp_shape(input_shape.size(), Integer(1));
   for (const auto& a : src_attrs->axes) {
@@ -488,8 +608,8 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, 
const Call& src_call
     exp_shape.Set(index, input_shape[index]);
   }
   // create eps constant
-  const auto& eps = MakeConstant(src_attrs->epsilon, in_dtype,
-                                 SpanUtils::GetAttr(call->span, 
msc_attr::kName) + "_eps");
+  const auto& eps = RewriteUtils::MakeConstant(builder, 
ExprUtils::GetSpanName(call, "eps"),
+                                               src_attrs->epsilon, in_dtype);
 
   // create ops
   static const Op& add_op = Op::Get("relax.add");
@@ -505,30 +625,36 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& 
var, const Call& src_call
   auto mean_attrs = make_object<StatisticalAttrs>();
   mean_attrs->axis = src_attrs->axes;
   mean_attrs->keepdims = true;
-  const auto& mean =
-      MakeCall(builder, call->span, "mean", mean_op, {call->args[0]}, 
Attrs(mean_attrs));
+  const auto& mean = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "mean"), mean_op,
+                                            {call->args[0]}, 
Attrs(mean_attrs));
 
   // variance: mean((input-mean)*(input-mean))
-  const auto& diff = MakeCall(builder, call->span, "diff", subtract_op, 
{call->args[0], mean});
-  const auto& square = MakeCall(builder, call->span, "square", square_op, 
{diff});
-  const auto& variance =
-      MakeCall(builder, call->span, "variance", mean_op, {square}, 
Attrs(mean_attrs));
+  const auto& diff = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "diff"),
+                                            subtract_op, {call->args[0], 
mean});
+  const auto& square =
+      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "square"), 
square_op, {diff});
+  const auto& variance = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "variance"),
+                                                mean_op, {square}, 
Attrs(mean_attrs));
 
   // sqrt(var + epsilon)
   Array<PrimExpr> exp_eps_shape(input_shape.size(), Integer(1));
-  const auto& exp_eps =
-      MakeCall(builder, call->span, "exp_eps", reshape_op, {eps, 
ShapeExpr(exp_eps_shape)});
-  const auto& eps_add = MakeCall(builder, call->span, "eps_add", add_op, 
{variance, exp_eps});
-  const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {eps_add});
+  const auto& exp_eps = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "exp_eps"),
+                                               reshape_op, {eps, 
ShapeExpr(exp_eps_shape)});
+  const auto& eps_add = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "eps_add"),
+                                               add_op, {variance, exp_eps});
+  const auto& sqrt =
+      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"), 
sqrt_op, {eps_add});
 
   // diff/sqrt
   Call res = Call(divide_op, {diff, sqrt}, Attrs(), call->sinfo_args, 
call->span);
 
   // scale
   if (src_attrs->scale) {
-    const auto& exp_gamma = MakeCall(builder, call->span, "exp_gamma", 
reshape_op,
-                                     {call->args[1], ShapeExpr(exp_shape)});
-    const auto& res_var = EmitCall(builder, res, call->span, "pre_scale");
+    const auto& exp_gamma =
+        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"exp_gamma"), reshape_op,
+                               {call->args[1], ShapeExpr(exp_shape)});
+    const auto& res_var =
+        RewriteUtils::ReEmit(builder, ExprUtils::GetSpanName(call, 
"pre_scale"), res);
     if (src_attrs->center) {
       res = Call(multiply_op, {res_var, exp_gamma});
     } else {
@@ -537,87 +663,126 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& 
var, const Call& src_call
   }
   // offset
   if (src_attrs->center) {
-    const auto& exp_beta = MakeCall(builder, call->span, "exp_beta", 
reshape_op,
-                                    {call->args[2], ShapeExpr(exp_shape)});
-    const auto& res_var = EmitCall(builder, res, call->span, "pre_offset");
+    const auto& exp_beta =
+        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"exp_beta"), reshape_op,
+                               {call->args[2], ShapeExpr(exp_shape)});
+    const auto& res_var =
+        RewriteUtils::ReEmit(builder, ExprUtils::GetSpanName(call, 
"pre_offset"), res);
     res = Call(add_op, {res_var, exp_beta}, Attrs(), call->sinfo_args, 
call->span);
   }
   return res;
 }
 
 Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call,
-                   const Map<Expr, Call>& new_calls, const Array<Integer>& 
version) {
+                   const Map<Expr, Call>& new_calls, const String& config) {
+  const auto& trt_config = ParseConfig(config);
   const auto& call = new_calls.count(src_call) ? new_calls[src_call] : 
src_call;
-  const auto& shape_a = GetShape(call->args[0]);
-  const auto& shape_b = GetShape(call->args[1]);
+  const auto& shape_a = ExprUtils::GetShape(call->args[0]);
+  const auto& shape_b = ExprUtils::GetShape(call->args[1]);
   static const Op& reshape_op = Op::Get("relax.reshape");
+  if (call->args[1]->IsInstance<ConstantNode>() && shape_b.size() == 2 &&
+      trt_config.linear_to_conv) {
+    const auto& out_shape = ExprUtils::GetShape(var);
+    PrimExpr accumulate = ArrayUtils::Accumulate(shape_a, shape_a.size() - 1);
+    Array<PrimExpr> exp_shape{accumulate, shape_a[shape_a.size() - 1], 
Integer(1), Integer(1)};
+    const auto& exp_in = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "exp_in"),
+                                                reshape_op, {call->args[0], 
ShapeExpr(exp_shape)});
+    // transpose and expand weight to OIHW
+    static const Op& permute_dims_op = Op::Get("relax.permute_dims");
+    auto permute_attrs = make_object<PermuteDimsAttrs>();
+    Array<Integer> axes{Integer(1), Integer(0)};
+    permute_attrs->axes = axes;
+    const auto& trans_weight =
+        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"trans_weight"),
+                               permute_dims_op, {call->args[1]}, 
Attrs(permute_attrs));
+    Array<PrimExpr> weight_shape{shape_b[1], shape_b[0], Integer(1), 
Integer(1)};
+    const auto& exp_weight =
+        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"exp_weight"), reshape_op,
+                               {trans_weight, ShapeExpr(weight_shape)});
+    // to conv2d
+    static const Op& conv2d_op = Op::Get("relax.nn.conv2d");
+    auto conv_attrs = make_object<Conv2DAttrs>();
+    conv_attrs->strides = Array<IntImm>{Integer(1), Integer(1)};
+    conv_attrs->padding = Array<IntImm>{Integer(0), Integer(0), Integer(0), 
Integer(0)};
+    conv_attrs->dilation = Array<IntImm>{Integer(1), Integer(1)};
+    conv_attrs->groups = 1;
+    conv_attrs->data_layout = "NCHW";
+    conv_attrs->kernel_layout = "OIHW";
+    conv_attrs->out_layout = "NCHW";
+    conv_attrs->out_dtype = ExprUtils::GetDataType(var);
+    const auto& conv2d = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "conv2d"),
+                                                conv2d_op, {exp_in, 
exp_weight}, Attrs(conv_attrs));
+    return Call(reshape_op, {conv2d, ShapeExpr(out_shape)}, Attrs(), 
call->sinfo_args, call->span);
+  }
   if (shape_a.size() > shape_b.size()) {
     Array<PrimExpr> exp_shape(shape_a.size(), Integer(1));
-    for (size_t i = shape_b.size(); i < shape_a.size(); i++) {
-      exp_shape.Set(i, shape_b[i - shape_b.size()]);
+    size_t diff = shape_a.size() - shape_b.size();
+    for (size_t i = diff; i < shape_a.size(); i++) {
+      exp_shape.Set(i, shape_b[i - diff]);
     }
-    const auto& expand_b = MakeCall(builder, call->span, "expand_b", 
reshape_op,
-                                    {call->args[1], ShapeExpr(exp_shape)});
+    const auto& expand_b =
+        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"expand_b"), reshape_op,
+                               {call->args[1], ShapeExpr(exp_shape)});
     return Call(call->op, {call->args[0], expand_b}, call->attrs, 
call->sinfo_args, call->span);
   }
   if (shape_a.size() < shape_b.size()) {
     Array<PrimExpr> exp_shape(shape_b.size(), Integer(1));
-    for (size_t i = shape_a.size(); i < shape_b.size(); i++) {
-      exp_shape.Set(i, shape_a[i - shape_a.size()]);
+    size_t diff = shape_b.size() - shape_a.size();
+    for (size_t i = diff; i < shape_b.size(); i++) {
+      exp_shape.Set(i, shape_a[i - diff]);
     }
-    const auto& expand_a = MakeCall(builder, call->span, "expand_a", 
reshape_op,
-                                    {call->args[0], ShapeExpr(exp_shape)});
+    const auto& expand_a =
+        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, 
"expand_a"), reshape_op,
+                               {call->args[0], ShapeExpr(exp_shape)});
     return Call(call->op, {expand_a, call->args[1]}, call->attrs, 
call->sinfo_args, call->span);
   }
   return call;
 }
 
 Expr RewriteRsqrt(BlockBuilder builder, const Var& var, const Call& src_call,
-                  const Map<Expr, Call>& new_calls, const Array<Integer>& 
version) {
+                  const Map<Expr, Call>& new_calls, const String& config) {
   const auto& call = new_calls.count(src_call) ? new_calls[src_call] : 
src_call;
-  const auto& input_shape = GetShape(call->args[0]);
-  const auto& in_dtype = 
Downcast<TensorStructInfo>(GetStructInfo(call->args[0]))->dtype;
-  Array<PrimExpr> exp_shape(input_shape.size(), Integer(1));
+  const auto& input_shape = ExprUtils::GetShape(call->args[0]);
+  const auto& in_dtype = ExprUtils::GetDataType(call->args[0]);
   // create 1 constant
-  const auto& one =
-      MakeConstant(1, in_dtype, SpanUtils::GetAttr(call->span, 
msc_attr::kName) + "_one");
+  const auto& one = RewriteUtils::MakeConstant(builder, 
ExprUtils::GetSpanName(call, "eps"), 1,
+                                               in_dtype, input_shape.size());
 
   // create ops
-  static const Op& reshape_op = Op::Get("relax.reshape");
   static const Op& divide_op = Op::Get("relax.divide");
   static const Op& sqrt_op = Op::Get("relax.sqrt");
 
   // expand and divide
-  const auto& exp_one =
-      MakeCall(builder, call->span, "exp_one", reshape_op, {one, 
ShapeExpr(exp_shape)});
-  const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, 
{call->args[0]});
-  return Call(divide_op, {exp_one, sqrt}, Attrs(), call->sinfo_args, 
call->span);
+  const auto& sqrt = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "sqrt"), sqrt_op,
+                                            {call->args[0]});
+  return Call(divide_op, {one, sqrt}, Attrs(), call->sinfo_args, call->span);
 }
 
 Expr RewriteSilu(BlockBuilder builder, const Var& var, const Call& src_call,
-                 const Map<Expr, Call>& new_calls, const Array<Integer>& 
version) {
+                 const Map<Expr, Call>& new_calls, const String& config) {
   const auto& call = new_calls.count(src_call) ? new_calls[src_call] : 
src_call;
   // create ops
   static const Op& multiply_op = Op::Get("relax.multiply");
   static const Op& sigmoid_op = Op::Get("relax.sigmoid");
   // silu=input*sigmoid(input)
-  const auto& sigmoid = MakeCall(builder, call->span, "sigmoid", sigmoid_op, 
{call->args[0]});
+  const auto& sigmoid = RewriteUtils::MakeCall(builder, 
ExprUtils::GetSpanName(call, "sigmoid"),
+                                               sigmoid_op, {call->args[0]});
   return Call(multiply_op, {call->args[0], sigmoid}, Attrs(), 
call->sinfo_args, call->span);
 }
 
 Expr RewriteShapeLike(BlockBuilder builder, const Var& var, const Call& 
src_call,
-                      const Map<Expr, Call>& new_calls, const Array<Integer>& 
version) {
+                      const Map<Expr, Call>& new_calls, const String& config) {
   const auto& call = new_calls.count(src_call) ? new_calls[src_call] : 
src_call;
-  const auto& output_shape = GetShape(var);
+  const auto& output_shape = ExprUtils::GetShape(var);
   static const Op& reshape_op = Op::Get("relax.reshape");
   return Call(reshape_op, {call->args[0], ShapeExpr(output_shape)}, Attrs(), 
call->sinfo_args,
               call->span);
 }
 
 Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call,
-                  const Map<Expr, Call>& new_calls, const Array<Integer>& 
version) {
+                  const Map<Expr, Call>& new_calls, const String& config) {
   const auto& call = new_calls.count(src_call) ? new_calls[src_call] : 
src_call;
-  const auto& input_shape = GetShape(call->args[0]);
+  const auto& input_shape = ExprUtils::GetShape(call->args[0]);
   const auto* src_attrs = src_call->attrs.as<SplitAttrs>();
   size_t axis = CommonUtils::GetIndex(src_attrs->axis, input_shape.size());
   std::vector<int64_t> split_begins, split_ends;
@@ -646,9 +811,16 @@ Expr RewriteSplit(BlockBuilder builder, const Var& var, 
const Call& src_call,
   // create strided_slices
   Array<Expr> outputs;
   for (size_t i = 0; i < split_begins.size(); i++) {
-    auto slice = strided_slice(call->args[0], 
Tuple(Array<Expr>{PrimValue(Integer(axis))}),
-                               
Tuple(Array<Expr>{PrimValue(Integer(split_begins[i]))}),
-                               
Tuple(Array<Expr>{PrimValue(Integer(split_ends[i]))}));
+    static const Op& strided_slice_op = Op::Get("relax.strided_slice");
+    const auto& axes = Tuple(Array<Expr>{PrimValue(IntImm(DataType::Int(64), 
axis))});
+    const auto& begin = Tuple(Array<Expr>{PrimValue(IntImm(DataType::Int(64), 
split_begins[i]))});
+    const auto& end = Tuple(Array<Expr>{PrimValue(IntImm(DataType::Int(64), 
split_ends[i]))});
+    const auto& strides = 
Tuple(Array<Expr>{PrimValue(IntImm(DataType::Int(64), 1))});
+    auto attrs = make_object<StridedSliceAttrs>();
+    attrs->assume_inbound = true;
+    const auto& slice = RewriteUtils::MakeCall(
+        builder, ExprUtils::GetSpanName(call, "slice_" + std::to_string(i)), 
strided_slice_op,
+        {call->args[0], axes, begin, end, strides}, Attrs(attrs));
     outputs.push_back(slice);
   }
   return Tuple(outputs, call->span);
@@ -664,6 +836,9 @@ TVM_REGISTER_OP("relax.nn.batch_norm")
 
TVM_REGISTER_OP("relax.nn.conv1d").set_attr<FRewriteTensorRT>("FRewriteTensorRT",
 RewriteConv1d);
 TVM_REGISTER_OP("relax.nn.group_norm")
     .set_attr<FRewriteTensorRT>("FRewriteTensorRT", RewriteGroupNorm);
+TVM_REGISTER_OP("relax.nn.gelu").set_attr<FRewriteTensorRT>("FRewriteTensorRT",
 RewriteGelu);
+TVM_REGISTER_OP("relax.nn.gelu_tanh")
+    .set_attr<FRewriteTensorRT>("FRewriteTensorRT", RewriteGeluTanh);
 TVM_REGISTER_OP("relax.nn.layer_norm")
     .set_attr<FRewriteTensorRT>("FRewriteTensorRT", RewriteLayerNorm);
 
TVM_REGISTER_OP("relax.nn.silu").set_attr<FRewriteTensorRT>("FRewriteTensorRT", 
RewriteSilu);
@@ -695,9 +870,9 @@ 
TVM_REGISTER_OP("relax.split").set_attr<FRewriteTensorRT>("FRewriteTensorRT", Re
 
 class TensorRTTransformer : public ExprMutator {
  public:
-  explicit TensorRTTransformer(IRModule ctx_module, const Array<Integer>& 
version)
+  explicit TensorRTTransformer(IRModule ctx_module, const String& config)
       : ExprMutator(ctx_module) {
-    version_ = version;
+    config_ = config;
   }
 
   void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) 
final {
@@ -707,7 +882,7 @@ class TensorRTTransformer : public ExprMutator {
       if (rewrite_map.count(op)) {
         const auto& call = GetRef<Call>(call_node);
         FRewriteTensorRT f = rewrite_map[op];
-        const auto& new_call = f(builder_, binding->var, call, new_calls_, 
version_);
+        const auto& new_call = f(builder_, binding->var, call, new_calls_, 
config_);
         if (new_call != call) {
           ReEmitBinding(binding, builder_->Normalize(new_call));
           new_calls_.Set(binding->var, call);
@@ -721,20 +896,19 @@ class TensorRTTransformer : public ExprMutator {
 
  private:
   Map<Expr, Call> new_calls_;
-  Array<Integer> version_;
+  String config_;
 };
 
-Function TransformTensorRT(const Function& func, const IRModule& module,
-                           const Array<Integer>& version) {
-  return Downcast<Function>(TensorRTTransformer(module, 
version).VisitExpr(func));
+Function TransformTensorRT(const Function& func, const IRModule& module, const 
String& config) {
+  return Downcast<Function>(TensorRTTransformer(module, 
config).VisitExpr(func));
 }
 
 namespace transform {
 
-Pass TransformTensorRT(const Array<Integer>& version) {
+Pass TransformTensorRT(const String& config) {
   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
       [=](Function f, IRModule m, PassContext pc) {
-        return relax::TransformTensorRT(f, m, version);
+        return relax::TransformTensorRT(f, m, config);
       };
   return CreateFunctionPass(pass_func, 0, "TransformTensorRT", {});
 }
diff --git a/tests/python/contrib/test_msc/test_translate_tensorrt.py 
b/tests/python/contrib/test_msc/test_translate_tensorrt.py
index 74c25ceacf..7c8c283099 100644
--- a/tests/python/contrib/test_msc/test_translate_tensorrt.py
+++ b/tests/python/contrib/test_msc/test_translate_tensorrt.py
@@ -87,7 +87,7 @@ def check_names(mod):
         NameChecker().check(func)
 
 
-def verify_model(torch_model, input_info, allow_incomplete=False):
+def verify_model(torch_model, input_info, **trans_config):
     """Build model and verify results"""
 
     graph_model = fx.symbolic_trace(torch_model)
@@ -100,9 +100,7 @@ def verify_model(torch_model, input_info, 
allow_incomplete=False):
         golden = [golden]
     golden = [g.detach().cpu().numpy() for g in golden]
     # partition module for tensorrt
-    mod, graphs, weights = translate.partition_for_tensorrt(
-        mod, trans_config={"allow_incomplete": allow_incomplete}
-    )
+    mod, graphs, weights = translate.partition_for_tensorrt(mod, 
trans_config=trans_config)
     check_names(mod)
     output_folder = msc_utils.msc_dir()
     # tranalte to tensorrt
@@ -191,6 +189,8 @@ def test_linear():
     input_info = [([1, 3, 10, 10], "float32")]
     verify_model(Dense1(), input_info)
     verify_model(Dense2(), input_info)
+    verify_model(Dense1(), input_info, linear_to_conv=True)
+    verify_model(Dense2(), input_info, linear_to_conv=True)
     verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")])
 
 
@@ -368,10 +368,10 @@ def test_embedding():
             self.embedding = torch.nn.Embedding(10, 3)
 
         def forward(self, data):
-            return self.embedding(data)
+            return self.embedding(data.to(torch.int64))
 
-    verify_model(Embedding(), [([4], "int64")], allow_incomplete=True)
-    verify_model(Embedding(), [([4, 5], "int64")], allow_incomplete=True)
+    verify_model(Embedding(), [([4], "int32")])
+    verify_model(Embedding(), [([4, 5], "int32")])
 
 
 @requires_tensorrt
@@ -801,14 +801,14 @@ def test_argmax():
 
     class Argmax1(Module):
         def forward(self, data):
-            return torch.argmax(data, dim=-1)
+            return torch.argmax(data, dim=-1).to(torch.int32)
 
     class Argmax2(Module):
         def forward(self, data):
-            return torch.argmax(data, dim=-1, keepdim=True)
+            return torch.argmax(data, dim=-1, keepdim=True).to(torch.int32)
 
-    verify_model(Argmax1(), [([256, 256], "float32")], allow_incomplete=True)
-    verify_model(Argmax2(), [([256, 256], "float32")], allow_incomplete=True)
+    verify_model(Argmax1(), [([256, 256], "float32")])
+    verify_model(Argmax2(), [([256, 256], "float32")])
 
 
 @requires_tensorrt
@@ -817,14 +817,14 @@ def test_argmin():
 
     class Argmin1(Module):
         def forward(self, data):
-            return torch.argmin(data, dim=-1)
+            return torch.argmin(data, dim=-1).to(torch.int32)
 
     class Argmin2(Module):
         def forward(self, data):
-            return torch.argmin(data, dim=-1, keepdim=True)
+            return torch.argmin(data, dim=-1, keepdim=True).to(torch.int32)
 
-    verify_model(Argmin1(), [([256, 256], "float32")], allow_incomplete=True)
-    verify_model(Argmin2(), [([256, 256], "float32")], allow_incomplete=True)
+    verify_model(Argmin1(), [([256, 256], "float32")])
+    verify_model(Argmin2(), [([256, 256], "float32")])
 
 
 @requires_tensorrt
@@ -876,5 +876,22 @@ def test_max():
     verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")])
 
 
+@requires_tensorrt
+def test_gelu():
+    """test tensorrt translator for gelu"""
+
+    class Gelu1(Module):
+        def forward(self, data):
+            return torch.nn.functional.gelu(data)
+
+    class Gelu2(Module):
+        def forward(self, data):
+            return torch.nn.functional.gelu(data, approximate="tanh")
+
+    input_info = [([1, 3, 10, 10], "float32")]
+    verify_model(Gelu1(), input_info)
+    verify_model(Gelu2(), input_info)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to