This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2fd4e11194 [PYTHON] Fix PEP 563 compat and remove args_converter
(#18847)
2fd4e11194 is described below
commit 2fd4e111949ca60f6a76e09d185847b07a8f8db1
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Feb 28 10:41:10 2026 -0500
[PYTHON] Fix PEP 563 compat and remove args_converter (#18847)
---
python/tvm/relax/__init__.py | 2 -
python/tvm/relax/op/base.py | 25 ++-
python/tvm/relax/op/builtin/builtin.py | 5 +-
python/tvm/relax/op/distributed/distributed.py | 7 +-
python/tvm/relax/op/index.py | 8 +-
python/tvm/relax/op/memory/memory.py | 8 +-
python/tvm/relax/op/sampling.py | 2 -
python/tvm/relax/op/unary.py | 5 +-
python/tvm/relax/op/vm/vm.py | 8 +-
python/tvm/relax/type_converter.py | 179 ---------------------
python/tvm/relax/utils.py | 3 -
python/tvm/s_tir/schedule/_type_checker.py | 9 +-
.../tvm/script/ir_builder/relax/distributed/ir.py | 7 +-
python/tvm/script/ir_builder/relax/ir.py | 6 +-
tests/python/relax/test_expr_args_converter.py | 147 -----------------
15 files changed, 49 insertions(+), 372 deletions(-)
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index 313048e4d4..65c44db3ac 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -21,8 +21,6 @@
from tvm.runtime import vm
from tvm.runtime.vm import VirtualMachine, VMInstrumentReturnKind
-from .type_converter import args_converter
-
# Expr
from .expr import (
Expr,
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index 10a4eb76b3..28a7aa897a 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -27,7 +27,7 @@ from tvm.runtime.object import Object
from ...ir import PrimExpr
from ..expr import Call, Expr, ExternFunc, GlobalVar, ShapeExpr, StringImm, Var
from ..struct_info import StructInfo, TensorStructInfo
-from ..utils import args_converter
+from ..utils import convert_to_expr
from . import _ffi_api
py_print = print # pylint: disable=invalid-name
@@ -76,7 +76,9 @@ def _wrap_inline_arg_tuple(args) -> Expr:
in-line relax Tuple.
"""
- if (
+ if isinstance(args, tuple | list):
+ return tvm.relax.Tuple([convert_to_expr(a) for a in args])
+ elif (
isinstance(args, Expr)
and not isinstance(args, tvm.relax.Tuple)
and (
@@ -89,7 +91,6 @@ def _wrap_inline_arg_tuple(args) -> Expr:
return args
-@args_converter.auto
def call_tir(
gvar: GlobalVar,
args: Expr,
@@ -131,7 +132,6 @@ def call_tir(
return _ffi_api.call_tir(gvar, args, out_sinfo, tir_vars) # type: ignore
-@args_converter.auto
def call_tir_with_grad(
gvar: GlobalVar,
args: Expr,
@@ -190,7 +190,6 @@ def call_tir_with_grad(
)
-@args_converter.auto
def call_tir_inplace(
gvar: GlobalVar,
args: Expr,
@@ -261,7 +260,6 @@ def call_tir_inplace(
)
-@args_converter.auto
def call_dps_packed(
func: str | Expr,
args: Expr,
@@ -303,7 +301,6 @@ def call_dps_packed(
return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore
-@args_converter.auto
def call_py_func(
func_name: str,
args: Expr,
@@ -339,7 +336,6 @@ def call_py_func(
return _ffi_api.call_py_func(func_name, args, out_sinfo) # type: ignore
-@args_converter.auto
def call_builtin_with_ctx(
func: str | Expr,
args: Expr,
@@ -367,6 +363,8 @@ def call_builtin_with_ctx(
if isinstance(func, str):
func = ExternFunc(func)
+ args = _wrap_inline_arg_tuple(args)
+
if sinfo_args is not None and not isinstance(sinfo_args, list | tuple):
sinfo_args = [sinfo_args]
@@ -377,7 +375,6 @@ def call_builtin_with_ctx(
)
-@args_converter.auto
def make_closure(
func: Expr,
args: Expr,
@@ -400,10 +397,11 @@ def make_closure(
The VMClosure.
"""
+ args = _wrap_inline_arg_tuple(args)
+
return _ffi_api.make_closure(func, args) # type: ignore
-@args_converter.auto
def invoke_closure(
closure: Expr,
args: Expr,
@@ -428,6 +426,7 @@ def invoke_closure(
ret: Call
A call to `invoke_closure`.
"""
+ args = _wrap_inline_arg_tuple(args)
if not isinstance(sinfo_args, list | tuple):
sinfo_args = [sinfo_args]
@@ -677,7 +676,6 @@ def shape_to_tensor(expr: Expr) -> Expr:
return _ffi_api.shape_to_tensor(expr) # type: ignore # pylint:
disable=no-member
-@args_converter.auto
def call_inplace_packed(
func: str | ExternFunc | GlobalVar,
*args: Expr,
@@ -731,6 +729,7 @@ def call_inplace_packed(
func = func.global_symbol
op = ExternFunc(func)
+ args = tuple(convert_to_expr(a) for a in args)
if sinfo_args is None:
raise ValueError("R.call_pure_packed is required to have type_args")
if isinstance(sinfo_args, tuple): # type: ignore
@@ -743,7 +742,6 @@ def call_inplace_packed(
return _ffi_api.call_inplace_packed(op, args, inplace_indices, sinfo_args)
# type: ignore # pylint: disable=no-member
-@args_converter.auto
def call_pure_packed(
func: str | ExternFunc | GlobalVar,
*args: Expr,
@@ -782,6 +780,7 @@ def call_pure_packed(
func = func.global_symbol
op = ExternFunc(func)
+ args = tuple(convert_to_expr(a) for a in args)
if sinfo_args is None:
raise ValueError("R.call_pure_packed is required to have type_args")
@@ -807,7 +806,6 @@ def call_pure_packed(
return _ffi_api.call_pure_packed(op, args, None, sinfo_args) # type:
ignore # pylint: disable=no-member
-@args_converter.auto
def invoke_pure_closure(
closure: Expr,
args: Expr,
@@ -838,6 +836,7 @@ def invoke_pure_closure(
ret: Call
A call to `invoke_pure_closure`.
"""
+ args = _wrap_inline_arg_tuple(args)
if not isinstance(sinfo_args, list | tuple):
sinfo_args = [sinfo_args]
diff --git a/python/tvm/relax/op/builtin/builtin.py
b/python/tvm/relax/op/builtin/builtin.py
index 47fab9c7a5..328d63f3cc 100644
--- a/python/tvm/relax/op/builtin/builtin.py
+++ b/python/tvm/relax/op/builtin/builtin.py
@@ -16,11 +16,10 @@
"""The builtin Relax operators."""
from ...expr import Call, DataTypeImm, Expr, PrimValue, StringImm
-from ...utils import args_converter
+from ...utils import convert_to_expr
from . import _ffi_api
-@args_converter.auto
def alloc_tensor(
shape: Expr,
dtype: str | Expr,
@@ -49,6 +48,8 @@ def alloc_tensor(
result : Call
A relax Call, which gets the allocated tensor.
"""
+ if not isinstance(shape, Expr):
+ shape = convert_to_expr(shape)
if isinstance(dtype, str):
dtype = DataTypeImm(dtype)
if isinstance(runtime_device_index, int):
diff --git a/python/tvm/relax/op/distributed/distributed.py
b/python/tvm/relax/op/distributed/distributed.py
index b17708223d..b09f8686ac 100644
--- a/python/tvm/relax/op/distributed/distributed.py
+++ b/python/tvm/relax/op/distributed/distributed.py
@@ -20,10 +20,10 @@
from tvm.ir import PrimExpr
from tvm.relax.distributed import DTensorStructInfo
from tvm.relax.distributed.struct_info import DeviceMesh, Placement
-from tvm.relax.utils import args_converter
from ...expr import Call, Expr, GlobalVar, ShapeExpr
from ...expr import Tuple as RxTuple
+from ...utils import convert_to_expr
from . import _ffi_api
@@ -66,7 +66,6 @@ def redistribute(input: Expr, device_mesh: DeviceMesh,
placement: Placement) ->
return _ffi_api.redistribute(input, device_mesh, placement) # type: ignore
-@args_converter.auto
def call_tir_local_view(
gvar: GlobalVar,
args: Expr,
@@ -99,7 +98,9 @@ def call_tir_local_view(
ret: Call
A call node for the call_tir_local_view operator.
"""
- if isinstance(args, Expr) and not isinstance(args, RxTuple): # type:
ignore
+ if isinstance(args, tuple | list):
+ args = RxTuple([convert_to_expr(a) for a in args])
+ elif isinstance(args, Expr) and not isinstance(args, RxTuple): # type:
ignore
args = RxTuple((args,))
if not isinstance(out_sinfo, list):
diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py
index 1edb807db8..71cae89c3c 100644
--- a/python/tvm/relax/op/index.py
+++ b/python/tvm/relax/op/index.py
@@ -18,8 +18,8 @@
from tvm.ir.expr import PrimExpr
-from .. import args_converter
from ..expr import Expr
+from ..utils import convert_to_expr
from . import _ffi_api
PrimExprLike = int | PrimExpr
@@ -58,7 +58,6 @@ def take(x: Expr, indices: Expr, axis: int | None = None,
mode: str = "fast") ->
return _ffi_api.take(x, indices, axis, mode) # type: ignore
-@args_converter.auto
def strided_slice(
x: Expr,
axes: Expr,
@@ -101,6 +100,11 @@ def strided_slice(
strided_slice require the input `begin`, `end` and `strides` to have the
same length as `axes`.
"""
+ axes = convert_to_expr(axes)
+ begin = convert_to_expr(begin)
+ end = convert_to_expr(end)
+ if strides is not None:
+ strides = convert_to_expr(strides)
return _ffi_api.strided_slice(x, axes, begin, end, strides,
assume_inbound) # type: ignore
diff --git a/python/tvm/relax/op/memory/memory.py
b/python/tvm/relax/op/memory/memory.py
index ac8343c310..ba54eea9ef 100644
--- a/python/tvm/relax/op/memory/memory.py
+++ b/python/tvm/relax/op/memory/memory.py
@@ -16,11 +16,10 @@
"""Relax memory primitives."""
from ...expr import Call, DataTypeImm, Expr, PrimValue, StringImm
-from ...utils import args_converter
+from ...utils import convert_to_expr
from . import _ffi_api
-@args_converter.auto
def alloc_storage(
size: Expr,
virtual_device_index: int | Expr,
@@ -50,6 +49,7 @@ def alloc_storage(
result : Call
A relax Call, which gets the allocated storage.
"""
+ size = convert_to_expr(size)
if isinstance(dtype, str):
dtype = DataTypeImm(dtype)
if isinstance(storage_scope, str):
@@ -59,7 +59,6 @@ def alloc_storage(
return _ffi_api.alloc_storage(size, virtual_device_index, storage_scope,
dtype) # type: ignore
-@args_converter.auto
def alloc_tensor(
storage: Expr,
offset: int | Expr,
@@ -94,12 +93,12 @@ def alloc_tensor(
"""
if isinstance(offset, int):
offset = PrimValue(offset)
+ shape = convert_to_expr(shape)
if isinstance(dtype, str):
dtype = DataTypeImm(dtype)
return _ffi_api.alloc_tensor(storage, offset, shape, dtype,
runtime_device_ind) # type: ignore
-@args_converter.auto
def kill_storage(storage: Expr) -> Call:
"""Construct a Call to kill a storage.
@@ -116,7 +115,6 @@ def kill_storage(storage: Expr) -> Call:
return _ffi_api.kill_storage(storage) # type: ignore
-@args_converter.auto
def kill_tensor(tensor: Expr) -> Call:
"""Construct a Call to kill a tensor.
diff --git a/python/tvm/relax/op/sampling.py b/python/tvm/relax/op/sampling.py
index bcd43a3922..cd4cad9253 100644
--- a/python/tvm/relax/op/sampling.py
+++ b/python/tvm/relax/op/sampling.py
@@ -16,12 +16,10 @@
# under the License.
"""Sampling operators."""
-from .. import args_converter
from ..expr import Expr
from . import _ffi_api
-@args_converter.auto
def multinomial_from_uniform(
prob: Expr,
uniform_sample: Expr,
diff --git a/python/tvm/relax/op/unary.py b/python/tvm/relax/op/unary.py
index c77d8311d5..e7b800b5c9 100644
--- a/python/tvm/relax/op/unary.py
+++ b/python/tvm/relax/op/unary.py
@@ -18,7 +18,7 @@
"""Relax unary arithmetic operators."""
from ..expr import Expr
-from ..utils import args_converter
+from ..utils import convert_to_expr
from . import _ffi_api
###################### Arithmetic operators ######################
@@ -526,7 +526,6 @@ def trunc(x: Expr) -> Expr:
return _ffi_api.trunc(x) # type: ignore
-@args_converter.auto
def clip(x: Expr, min: Expr, max: Expr) -> Expr:
"""Clips tensor values to a specified min and max.
@@ -546,6 +545,8 @@ def clip(x: Expr, min: Expr, max: Expr) -> Expr:
result : relax.Expr
The computed result.
"""
+ min = convert_to_expr(min)
+ max = convert_to_expr(max)
return _ffi_api.clip(x, min, max) # type: ignore
diff --git a/python/tvm/relax/op/vm/vm.py b/python/tvm/relax/op/vm/vm.py
index de1015d08d..868ac92295 100644
--- a/python/tvm/relax/op/vm/vm.py
+++ b/python/tvm/relax/op/vm/vm.py
@@ -16,11 +16,10 @@
"""Relax vm primitives."""
from ...expr import Call, DataTypeImm, Expr, PrimValue, StringImm, Tuple
-from ...utils import args_converter
+from ...utils import convert_to_expr
from . import _ffi_api
-@args_converter.auto
def alloc_storage(
shape: Expr,
runtime_device_index: int | Expr,
@@ -50,6 +49,7 @@ def alloc_storage(
result : Call
A relax Call, which gets the allocated storage.
"""
+ shape = convert_to_expr(shape)
if isinstance(dtype, str):
dtype = DataTypeImm(dtype)
if isinstance(storage_scope, str):
@@ -59,7 +59,6 @@ def alloc_storage(
return _ffi_api.alloc_storage(shape, runtime_device_index, dtype,
storage_scope) # type: ignore
-@args_converter.auto
def alloc_tensor(
storage: Expr,
offset: int | Expr,
@@ -94,6 +93,7 @@ def alloc_tensor(
"""
if isinstance(offset, int):
offset = PrimValue(offset)
+ shape = convert_to_expr(shape)
if isinstance(dtype, str):
dtype = DataTypeImm(dtype)
return _ffi_api.alloc_tensor(storage, offset, shape, dtype,
runtime_device_ind) # type: ignore
@@ -116,7 +116,6 @@ def kill_object(obj: Expr) -> Call:
return _ffi_api.kill_object(obj) # type: ignore
-@args_converter.auto
def call_tir_dyn(func: Expr, args: Tuple) -> Call:
"""Construct a Call to call_tir_dyn (invoke the given TIR PrimFunc)
consisting of the input tensors and the shape of the result.
@@ -134,6 +133,7 @@ def call_tir_dyn(func: Expr, args: Tuple) -> Call:
result : Call
A relax Call to call_tir_dyn.
"""
+ func = convert_to_expr(func)
if isinstance(args, list | tuple):
args = Tuple(args)
diff --git a/python/tvm/relax/type_converter.py
b/python/tvm/relax/type_converter.py
deleted file mode 100644
index 50014cba5a..0000000000
--- a/python/tvm/relax/type_converter.py
+++ /dev/null
@@ -1,179 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: disable=invalid-name,too-many-locals
-
-"""Argument converter utility for Relax
-
-This utility is used to decorate constructors of `tvm.relax.Expr`, and
-must be able to be imported before `tvm.relax.Expr` or its subtypes
-have been defined. Neither the class definitions nor any type
-signature in this file may reference relax types. All references must
-be exclusively in function bodies to avoid having a circular reference
-during module imports.
-"""
-
-import functools
-import inspect
-from collections.abc import Callable
-from typing import Any, TypeVar
-
-import tvm
-
-FType = TypeVar("FType", bound=Callable[..., "tvm.relax.Expr"])
-
-
-class _ArgsConverter:
- """A helper class to convert the arguments to Expr."""
-
- @staticmethod
- def convert(args_to_expr: list[str], args_to_list_expr: list[str]):
- """Convert the arguments to Expr.
-
- Parameters
- ----------
- args_to_expr : List[str]
- The argument names to be converted to Expr.
-
- args_to_list_expr : List[str]
- The argument names to be converted to List[Expr].
-
- Returns
- -------
- output : Callable[[FType], FType]
- The decorator.
- """
-
- if any([x in args_to_list_expr for x in args_to_expr]):
- raise ValueError("`args_to_expr` and `args_to_list_expr` should be
disjoint.")
-
- def _convert(name: str, value: Any) -> Any:
- if value is None:
- return value
- if name in args_to_expr:
- try:
- return tvm.relax.utils.convert_to_expr(value)
- except Exception as err:
- raise TypeError(
- f"Argument `{name}` is expected to be converted to
`Expr`, "
- f"but failed with input value: {value}"
- ) from err
- elif name in args_to_list_expr:
- try:
- return [tvm.relax.utils.convert_to_expr(x) for x in value]
- except Exception as err:
- raise TypeError(
- f"Argument `{name}` is expected to be converted to
`List[Expr]`, "
- f"but failed with input value: {value}"
- ) from err
- else:
- return value
-
- def inner(func: FType) -> FType:
- sig = inspect.signature(func)
- param_names = list(sig.parameters.keys())
- for name in args_to_expr + args_to_list_expr:
- if name not in param_names:
- raise ValueError(f"Argument `{name}` is not found in
function signature.")
-
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- bound = sig.bind(*args, **kwargs)
- bound.apply_defaults()
- for param in sig.parameters.values():
- if param.kind == param.VAR_POSITIONAL:
- # *args case
- values = [_convert(param.name, x) for x in
bound.arguments[param.name]]
- bound.arguments[param.name] = tuple(values)
- elif param.kind == param.VAR_KEYWORD:
- # **kwargs case
- key_value = {
- key: _convert(param.name, value)
- for key, value in
bound.arguments[param.name].items()
- }
- bound.arguments[param.name] = key_value
- else:
- bound.arguments[param.name] = _convert(
- param.name, bound.arguments[param.name]
- )
- return func(*bound.args, **bound.kwargs)
-
- return wrapper # type: ignore
-
- return inner
-
- @staticmethod
- def to_expr(*arg_names: str) -> Callable:
- """Convert the arguments to Expr.
-
- Parameters
- ----------
- *arg_names: str
- The list of argument names that need to be converted to Expr.
-
- Returns
- -------
- output: Callable
- The decorator.
- """
-
- return _ArgsConverter.convert(args_to_expr=list(arg_names),
args_to_list_expr=[])
-
- @staticmethod
- def to_list_expr(*arg_names: str) -> Callable:
- """Convert the arguments to List of Expr.
-
- Parameters
- ----------
- *arg_names: str
- The list of argument names that need to be converted to List of
Expr.
-
- Returns
- -------
- output: Callable
- The decorator.
- """
-
- return _ArgsConverter.convert(args_to_expr=[],
args_to_list_expr=list(arg_names))
-
- @staticmethod
- def auto(func: FType) -> FType:
- """Decorator for automatically convert the arguments to Expr according
to type annotation.
- Only two patterns are supported:
-
- 1. The argument is Expr or Expr | None.
-
- 2. The argument is List[Expr] or Optional[List[Expr]].
-
- """
- sig = inspect.signature(func)
- args_to_expr = []
- args_to_list_expr = []
-
- from . import Expr # pylint: disable=import-outside-toplevel
-
- for param in sig.parameters.values():
- anno = param.annotation
- if anno in (Expr, Expr | None):
- args_to_expr.append(param.name)
- if anno in (list[Expr], list[Expr] | None):
- args_to_list_expr.append(param.name)
-
- return _ArgsConverter.convert(args_to_expr, args_to_list_expr)(func)
-
-
-args_converter = _ArgsConverter() # pylint: disable=invalid-name
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index f76459294a..75f62c525d 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -39,9 +39,6 @@ from .expr import Expr, Function, PrimValue, ShapeExpr,
StringImm, te_tensor
from .expr import Tuple as rx_Tuple
from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo
-# Re-export `args_converter` here for backwards compatibility
-from .type_converter import args_converter # pylint: disable=unused-import
-
def metadata_partitioner(rx_txt: str) -> list[str]:
"""Extract Relax program and metadata section.
diff --git a/python/tvm/s_tir/schedule/_type_checker.py
b/python/tvm/s_tir/schedule/_type_checker.py
index 57e612688b..5ad4f29c55 100644
--- a/python/tvm/s_tir/schedule/_type_checker.py
+++ b/python/tvm/s_tir/schedule/_type_checker.py
@@ -345,17 +345,22 @@ FType = TypeVar("FType", bound=Callable[..., Any])
def type_checked(func: FType) -> FType:
"""Type check the input arguments of a function."""
sig = inspect.signature(func)
+ try:
+ hints = typing.get_type_hints(func)
+ except Exception:
+ hints = {}
@functools.wraps(func)
def wrap(*args, **kwargs):
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
for param in sig.parameters.values():
- if param.annotation != inspect.Signature.empty:
+ type_hint = hints.get(param.name, inspect.Parameter.empty)
+ if type_hint != inspect.Parameter.empty:
error_msg = _type_check(
bound_args.arguments[param.name],
param.name,
- param.annotation,
+ type_hint,
)
if error_msg is not None:
error_msg = f'In "{func.__qualname__}", {error_msg}'
diff --git a/python/tvm/script/ir_builder/relax/distributed/ir.py
b/python/tvm/script/ir_builder/relax/distributed/ir.py
index 49bfb6139a..485e91e3e5 100644
--- a/python/tvm/script/ir_builder/relax/distributed/ir.py
+++ b/python/tvm/script/ir_builder/relax/distributed/ir.py
@@ -40,7 +40,7 @@ from tvm.relax.op.distributed import (
from tvm.relax.op.distributed import (
redistribute as _redistribute,
)
-from tvm.relax.utils import args_converter
+from tvm.relax.utils import convert_to_expr
from tvm.runtime import _tensor
from ... import IRBuilder
@@ -49,7 +49,6 @@ from ..ir import py_str
from . import _ffi_api
-@args_converter.auto
def call_tir(
func: str | Expr,
args: Expr,
@@ -82,7 +81,9 @@ def call_tir(
if isinstance(func, str):
func = ExternFunc(func)
- if isinstance(args, Expr) and not isinstance(args, RxTuple): # type:
ignore
+ if isinstance(args, tuple | list):
+ args = RxTuple([convert_to_expr(a) for a in args])
+ elif isinstance(args, Expr) and not isinstance(args, RxTuple): # type:
ignore
args = RxTuple((args,))
if not isinstance(out_sinfo, list):
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index a7428c32a3..a40517992e 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -198,7 +198,7 @@ from tvm.relax.op import (
)
from tvm.relax.op.builtin import stop_lift_params
from tvm.relax.struct_info import StructInfo
-from tvm.relax.utils import args_converter, gen_call_tir_inputs
+from tvm.relax.utils import convert_to_expr, gen_call_tir_inputs
from tvm.runtime import Object as tvm_Object
from tvm.runtime import ObjectConvertible
from tvm.runtime._tensor import (
@@ -403,7 +403,6 @@ def output(*vars: tuple[Var]) -> None:
################################## Ops #################################
-@args_converter.auto
def call_packed(
func: py_str,
*args: Expr,
@@ -428,6 +427,7 @@ def call_packed(
The created Relax Call
"""
op = ExternFunc(func)
+ args = py_tuple(convert_to_expr(a) for a in args)
if sinfo_args is None:
sinfo_args = []
if isinstance(sinfo_args, py_tuple): # type: ignore
@@ -460,7 +460,6 @@ def call_packed(
return Call(op, args, attrs=attrs, sinfo_args=sinfo_args)
-@args_converter.auto
def call_py_func(
py_func_name: py_str,
*args: Expr,
@@ -485,6 +484,7 @@ def call_py_func(
call: Call
The created Relax Call for call_py_func operator.
"""
+ args = py_tuple(convert_to_expr(a) for a in args)
if isinstance(out_sinfo, py_tuple): # type: ignore
out_sinfo = list(out_sinfo)
elif not isinstance(out_sinfo, list):
diff --git a/tests/python/relax/test_expr_args_converter.py
b/tests/python/relax/test_expr_args_converter.py
deleted file mode 100644
index d156245452..0000000000
--- a/tests/python/relax/test_expr_args_converter.py
+++ /dev/null
@@ -1,147 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# ruff: noqa: E731
-
-from collections.abc import Callable
-from typing import Any
-
-import pytest
-
-import tvm
-import tvm.testing
-from tvm import relax
-from tvm.relax import Expr
-from tvm.relax.utils import args_converter
-
-
-def _test_base(f_checker: Callable, arg: Any, *args: Any, **kwargs: Any) ->
None:
- # Test converting to `Expr`
- assert f_checker(arg)
- # Test converting `*args`
- assert isinstance(args, tuple)
- assert all([f_checker(arg) for arg in args])
- # Test converting `**kwargs`
- assert isinstance(kwargs, dict)
- assert all([f_checker(arg) for arg in kwargs.values()])
-
-
-def _test_expr(arg: Expr, *args: Expr, **kwargs: Expr) -> None:
- f_checker = lambda x: isinstance(x, Expr)
- _test_base(f_checker, arg, *args, **kwargs)
-
-
-def _test_optional_expr(arg: Expr | None, *args: Expr | None, **kwargs: Expr |
None) -> None:
- f_checker = lambda x: x is None or isinstance(x, Expr)
- _test_base(f_checker, arg, *args, **kwargs)
-
-
-def _test_list_expr(arg: list[Expr], *args: list[Expr], **kwargs: list[Expr])
-> None:
- f_checker = lambda x: isinstance(x, list) and all([isinstance(arg, Expr)
for arg in x])
- _test_base(f_checker, arg, *args, **kwargs)
-
-
-def _test_optional_list_expr(
- arg: list[Expr] | None, *args: list[Expr] | None, **kwargs: list[Expr] |
None
-) -> None:
- f_checker = lambda x: (
- x is None or (isinstance(x, list) and all([isinstance(arg, Expr) for
arg in x]))
- )
- _test_base(f_checker, arg, *args, **kwargs)
-
-
-prim_value = 1
-str_value = "value_to_be_convert"
-shape_value = (1, 1)
-tuple_value = (relax.const(1), (1, 1))
-placeholder = relax.const(0)
-
-test_cases = [prim_value, str_value, shape_value, tuple_value, placeholder]
-
-
-def test_args_to_expr():
- for _f in [_test_expr, _test_optional_expr]:
- f = args_converter.to_expr("arg", "args", "kwargs")(_f)
- for x in test_cases:
- f(
- x,
- x, # the first argument in *args
- x, # the second argument in *args
- test_kwargs=x,
- )
-
- if _f == _test_optional_expr:
- f(None, None, x, test_kwargs=None)
-
-
-def test_args_to_list_expr():
- for _f in [_test_list_expr, _test_optional_list_expr]:
- f = args_converter.to_list_expr("arg", "args", "kwargs")(_f)
- for x in test_cases:
- f(
- [x],
- [x], # the first argument in *args
- [x, x], # the second argument in *args
- test_kwargs=[x, (x,)],
- )
-
- if _f == _test_optional_list_expr:
- f(None, None, [x], test_kwargs=None)
-
-
-def test_error():
- f = args_converter.to_list_expr("arg", "args", "kwargs")(_test_list_expr)
- with pytest.raises(TypeError):
- f(prim_value) # fail to convert prim_value to `List[Expr]`
-
-
-def test_auto_convert():
- for _f in [_test_expr, _test_optional_expr]:
- f = args_converter.auto(_f)
- for x in test_cases:
- f(x, (x,), test_kwargs=x)
-
- if _f == _test_optional_expr:
- f(None, x, test_kwargs=None)
-
- for _f in [_test_list_expr, _test_optional_list_expr]:
- f = args_converter.auto(_f)
- for x in test_cases:
- f([x], [x, x], test_kwargs=[x, (x,)])
-
- if _f == _test_optional_list_expr:
- f(None, None, [x], test_kwargs=None)
-
-
-def test_auto_convert_skip():
- def _test_expr_skip(arg: int, *args: str | Expr, **kwargs: list[Expr |
None]) -> None:
- f_checker = lambda x: not isinstance(x, Expr)
- _test_base(f_checker, arg, *args, **kwargs)
-
- f = args_converter.auto(_test_expr_skip)
- f(1, "str", test_kwargs=[None])
-
-
-def test_empty_tuple():
- def _test(arg: Expr):
- assert isinstance(arg, relax.Tuple)
-
- f = args_converter.auto(_test)
- f(())
-
-
-if __name__ == "__main__":
- tvm.testing.main()