This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2fd4e11194 [PYTHON] Fix PEP 563 compat and remove args_converter 
(#18847)
2fd4e11194 is described below

commit 2fd4e111949ca60f6a76e09d185847b07a8f8db1
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Feb 28 10:41:10 2026 -0500

    [PYTHON] Fix PEP 563 compat and remove args_converter (#18847)
---
 python/tvm/relax/__init__.py                       |   2 -
 python/tvm/relax/op/base.py                        |  25 ++-
 python/tvm/relax/op/builtin/builtin.py             |   5 +-
 python/tvm/relax/op/distributed/distributed.py     |   7 +-
 python/tvm/relax/op/index.py                       |   8 +-
 python/tvm/relax/op/memory/memory.py               |   8 +-
 python/tvm/relax/op/sampling.py                    |   2 -
 python/tvm/relax/op/unary.py                       |   5 +-
 python/tvm/relax/op/vm/vm.py                       |   8 +-
 python/tvm/relax/type_converter.py                 | 179 ---------------------
 python/tvm/relax/utils.py                          |   3 -
 python/tvm/s_tir/schedule/_type_checker.py         |   9 +-
 .../tvm/script/ir_builder/relax/distributed/ir.py  |   7 +-
 python/tvm/script/ir_builder/relax/ir.py           |   6 +-
 tests/python/relax/test_expr_args_converter.py     | 147 -----------------
 15 files changed, 49 insertions(+), 372 deletions(-)

diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index 313048e4d4..65c44db3ac 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -21,8 +21,6 @@
 from tvm.runtime import vm
 from tvm.runtime.vm import VirtualMachine, VMInstrumentReturnKind
 
-from .type_converter import args_converter
-
 # Expr
 from .expr import (
     Expr,
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index 10a4eb76b3..28a7aa897a 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -27,7 +27,7 @@ from tvm.runtime.object import Object
 from ...ir import PrimExpr
 from ..expr import Call, Expr, ExternFunc, GlobalVar, ShapeExpr, StringImm, Var
 from ..struct_info import StructInfo, TensorStructInfo
-from ..utils import args_converter
+from ..utils import convert_to_expr
 from . import _ffi_api
 
 py_print = print  # pylint: disable=invalid-name
@@ -76,7 +76,9 @@ def _wrap_inline_arg_tuple(args) -> Expr:
     in-line relax Tuple.
 
     """
-    if (
+    if isinstance(args, tuple | list):
+        return tvm.relax.Tuple([convert_to_expr(a) for a in args])
+    elif (
         isinstance(args, Expr)
         and not isinstance(args, tvm.relax.Tuple)
         and (
@@ -89,7 +91,6 @@ def _wrap_inline_arg_tuple(args) -> Expr:
         return args
 
 
-@args_converter.auto
 def call_tir(
     gvar: GlobalVar,
     args: Expr,
@@ -131,7 +132,6 @@ def call_tir(
     return _ffi_api.call_tir(gvar, args, out_sinfo, tir_vars)  # type: ignore
 
 
-@args_converter.auto
 def call_tir_with_grad(
     gvar: GlobalVar,
     args: Expr,
@@ -190,7 +190,6 @@ def call_tir_with_grad(
     )
 
 
-@args_converter.auto
 def call_tir_inplace(
     gvar: GlobalVar,
     args: Expr,
@@ -261,7 +260,6 @@ def call_tir_inplace(
     )
 
 
-@args_converter.auto
 def call_dps_packed(
     func: str | Expr,
     args: Expr,
@@ -303,7 +301,6 @@ def call_dps_packed(
     return _ffi_api.call_dps_packed(func, args, out_sinfo)  # type: ignore
 
 
-@args_converter.auto
 def call_py_func(
     func_name: str,
     args: Expr,
@@ -339,7 +336,6 @@ def call_py_func(
     return _ffi_api.call_py_func(func_name, args, out_sinfo)  # type: ignore
 
 
-@args_converter.auto
 def call_builtin_with_ctx(
     func: str | Expr,
     args: Expr,
@@ -367,6 +363,8 @@ def call_builtin_with_ctx(
     if isinstance(func, str):
         func = ExternFunc(func)
 
+    args = _wrap_inline_arg_tuple(args)
+
     if sinfo_args is not None and not isinstance(sinfo_args, list | tuple):
         sinfo_args = [sinfo_args]
 
@@ -377,7 +375,6 @@ def call_builtin_with_ctx(
     )
 
 
-@args_converter.auto
 def make_closure(
     func: Expr,
     args: Expr,
@@ -400,10 +397,11 @@ def make_closure(
         The VMClosure.
     """
 
+    args = _wrap_inline_arg_tuple(args)
+
     return _ffi_api.make_closure(func, args)  # type: ignore
 
 
-@args_converter.auto
 def invoke_closure(
     closure: Expr,
     args: Expr,
@@ -428,6 +426,7 @@ def invoke_closure(
     ret: Call
         A call to `invoke_closure`.
     """
+    args = _wrap_inline_arg_tuple(args)
 
     if not isinstance(sinfo_args, list | tuple):
         sinfo_args = [sinfo_args]
@@ -677,7 +676,6 @@ def shape_to_tensor(expr: Expr) -> Expr:
     return _ffi_api.shape_to_tensor(expr)  # type: ignore # pylint: 
disable=no-member
 
 
-@args_converter.auto
 def call_inplace_packed(
     func: str | ExternFunc | GlobalVar,
     *args: Expr,
@@ -731,6 +729,7 @@ def call_inplace_packed(
         func = func.global_symbol
 
     op = ExternFunc(func)
+    args = tuple(convert_to_expr(a) for a in args)
     if sinfo_args is None:
         raise ValueError("R.call_pure_packed is required to have type_args")
     if isinstance(sinfo_args, tuple):  # type: ignore
@@ -743,7 +742,6 @@ def call_inplace_packed(
     return _ffi_api.call_inplace_packed(op, args, inplace_indices, sinfo_args) 
 # type: ignore # pylint: disable=no-member
 
 
-@args_converter.auto
 def call_pure_packed(
     func: str | ExternFunc | GlobalVar,
     *args: Expr,
@@ -782,6 +780,7 @@ def call_pure_packed(
         func = func.global_symbol
 
     op = ExternFunc(func)
+    args = tuple(convert_to_expr(a) for a in args)
 
     if sinfo_args is None:
         raise ValueError("R.call_pure_packed is required to have type_args")
@@ -807,7 +806,6 @@ def call_pure_packed(
     return _ffi_api.call_pure_packed(op, args, None, sinfo_args)  # type: 
ignore # pylint: disable=no-member
 
 
-@args_converter.auto
 def invoke_pure_closure(
     closure: Expr,
     args: Expr,
@@ -838,6 +836,7 @@ def invoke_pure_closure(
     ret: Call
         A call to `invoke_pure_closure`.
     """
+    args = _wrap_inline_arg_tuple(args)
 
     if not isinstance(sinfo_args, list | tuple):
         sinfo_args = [sinfo_args]
diff --git a/python/tvm/relax/op/builtin/builtin.py 
b/python/tvm/relax/op/builtin/builtin.py
index 47fab9c7a5..328d63f3cc 100644
--- a/python/tvm/relax/op/builtin/builtin.py
+++ b/python/tvm/relax/op/builtin/builtin.py
@@ -16,11 +16,10 @@
 """The builtin Relax operators."""
 
 from ...expr import Call, DataTypeImm, Expr, PrimValue, StringImm
-from ...utils import args_converter
+from ...utils import convert_to_expr
 from . import _ffi_api
 
 
-@args_converter.auto
 def alloc_tensor(
     shape: Expr,
     dtype: str | Expr,
@@ -49,6 +48,8 @@ def alloc_tensor(
     result : Call
         A relax Call, which gets the allocated tensor.
     """
+    if not isinstance(shape, Expr):
+        shape = convert_to_expr(shape)
     if isinstance(dtype, str):
         dtype = DataTypeImm(dtype)
     if isinstance(runtime_device_index, int):
diff --git a/python/tvm/relax/op/distributed/distributed.py 
b/python/tvm/relax/op/distributed/distributed.py
index b17708223d..b09f8686ac 100644
--- a/python/tvm/relax/op/distributed/distributed.py
+++ b/python/tvm/relax/op/distributed/distributed.py
@@ -20,10 +20,10 @@
 from tvm.ir import PrimExpr
 from tvm.relax.distributed import DTensorStructInfo
 from tvm.relax.distributed.struct_info import DeviceMesh, Placement
-from tvm.relax.utils import args_converter
 
 from ...expr import Call, Expr, GlobalVar, ShapeExpr
 from ...expr import Tuple as RxTuple
+from ...utils import convert_to_expr
 from . import _ffi_api
 
 
@@ -66,7 +66,6 @@ def redistribute(input: Expr, device_mesh: DeviceMesh, 
placement: Placement) ->
     return _ffi_api.redistribute(input, device_mesh, placement)  # type: ignore
 
 
-@args_converter.auto
 def call_tir_local_view(
     gvar: GlobalVar,
     args: Expr,
@@ -99,7 +98,9 @@ def call_tir_local_view(
     ret: Call
         A call node for the call_tir_local_view operator.
     """
-    if isinstance(args, Expr) and not isinstance(args, RxTuple):  # type: 
ignore
+    if isinstance(args, tuple | list):
+        args = RxTuple([convert_to_expr(a) for a in args])
+    elif isinstance(args, Expr) and not isinstance(args, RxTuple):  # type: 
ignore
         args = RxTuple((args,))
 
     if not isinstance(out_sinfo, list):
diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py
index 1edb807db8..71cae89c3c 100644
--- a/python/tvm/relax/op/index.py
+++ b/python/tvm/relax/op/index.py
@@ -18,8 +18,8 @@
 
 from tvm.ir.expr import PrimExpr
 
-from .. import args_converter
 from ..expr import Expr
+from ..utils import convert_to_expr
 from . import _ffi_api
 
 PrimExprLike = int | PrimExpr
@@ -58,7 +58,6 @@ def take(x: Expr, indices: Expr, axis: int | None = None, 
mode: str = "fast") ->
     return _ffi_api.take(x, indices, axis, mode)  # type: ignore
 
 
-@args_converter.auto
 def strided_slice(
     x: Expr,
     axes: Expr,
@@ -101,6 +100,11 @@ def strided_slice(
     strided_slice require the input `begin`, `end` and `strides` to have the
     same length as `axes`.
     """
+    axes = convert_to_expr(axes)
+    begin = convert_to_expr(begin)
+    end = convert_to_expr(end)
+    if strides is not None:
+        strides = convert_to_expr(strides)
     return _ffi_api.strided_slice(x, axes, begin, end, strides, 
assume_inbound)  # type: ignore
 
 
diff --git a/python/tvm/relax/op/memory/memory.py 
b/python/tvm/relax/op/memory/memory.py
index ac8343c310..ba54eea9ef 100644
--- a/python/tvm/relax/op/memory/memory.py
+++ b/python/tvm/relax/op/memory/memory.py
@@ -16,11 +16,10 @@
 """Relax memory primitives."""
 
 from ...expr import Call, DataTypeImm, Expr, PrimValue, StringImm
-from ...utils import args_converter
+from ...utils import convert_to_expr
 from . import _ffi_api
 
 
-@args_converter.auto
 def alloc_storage(
     size: Expr,
     virtual_device_index: int | Expr,
@@ -50,6 +49,7 @@ def alloc_storage(
     result : Call
         A relax Call, which gets the allocated storage.
     """
+    size = convert_to_expr(size)
     if isinstance(dtype, str):
         dtype = DataTypeImm(dtype)
     if isinstance(storage_scope, str):
@@ -59,7 +59,6 @@ def alloc_storage(
     return _ffi_api.alloc_storage(size, virtual_device_index, storage_scope, 
dtype)  # type: ignore
 
 
-@args_converter.auto
 def alloc_tensor(
     storage: Expr,
     offset: int | Expr,
@@ -94,12 +93,12 @@ def alloc_tensor(
     """
     if isinstance(offset, int):
         offset = PrimValue(offset)
+    shape = convert_to_expr(shape)
     if isinstance(dtype, str):
         dtype = DataTypeImm(dtype)
     return _ffi_api.alloc_tensor(storage, offset, shape, dtype, 
runtime_device_ind)  # type: ignore
 
 
-@args_converter.auto
 def kill_storage(storage: Expr) -> Call:
     """Construct a Call to kill a storage.
 
@@ -116,7 +115,6 @@ def kill_storage(storage: Expr) -> Call:
     return _ffi_api.kill_storage(storage)  # type: ignore
 
 
-@args_converter.auto
 def kill_tensor(tensor: Expr) -> Call:
     """Construct a Call to kill a tensor.
 
diff --git a/python/tvm/relax/op/sampling.py b/python/tvm/relax/op/sampling.py
index bcd43a3922..cd4cad9253 100644
--- a/python/tvm/relax/op/sampling.py
+++ b/python/tvm/relax/op/sampling.py
@@ -16,12 +16,10 @@
 # under the License.
 """Sampling operators."""
 
-from .. import args_converter
 from ..expr import Expr
 from . import _ffi_api
 
 
-@args_converter.auto
 def multinomial_from_uniform(
     prob: Expr,
     uniform_sample: Expr,
diff --git a/python/tvm/relax/op/unary.py b/python/tvm/relax/op/unary.py
index c77d8311d5..e7b800b5c9 100644
--- a/python/tvm/relax/op/unary.py
+++ b/python/tvm/relax/op/unary.py
@@ -18,7 +18,7 @@
 """Relax unary arithmetic operators."""
 
 from ..expr import Expr
-from ..utils import args_converter
+from ..utils import convert_to_expr
 from . import _ffi_api
 
 ###################### Arithmetic operators ######################
@@ -526,7 +526,6 @@ def trunc(x: Expr) -> Expr:
     return _ffi_api.trunc(x)  # type: ignore
 
 
-@args_converter.auto
 def clip(x: Expr, min: Expr, max: Expr) -> Expr:
     """Clips tensor values to a specified min and max.
 
@@ -546,6 +545,8 @@ def clip(x: Expr, min: Expr, max: Expr) -> Expr:
     result : relax.Expr
         The computed result.
     """
+    min = convert_to_expr(min)
+    max = convert_to_expr(max)
     return _ffi_api.clip(x, min, max)  # type: ignore
 
 
diff --git a/python/tvm/relax/op/vm/vm.py b/python/tvm/relax/op/vm/vm.py
index de1015d08d..868ac92295 100644
--- a/python/tvm/relax/op/vm/vm.py
+++ b/python/tvm/relax/op/vm/vm.py
@@ -16,11 +16,10 @@
 """Relax vm primitives."""
 
 from ...expr import Call, DataTypeImm, Expr, PrimValue, StringImm, Tuple
-from ...utils import args_converter
+from ...utils import convert_to_expr
 from . import _ffi_api
 
 
-@args_converter.auto
 def alloc_storage(
     shape: Expr,
     runtime_device_index: int | Expr,
@@ -50,6 +49,7 @@ def alloc_storage(
     result : Call
         A relax Call, which gets the allocated storage.
     """
+    shape = convert_to_expr(shape)
     if isinstance(dtype, str):
         dtype = DataTypeImm(dtype)
     if isinstance(storage_scope, str):
@@ -59,7 +59,6 @@ def alloc_storage(
     return _ffi_api.alloc_storage(shape, runtime_device_index, dtype, 
storage_scope)  # type: ignore
 
 
-@args_converter.auto
 def alloc_tensor(
     storage: Expr,
     offset: int | Expr,
@@ -94,6 +93,7 @@ def alloc_tensor(
     """
     if isinstance(offset, int):
         offset = PrimValue(offset)
+    shape = convert_to_expr(shape)
     if isinstance(dtype, str):
         dtype = DataTypeImm(dtype)
     return _ffi_api.alloc_tensor(storage, offset, shape, dtype, 
runtime_device_ind)  # type: ignore
@@ -116,7 +116,6 @@ def kill_object(obj: Expr) -> Call:
     return _ffi_api.kill_object(obj)  # type: ignore
 
 
-@args_converter.auto
 def call_tir_dyn(func: Expr, args: Tuple) -> Call:
     """Construct a Call to call_tir_dyn (invoke the given TIR PrimFunc)
     consisting of the input tensors and the shape of the result.
@@ -134,6 +133,7 @@ def call_tir_dyn(func: Expr, args: Tuple) -> Call:
     result : Call
         A relax Call to call_tir_dyn.
     """
+    func = convert_to_expr(func)
     if isinstance(args, list | tuple):
         args = Tuple(args)
 
diff --git a/python/tvm/relax/type_converter.py 
b/python/tvm/relax/type_converter.py
deleted file mode 100644
index 50014cba5a..0000000000
--- a/python/tvm/relax/type_converter.py
+++ /dev/null
@@ -1,179 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: disable=invalid-name,too-many-locals
-
-"""Argument converter utility for Relax
-
-This utility is used to decorate constructors of `tvm.relax.Expr`, and
-must be able to be imported before `tvm.relax.Expr` or its subtypes
-have been defined.  Neither the class definitions nor any type
-signature in this file may reference relax types.  All references must
-be exclusively in function bodies to avoid having a circular reference
-during module imports.
-"""
-
-import functools
-import inspect
-from collections.abc import Callable
-from typing import Any, TypeVar
-
-import tvm
-
-FType = TypeVar("FType", bound=Callable[..., "tvm.relax.Expr"])
-
-
-class _ArgsConverter:
-    """A helper class to convert the arguments to Expr."""
-
-    @staticmethod
-    def convert(args_to_expr: list[str], args_to_list_expr: list[str]):
-        """Convert the arguments to Expr.
-
-        Parameters
-        ----------
-        args_to_expr : List[str]
-            The argument names to be converted to Expr.
-
-        args_to_list_expr : List[str]
-            The argument names to be converted to List[Expr].
-
-        Returns
-        -------
-        output : Callable[[FType], FType]
-            The decorator.
-        """
-
-        if any([x in args_to_list_expr for x in args_to_expr]):
-            raise ValueError("`args_to_expr` and `args_to_list_expr` should be 
disjoint.")
-
-        def _convert(name: str, value: Any) -> Any:
-            if value is None:
-                return value
-            if name in args_to_expr:
-                try:
-                    return tvm.relax.utils.convert_to_expr(value)
-                except Exception as err:
-                    raise TypeError(
-                        f"Argument `{name}` is expected to be converted to 
`Expr`, "
-                        f"but failed with input value: {value}"
-                    ) from err
-            elif name in args_to_list_expr:
-                try:
-                    return [tvm.relax.utils.convert_to_expr(x) for x in value]
-                except Exception as err:
-                    raise TypeError(
-                        f"Argument `{name}` is expected to be converted to 
`List[Expr]`, "
-                        f"but failed with input value: {value}"
-                    ) from err
-            else:
-                return value
-
-        def inner(func: FType) -> FType:
-            sig = inspect.signature(func)
-            param_names = list(sig.parameters.keys())
-            for name in args_to_expr + args_to_list_expr:
-                if name not in param_names:
-                    raise ValueError(f"Argument `{name}` is not found in 
function signature.")
-
-            @functools.wraps(func)
-            def wrapper(*args, **kwargs):
-                bound = sig.bind(*args, **kwargs)
-                bound.apply_defaults()
-                for param in sig.parameters.values():
-                    if param.kind == param.VAR_POSITIONAL:
-                        # *args case
-                        values = [_convert(param.name, x) for x in 
bound.arguments[param.name]]
-                        bound.arguments[param.name] = tuple(values)
-                    elif param.kind == param.VAR_KEYWORD:
-                        # **kwargs case
-                        key_value = {
-                            key: _convert(param.name, value)
-                            for key, value in 
bound.arguments[param.name].items()
-                        }
-                        bound.arguments[param.name] = key_value
-                    else:
-                        bound.arguments[param.name] = _convert(
-                            param.name, bound.arguments[param.name]
-                        )
-                return func(*bound.args, **bound.kwargs)
-
-            return wrapper  # type: ignore
-
-        return inner
-
-    @staticmethod
-    def to_expr(*arg_names: str) -> Callable:
-        """Convert the arguments to Expr.
-
-        Parameters
-        ----------
-        *arg_names: str
-            The list of argument names that need to be converted to Expr.
-
-        Returns
-        -------
-        output: Callable
-            The decorator.
-        """
-
-        return _ArgsConverter.convert(args_to_expr=list(arg_names), 
args_to_list_expr=[])
-
-    @staticmethod
-    def to_list_expr(*arg_names: str) -> Callable:
-        """Convert the arguments to List of Expr.
-
-        Parameters
-        ----------
-        *arg_names: str
-            The list of argument names that need to be converted to List of 
Expr.
-
-        Returns
-        -------
-        output: Callable
-            The decorator.
-        """
-
-        return _ArgsConverter.convert(args_to_expr=[], 
args_to_list_expr=list(arg_names))
-
-    @staticmethod
-    def auto(func: FType) -> FType:
-        """Decorator for automatically convert the arguments to Expr according 
to type annotation.
-        Only two patterns are supported:
-
-        1. The argument is Expr or Expr | None.
-
-        2. The argument is List[Expr] or Optional[List[Expr]].
-
-        """
-        sig = inspect.signature(func)
-        args_to_expr = []
-        args_to_list_expr = []
-
-        from . import Expr  # pylint: disable=import-outside-toplevel
-
-        for param in sig.parameters.values():
-            anno = param.annotation
-            if anno in (Expr, Expr | None):
-                args_to_expr.append(param.name)
-            if anno in (list[Expr], list[Expr] | None):
-                args_to_list_expr.append(param.name)
-
-        return _ArgsConverter.convert(args_to_expr, args_to_list_expr)(func)
-
-
-args_converter = _ArgsConverter()  # pylint: disable=invalid-name
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index f76459294a..75f62c525d 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -39,9 +39,6 @@ from .expr import Expr, Function, PrimValue, ShapeExpr, 
StringImm, te_tensor
 from .expr import Tuple as rx_Tuple
 from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo
 
-# Re-export `args_converter` here for backwards compatibility
-from .type_converter import args_converter  # pylint: disable=unused-import
-
 
 def metadata_partitioner(rx_txt: str) -> list[str]:
     """Extract Relax program and metadata section.
diff --git a/python/tvm/s_tir/schedule/_type_checker.py 
b/python/tvm/s_tir/schedule/_type_checker.py
index 57e612688b..5ad4f29c55 100644
--- a/python/tvm/s_tir/schedule/_type_checker.py
+++ b/python/tvm/s_tir/schedule/_type_checker.py
@@ -345,17 +345,22 @@ FType = TypeVar("FType", bound=Callable[..., Any])
 def type_checked(func: FType) -> FType:
     """Type check the input arguments of a function."""
     sig = inspect.signature(func)
+    try:
+        hints = typing.get_type_hints(func)
+    except Exception:
+        hints = {}
 
     @functools.wraps(func)
     def wrap(*args, **kwargs):
         bound_args = sig.bind(*args, **kwargs)
         bound_args.apply_defaults()
         for param in sig.parameters.values():
-            if param.annotation != inspect.Signature.empty:
+            type_hint = hints.get(param.name, inspect.Parameter.empty)
+            if type_hint != inspect.Parameter.empty:
                 error_msg = _type_check(
                     bound_args.arguments[param.name],
                     param.name,
-                    param.annotation,
+                    type_hint,
                 )
                 if error_msg is not None:
                     error_msg = f'In "{func.__qualname__}", {error_msg}'
diff --git a/python/tvm/script/ir_builder/relax/distributed/ir.py 
b/python/tvm/script/ir_builder/relax/distributed/ir.py
index 49bfb6139a..485e91e3e5 100644
--- a/python/tvm/script/ir_builder/relax/distributed/ir.py
+++ b/python/tvm/script/ir_builder/relax/distributed/ir.py
@@ -40,7 +40,7 @@ from tvm.relax.op.distributed import (
 from tvm.relax.op.distributed import (
     redistribute as _redistribute,
 )
-from tvm.relax.utils import args_converter
+from tvm.relax.utils import convert_to_expr
 from tvm.runtime import _tensor
 
 from ... import IRBuilder
@@ -49,7 +49,6 @@ from ..ir import py_str
 from . import _ffi_api
 
 
-@args_converter.auto
 def call_tir(
     func: str | Expr,
     args: Expr,
@@ -82,7 +81,9 @@ def call_tir(
     if isinstance(func, str):
         func = ExternFunc(func)
 
-    if isinstance(args, Expr) and not isinstance(args, RxTuple):  # type: 
ignore
+    if isinstance(args, tuple | list):
+        args = RxTuple([convert_to_expr(a) for a in args])
+    elif isinstance(args, Expr) and not isinstance(args, RxTuple):  # type: 
ignore
         args = RxTuple((args,))
 
     if not isinstance(out_sinfo, list):
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index a7428c32a3..a40517992e 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -198,7 +198,7 @@ from tvm.relax.op import (
 )
 from tvm.relax.op.builtin import stop_lift_params
 from tvm.relax.struct_info import StructInfo
-from tvm.relax.utils import args_converter, gen_call_tir_inputs
+from tvm.relax.utils import convert_to_expr, gen_call_tir_inputs
 from tvm.runtime import Object as tvm_Object
 from tvm.runtime import ObjectConvertible
 from tvm.runtime._tensor import (
@@ -403,7 +403,6 @@ def output(*vars: tuple[Var]) -> None:
 ################################## Ops #################################
 
 
-@args_converter.auto
 def call_packed(
     func: py_str,
     *args: Expr,
@@ -428,6 +427,7 @@ def call_packed(
         The created Relax Call
     """
     op = ExternFunc(func)
+    args = py_tuple(convert_to_expr(a) for a in args)
     if sinfo_args is None:
         sinfo_args = []
     if isinstance(sinfo_args, py_tuple):  # type: ignore
@@ -460,7 +460,6 @@ def call_packed(
     return Call(op, args, attrs=attrs, sinfo_args=sinfo_args)
 
 
-@args_converter.auto
 def call_py_func(
     py_func_name: py_str,
     *args: Expr,
@@ -485,6 +484,7 @@ def call_py_func(
     call: Call
         The created Relax Call for call_py_func operator.
     """
+    args = py_tuple(convert_to_expr(a) for a in args)
     if isinstance(out_sinfo, py_tuple):  # type: ignore
         out_sinfo = list(out_sinfo)
     elif not isinstance(out_sinfo, list):
diff --git a/tests/python/relax/test_expr_args_converter.py 
b/tests/python/relax/test_expr_args_converter.py
deleted file mode 100644
index d156245452..0000000000
--- a/tests/python/relax/test_expr_args_converter.py
+++ /dev/null
@@ -1,147 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# ruff: noqa: E731
-
-from collections.abc import Callable
-from typing import Any
-
-import pytest
-
-import tvm
-import tvm.testing
-from tvm import relax
-from tvm.relax import Expr
-from tvm.relax.utils import args_converter
-
-
-def _test_base(f_checker: Callable, arg: Any, *args: Any, **kwargs: Any) -> 
None:
-    # Test converting to `Expr`
-    assert f_checker(arg)
-    # Test converting `*args`
-    assert isinstance(args, tuple)
-    assert all([f_checker(arg) for arg in args])
-    # Test converting `**kwargs`
-    assert isinstance(kwargs, dict)
-    assert all([f_checker(arg) for arg in kwargs.values()])
-
-
-def _test_expr(arg: Expr, *args: Expr, **kwargs: Expr) -> None:
-    f_checker = lambda x: isinstance(x, Expr)
-    _test_base(f_checker, arg, *args, **kwargs)
-
-
-def _test_optional_expr(arg: Expr | None, *args: Expr | None, **kwargs: Expr | 
None) -> None:
-    f_checker = lambda x: x is None or isinstance(x, Expr)
-    _test_base(f_checker, arg, *args, **kwargs)
-
-
-def _test_list_expr(arg: list[Expr], *args: list[Expr], **kwargs: list[Expr]) 
-> None:
-    f_checker = lambda x: isinstance(x, list) and all([isinstance(arg, Expr) 
for arg in x])
-    _test_base(f_checker, arg, *args, **kwargs)
-
-
-def _test_optional_list_expr(
-    arg: list[Expr] | None, *args: list[Expr] | None, **kwargs: list[Expr] | 
None
-) -> None:
-    f_checker = lambda x: (
-        x is None or (isinstance(x, list) and all([isinstance(arg, Expr) for 
arg in x]))
-    )
-    _test_base(f_checker, arg, *args, **kwargs)
-
-
-prim_value = 1
-str_value = "value_to_be_convert"
-shape_value = (1, 1)
-tuple_value = (relax.const(1), (1, 1))
-placeholder = relax.const(0)
-
-test_cases = [prim_value, str_value, shape_value, tuple_value, placeholder]
-
-
-def test_args_to_expr():
-    for _f in [_test_expr, _test_optional_expr]:
-        f = args_converter.to_expr("arg", "args", "kwargs")(_f)
-        for x in test_cases:
-            f(
-                x,
-                x,  # the first argument in *args
-                x,  # the second argument in *args
-                test_kwargs=x,
-            )
-
-            if _f == _test_optional_expr:
-                f(None, None, x, test_kwargs=None)
-
-
-def test_args_to_list_expr():
-    for _f in [_test_list_expr, _test_optional_list_expr]:
-        f = args_converter.to_list_expr("arg", "args", "kwargs")(_f)
-        for x in test_cases:
-            f(
-                [x],
-                [x],  # the first argument in *args
-                [x, x],  # the second argument in *args
-                test_kwargs=[x, (x,)],
-            )
-
-            if _f == _test_optional_list_expr:
-                f(None, None, [x], test_kwargs=None)
-
-
-def test_error():
-    f = args_converter.to_list_expr("arg", "args", "kwargs")(_test_list_expr)
-    with pytest.raises(TypeError):
-        f(prim_value)  # fail to convert prim_value to `List[Expr]`
-
-
-def test_auto_convert():
-    for _f in [_test_expr, _test_optional_expr]:
-        f = args_converter.auto(_f)
-        for x in test_cases:
-            f(x, (x,), test_kwargs=x)
-
-            if _f == _test_optional_expr:
-                f(None, x, test_kwargs=None)
-
-    for _f in [_test_list_expr, _test_optional_list_expr]:
-        f = args_converter.auto(_f)
-        for x in test_cases:
-            f([x], [x, x], test_kwargs=[x, (x,)])
-
-            if _f == _test_optional_list_expr:
-                f(None, None, [x], test_kwargs=None)
-
-
-def test_auto_convert_skip():
-    def _test_expr_skip(arg: int, *args: str | Expr, **kwargs: list[Expr | 
None]) -> None:
-        f_checker = lambda x: not isinstance(x, Expr)
-        _test_base(f_checker, arg, *args, **kwargs)
-
-    f = args_converter.auto(_test_expr_skip)
-    f(1, "str", test_kwargs=[None])
-
-
-def test_empty_tuple():
-    def _test(arg: Expr):
-        assert isinstance(arg, relax.Tuple)
-
-    f = args_converter.auto(_test)
-    f(())
-
-
-if __name__ == "__main__":
-    tvm.testing.main()

Reply via email to