This is an automated email from the ASF dual-hosted git repository.
spectrometerHBH pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 96cba60464 [PYTHON] Autoload backends; simplify library loading;
remove TVMError for native errors (#19727)
96cba60464 is described below
commit 96cba60464ad8935dff9f0462e63140cb3182e4f
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Jun 11 13:50:38 2026 -0400
[PYTHON] Autoload backends; simplify library loading; remove TVMError for
native errors (#19727)
This PR adds an autoload mechanism for out-of-tree backends, simplifies
TVM's Python library loading, and removes `TVMError` in favor of native
Python errors.
## Autoload out-of-tree backends
Out-of-tree packages can register an autoload callable under the
`tvm.backends` entry-point group (mirroring torch's device-backend
autoload). At `import tvm` startup each entry point is discovered and
its callable invoked once, after the core runtime and the `tvm`
namespace are fully initialized, so an extension can register
ops/targets/funcs or load extra libraries.
```toml
[project.entry-points."tvm.backends"]
tvm_foo = "tvm_foo:_autoload"
```
A failing extension is caught and surfaced via `warnings.warn` so it
cannot break `import tvm`. Autoload can be disabled with
`TVM_DEVICE_BACKEND_AUTOLOAD=0`.
## Simplify library loading
The library-loading path in `base.py` is consolidated around a single
`_LOADED_LIBS` dict (basename to ctypes handle) so downstream and
autoloaded extensions can skip already-loaded libraries; the per-backend
runtime DSO list is folded into `load_backend_libs`. Accumulated cruft
is removed: the Python-3.9 check, the readline shim, the `_FFI_MODE`
ctypes check, the `base.__version__` re-export, and `py_str` (call sites
inline `.decode("utf-8")`).
## Remove TVMError in favor of native Python errors
`TVMError` added a layer atop `RuntimeError` that downstream code had to
import and learn. It is removed; the registered FFI error kinds
(`InternalError`, `RPCError`, `OpError`, `DiagnosticError`,
`ScheduleError`) now subclass `RuntimeError` directly while staying
registered, so the FFI keeps throwing the right kinds. All `TVMError`
imports, `except`/`raise`/`isinstance` uses, and
`pytest.raises(tvm.TVMError)` sites move to the `RuntimeError` builtin.
---
docs/contribute/error_handling.rst | 2 +-
docs/install/from_source.rst | 4 +-
python/tvm/__init__.py | 12 ++-
python/tvm/_autoload_backends.py | 50 ++++++++++++
python/tvm/base.py | 95 +++++++---------------
python/tvm/error.py | 12 +--
python/tvm/ir/base.py | 2 +-
python/tvm/ir/module.py | 2 +-
python/tvm/ir/utils.py | 2 +-
.../relax/backend/contrib/example_npu/README.md | 2 +-
.../relax/backend/contrib/example_npu/patterns.py | 11 ++-
python/tvm/relax/binding_rewrite.py | 2 +-
python/tvm/relax/distributed/struct_info.py | 3 +-
python/tvm/relax/expr.py | 4 +-
python/tvm/relax/frontend/onnx/onnx_frontend.py | 8 +-
python/tvm/relax/op/_op_gradient.py | 5 +-
python/tvm/relax/training/setup_trainer.py | 3 +-
python/tvm/relax/training/trainer.py | 6 +-
python/tvm/rpc/base.py | 4 +-
python/tvm/rpc/client.py | 5 +-
python/tvm/rpc/proxy.py | 3 +-
python/tvm/rpc/server.py | 5 +-
python/tvm/rpc/tracker.py | 3 +-
.../s_tir/meta_schedule/cost_model/mlp_model.py | 3 +-
python/tvm/s_tir/meta_schedule/utils.py | 3 +-
python/tvm/s_tir/schedule/schedule.py | 4 +-
python/tvm/script/parser/core/parser.py | 3 +-
python/tvm/support/cc.py | 15 ++--
python/tvm/support/clang.py | 3 +-
python/tvm/support/emcc.py | 3 +-
python/tvm/support/ndk.py | 7 +-
python/tvm/support/nvcc.py | 9 +-
python/tvm/support/rocm.py | 5 +-
python/tvm/support/tar.py | 5 +-
python/tvm/support/xcode.py | 5 +-
python/tvm/testing/utils.py | 3 +-
tests/python/arith/test_arith_analyzer_object.py | 2 +-
tests/python/arith/test_arith_rewrite_simplify.py | 8 +-
tests/python/arith/test_arith_simplify.py | 4 +-
tests/python/codegen/test_target_codegen.py | 8 +-
.../codegen/test_target_codegen_cross_llvm.py | 2 +-
tests/python/codegen/test_target_codegen_cuda.py | 2 +-
tests/python/codegen/test_target_codegen_llvm.py | 8 +-
.../test_hexagon/test_2d_physical_buffers.py | 2 +-
tests/python/contrib/test_hexagon/test_vtcm.py | 2 +-
tests/python/contrib/test_rpc_tracker.py | 2 +-
tests/python/ir/test_roundtrip_runtime_module.py | 1 -
.../distributed/test_distributed_dtensor_sinfo.py | 3 +-
.../relax/test_analysis_struct_info_analysis.py | 8 +-
.../test_analysis_suggest_layout_transforms.py | 4 +-
tests/python/relax/test_analysis_well_formed.py | 2 +-
tests/python/relax/test_bind_params.py | 4 +-
tests/python/relax/test_bind_symbolic_vars.py | 10 +--
tests/python/relax/test_binding_rewrite.py | 3 +-
tests/python/relax/test_blockbuilder_core.py | 2 +-
tests/python/relax/test_expr.py | 2 +-
.../python/relax/test_frontend_nn_extern_module.py | 7 +-
tests/python/relax/test_relax_operators.py | 3 +-
.../python/relax/test_runtime_builtin_rnn_state.py | 2 +-
tests/python/relax/test_struct_info.py | 6 +-
tests/python/relax/test_training_append_loss.py | 15 ++--
tests/python/relax/test_training_setup_trainer.py | 4 +-
.../python/relax/test_training_trainer_numeric.py | 6 +-
.../relax/test_transform_bind_symbolic_vars.py | 2 +-
.../relax/test_transform_fuse_ops_by_pattern.py | 2 +-
tests/python/relax/test_transform_fuse_tir.py | 2 +-
tests/python/relax/test_transform_gradient.py | 27 +++---
.../relax/test_transform_lift_transform_params.py | 6 +-
.../test_transform_static_plan_block_memory.py | 6 +-
tests/python/relax/test_vm_build.py | 2 +-
tests/python/relax/test_vm_builtin_lower.py | 2 +-
tests/python/relax/test_vm_cuda_graph.py | 2 +-
tests/python/relax/test_vm_execbuilder.py | 4 +-
tests/python/runtime/test_runtime_rpc.py | 2 +-
.../test_meta_schedule_post_order_apply.py | 3 +-
.../test_meta_schedule_postproc_rewrite_layout.py | 4 +-
.../test_meta_schedule_space_generator.py | 3 +-
.../s_tir/schedule/test_tir_schedule_rfactor.py | 2 +-
.../schedule/test_tir_schedule_transform_layout.py | 4 +-
...est_s_tir_transform_convert_blocks_to_opaque.py | 2 +-
...est_s_tir_transform_inject_software_pipeline.py | 2 +-
.../test_s_tir_transform_lower_match_buffer.py | 2 +-
.../transform/test_s_tir_transform_remove_undef.py | 5 +-
tests/python/target/test_target_target.py | 2 +-
tests/python/te/test_te_verify_compute.py | 13 ++-
tests/python/tirx-base/test_tir_base.py | 5 +-
tests/python/tirx-base/test_tir_constructor.py | 3 +-
tests/python/tirx-base/test_tir_imm_values.py | 33 ++++----
tests/python/tirx-base/test_tir_index_map.py | 2 +-
tests/python/tirx-base/test_tir_nodes.py | 22 ++---
tests/python/tirx-base/test_tir_ops.py | 2 +-
.../test_tir_unsafe_hide_buffer_access.py | 4 +-
...test_tir_transform_force_narrow_index_to_i32.py | 5 +-
.../test_tir_transform_lower_tvm_builtin.py | 4 +-
.../tirx-transform/test_tir_transform_vectorize.py | 2 +-
.../python/tvmscript/test_tvmscript_parser_tir.py | 2 +-
.../python/tvmscript/test_tvmscript_printer_ir.py | 4 +-
97 files changed, 301 insertions(+), 321 deletions(-)
diff --git a/docs/contribute/error_handling.rst
b/docs/contribute/error_handling.rst
index 754602c421..8d66a6cb74 100644
--- a/docs/contribute/error_handling.rst
+++ b/docs/contribute/error_handling.rst
@@ -40,7 +40,7 @@ Raise a Specific Error in C++
You can add ``<ErrorType>:`` prefix to your error message to
raise an error of the corresponding type.
Note that you do not have to add a new type
-:py:class:`tvm.error.TVMError` will be raised by default when
+:py:class:`tvm.error.InternalError` will be raised by default when
there is no error type prefix in the message.
This mechanism works for both ``LOG(FATAL)`` and ``TVM_FFI_ICHECK`` macros.
The following code gives an example on how to do so.
diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst
index a970bf5c1e..65e4f87e9e 100644
--- a/docs/install/from_source.rst
+++ b/docs/install/from_source.rst
@@ -193,8 +193,8 @@ Therefore, it is highly recommended to validate Apache TVM
installation before u
.. code-block:: bash
- >>> python -c "import tvm; print(tvm.base._LIB)"
- <CDLL '/some-path/lib/python3.11/site-packages/tvm/libtvm.dylib', handle
95ada510 at 0x1030e4e50>
+ >>> python -c "import tvm; print(tvm.base._LOADED_LIBS['tvm_runtime'])"
+ <CDLL '/some-path/lib/python3.11/site-packages/tvm/libtvm_runtime.dylib',
handle 95ada510 at 0x1030e4e50>
**Step 3. Reflect TVM build option.** Sometimes when downstream application
fails, it could likely be some mistakes with a wrong TVM commit, or wrong build
flags. To find it out, the following commands will be helpful:
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index e49e9fa0a4..bb171170fa 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -26,9 +26,9 @@ import os
from tvm_ffi import register_object, register_global_func, get_global_func
# top-level alias
-from .base import TVMError, __version__, _RUNTIME_ONLY
+from .libinfo import __version__
+from .base import _RUNTIME_ONLY
-# top-level alias
# tvm.runtime
from .runtime import Object
from .runtime._tensor import device, cpu, cuda, opencl, vulkan, metal
@@ -117,3 +117,11 @@ def tvm_wrap_excepthook(exception_hook):
sys.excepthook = tvm_wrap_excepthook(sys.excepthook)
+
+# Autoload out-of-tree backends registered under the ``tvm.backends`` entry
+# point group. Runs last, after the core runtime and the tvm namespace are
+# fully initialized, so an extension can safely register into ``tvm.*`` and
+# load extra libraries. Imported lazily here to avoid any import-cycle risk.
+from ._autoload_backends import _autoload_backends
+
+_autoload_backends()
diff --git a/python/tvm/_autoload_backends.py b/python/tvm/_autoload_backends.py
new file mode 100644
index 0000000000..b45ac59d9d
--- /dev/null
+++ b/python/tvm/_autoload_backends.py
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Autoload out-of-tree backends registered via ``tvm.backends`` entry points.
+
+Out-of-tree extensions opt into being loaded automatically at ``import tvm``
+time by declaring an entry point in the ``tvm.backends`` group::
+
+ [project.entry-points."tvm.backends"]
+ tvm_foo = "tvm_foo:_autoload"
+
+Autoload can be disabled via ``TVM_DEVICE_BACKEND_AUTOLOAD=0``.
+"""
+
+import os
+import warnings
+from importlib.metadata import entry_points
+
+# Guard so autoload runs at most once per process, even if invoked again.
+_AUTO_LOAD_DONE = False
+
+
+def _autoload_backends():
+ """Discover and invoke out-of-tree backends registered via entry points."""
+ global _AUTO_LOAD_DONE
+ if _AUTO_LOAD_DONE:
+ return
+ _AUTO_LOAD_DONE = True
+
+ if os.environ.get("TVM_DEVICE_BACKEND_AUTOLOAD", "1") == "0":
+ return
+
+ for entry_pt in entry_points(group="tvm.backends"):
+ try:
+ entry_pt.load()()
+ except Exception as e: # pylint: disable=broad-except
+ warnings.warn(f"Failed to autoload tvm backend '{entry_pt.name}':
{e}")
diff --git a/python/tvm/base.py b/python/tvm/base.py
index f3f42263b4..5c1e75566e 100644
--- a/python/tvm/base.py
+++ b/python/tvm/base.py
@@ -16,41 +16,42 @@
# under the License.
# coding: utf-8
# pylint: disable=invalid-name, import-outside-toplevel
-# ruff: noqa: F401
"""Base library for TVM."""
import os
-import sys
from pathlib import Path
from tvm_ffi.libinfo import load_lib_ctypes
from . import libinfo
-# ----------------------------
-# Python3 version.
-# ----------------------------
-if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 9):
- PY3STATEMENT = "The minimal Python requirement is Python 3.9"
- raise Exception(PY3STATEMENT)
-
# ----------------------------
# library loading
# ----------------------------
-# Known per-backend runtime DSOs that, when present, are loaded with
-# RTLD_GLOBAL so their static initializers register the device backend.
-_BACKEND_RUNTIME_LIBS = ["cuda", "vulkan", "opencl", "metal", "rocm",
"hexagon", "extra"]
+# Whether only the runtime library is loaded (runtime-only wheel, or
+# ``TVM_USE_RUNTIME_LIB=1``). Set during library loading below.
+_RUNTIME_ONLY = os.environ.get("TVM_USE_RUNTIME_LIB") == "1"
+
+# Handles of the core libraries actually loaded, keyed by basename
+# (e.g. ``{"tvm_runtime": <CDLL>, "tvm_compiler": <CDLL>}``). Downstream /
+# autoloaded extensions can inspect this to skip duplicate libraries
+# (``"tvm_runtime" in _LOADED_LIBS``) and obtain the loaded handle.
+_LOADED_LIBS = {}
def load_backend_libs(runtime_lib_path: str) -> None:
- """Try to load each known backend runtime DSO; failures are silent."""
+ """Load each known backend runtime DSO into ``_LOADED_LIBS``; failures are
silent."""
+ # Known per-backend runtime DSOs that, when present, are loaded with
+ # RTLD_GLOBAL so their static initializers register the device backend.
+ backend_runtime_libs = ["cuda", "vulkan", "opencl", "metal", "rocm",
"hexagon", "extra"]
runtime_dir = Path(runtime_lib_path).resolve().parent
- for backend in _BACKEND_RUNTIME_LIBS:
+ for backend in backend_runtime_libs:
+ target_name = f"tvm_runtime_{backend}"
try:
- load_lib_ctypes(
+ _LOADED_LIBS[target_name] = load_lib_ctypes(
package="tvm",
- target_name=f"tvm_runtime_{backend}",
+ target_name=target_name,
mode="RTLD_GLOBAL",
extra_lib_paths=[runtime_dir],
)
@@ -58,68 +59,28 @@ def load_backend_libs(runtime_lib_path: str) -> None:
pass
-# The TVM C++ side is split into two shared libraries:
-#
-# - ``libtvm_runtime`` — runtime-only sources. Loaded with ``RTLD_GLOBAL`` so
-# its symbols are exposed to subsequent loads (NVRTC kernels, downstream
-# modules and so on resolve runtime symbols at link time).
-# - ``libtvm_compiler`` — compiler / IR / transform sources, links against
-# ``libtvm_runtime``. Loaded with ``RTLD_LOCAL`` so compiler internals
-# don't leak into the global symbol namespace.
-#
-# If the environment variable ``TVM_USE_RUNTIME_LIB`` is set to ``"1"``, or
-# the compiler library is simply not present (runtime-only wheel), only the
-# runtime is loaded and ``_LIB`` aliases ``_LIB_RUNTIME``.
-_extra_lib_paths = libinfo.package_lib_paths()
-_LIB_RUNTIME = load_lib_ctypes(
- "tvm", "tvm_runtime", "RTLD_GLOBAL", extra_lib_paths=_extra_lib_paths
+# runtime is loaded RTLD_GLOBAL to expose its symbols to subsequent loads;
+# compiler is loaded RTLD_LOCAL.
+_LOADED_LIBS["tvm_runtime"] = load_lib_ctypes(
+ "tvm", "tvm_runtime", "RTLD_GLOBAL",
extra_lib_paths=libinfo.package_lib_paths()
)
# After libtvm_runtime.so is in the global symbol namespace, scan the same
# directory for per-backend DSOs (libtvm_runtime_cuda.so, etc.) and load each
# with RTLD_GLOBAL so their static initializers register device backends.
-# Failures are swallowed silently — a missing driver just means that backend
-# is unavailable, not an error.
-load_backend_libs(_LIB_RUNTIME._name)
+load_backend_libs(_LOADED_LIBS["tvm_runtime"]._name)
-_RUNTIME_ONLY = os.environ.get("TVM_USE_RUNTIME_LIB") == "1"
-if _RUNTIME_ONLY:
- _LIB = _LIB_RUNTIME
-else:
+if not _RUNTIME_ONLY:
try:
- _LIB = load_lib_ctypes(
- "tvm", "tvm_compiler", "RTLD_LOCAL",
extra_lib_paths=_extra_lib_paths
+ _LOADED_LIBS["tvm_compiler"] = load_lib_ctypes(
+ "tvm", "tvm_compiler", "RTLD_LOCAL",
extra_lib_paths=libinfo.package_lib_paths()
)
- except RuntimeError:
- # Compiler lib not present — fall back to runtime-only mode.
- _LIB = _LIB_RUNTIME
+ except (RuntimeError, OSError):
+ # Compiler lib not present, or present but unloadable (missing LLVM
+ # deps / linker issues) — fall back to runtime-only mode.
_RUNTIME_ONLY = True
-
-try:
- # The following import is needed for TVM to work with pdb
- import readline # pylint: disable=unused-import
-except ImportError:
- pass
-
-# version number
-__version__ = libinfo.__version__
-
-
if _RUNTIME_ONLY:
from tvm_ffi import registry as _tvm_ffi_registry
_tvm_ffi_registry._SKIP_UNKNOWN_OBJECTS = True
-
-# The FFI mode of TVM
-_FFI_MODE = os.environ.get("TVM_FFI", "auto")
-
-if _FFI_MODE == "ctypes":
- raise ImportError("We have phased out ctypes support in favor of cython on
wards")
-
-
-def py_str(x):
- return x.decode("utf-8")
-
-
-TVMError = Exception
diff --git a/python/tvm/error.py b/python/tvm/error.py
index 180805432a..e1a5fdd9b2 100644
--- a/python/tvm/error.py
+++ b/python/tvm/error.py
@@ -29,12 +29,8 @@ copy the examples and raise errors with the same message
convention.
from tvm_ffi import register_error
-class TVMError(RuntimeError):
- pass
-
-
@register_error
-class InternalError(TVMError):
+class InternalError(RuntimeError):
"""Internal error in the system.
Examples
@@ -52,7 +48,7 @@ class InternalError(TVMError):
@register_error
-class RPCError(TVMError):
+class RPCError(RuntimeError):
"""Error thrown by the remote server handling the RPC call."""
@@ -62,7 +58,7 @@ class RPCSessionTimeoutError(RPCError, TimeoutError):
@register_error
-class OpError(TVMError):
+class OpError(RuntimeError):
"""Base class of all operator errors in frontends."""
@@ -123,7 +119,7 @@ class OpAttributeUnImplemented(OpError,
NotImplementedError):
@register_error
-class DiagnosticError(TVMError):
+class DiagnosticError(RuntimeError):
"""Error diagnostics were reported during the execution of a pass.
See the configured diagnostic renderer for detailed error information.
diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py
index 6bae30f791..e593783ac0 100644
--- a/python/tvm/ir/base.py
+++ b/python/tvm/ir/base.py
@@ -22,7 +22,7 @@ from tvm_ffi.serialization import from_json_graph_str,
to_json_graph_str
from tvm.runtime import Object
-from ..base import __version__
+from ..libinfo import __version__
from . import _ffi_api, json_compact
diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index 95b9d940ec..a883158d35 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -181,7 +181,7 @@ class IRModule(Node, Scriptable):
Raises
------
- tvm.error.TVMError if we cannot find corresponding global var.
+ RuntimeError if we cannot find corresponding global var.
"""
return _ffi_api.Module_GetGlobalVar(self, name)
diff --git a/python/tvm/ir/utils.py b/python/tvm/ir/utils.py
index 202df8b94d..e5d012220f 100644
--- a/python/tvm/ir/utils.py
+++ b/python/tvm/ir/utils.py
@@ -75,7 +75,7 @@ def derived_object(cls: type[T]) -> type[T]:
return method
# for task scheduler return None means calling default function
- # otherwise it will trigger a TVMError of method not implemented
+ # otherwise it will trigger a RuntimeError of method not implemented
# on the c++ side when you call the method
return None
diff --git a/python/tvm/relax/backend/contrib/example_npu/README.md
b/python/tvm/relax/backend/contrib/example_npu/README.md
index 0b5119f80b..7e88a0ece0 100644
--- a/python/tvm/relax/backend/contrib/example_npu/README.md
+++ b/python/tvm/relax/backend/contrib/example_npu/README.md
@@ -235,7 +235,7 @@ This shows the registered patterns and that matched
subgraphs were turned into c
- **Layout preferences**: NHWC channel-last layouts preferred by NPUs
### Error Handling
-- **Robust exception handling**: Uses specific `TVMError` instead of generic
exceptions
+- **Robust exception handling**: Catches specific exception types instead of
generic exceptions
- **Comprehensive testing**: Validates both successful cases and error
conditions
## Learn More
diff --git a/python/tvm/relax/backend/contrib/example_npu/patterns.py
b/python/tvm/relax/backend/contrib/example_npu/patterns.py
index d224388fa3..17d6656ef8 100644
--- a/python/tvm/relax/backend/contrib/example_npu/patterns.py
+++ b/python/tvm/relax/backend/contrib/example_npu/patterns.py
@@ -24,7 +24,6 @@ tiling, and fusion strategies.
from typing import ClassVar
-from tvm import TVMError
from tvm.ir import Op
from tvm.relax.dpl.pattern import is_op, wildcard
from tvm.relax.transform import PatternCheckContext
@@ -372,7 +371,7 @@ def softmax_patterns():
try:
Op.get("relax.nn.softmax")
patterns.append(("example_npu.softmax", *_make_softmax_pattern(),
_check_softmax))
- except TVMError: # pylint: disable=broad-exception-caught
+ except (KeyError, AttributeError):
pass
return patterns
@@ -415,7 +414,7 @@ def activation_patterns():
for pattern_name, op_name in activations:
try:
Op.get(op_name)
- except TVMError: # pylint: disable=broad-exception-caught
+ except (KeyError, AttributeError):
continue
pattern_fn = _make_activation_pattern(op_name)
@@ -456,7 +455,7 @@ def elementwise_patterns():
for op in ops:
try:
Op.get(op)
- except TVMError: # pylint: disable=broad-exception-caught
+ except (KeyError, AttributeError):
continue
op_short = op.split(".")[-1]
@@ -505,7 +504,7 @@ def quantization_patterns():
try:
Op.get("relax.quantize")
patterns.append(("example_npu.quantize", *_make_quantize_pattern(),
_check_quantization))
- except TVMError: # pylint: disable=broad-exception-caught
+ except (KeyError, AttributeError):
pass
try:
@@ -513,7 +512,7 @@ def quantization_patterns():
patterns.append(
("example_npu.dequantize", *_make_dequantize_pattern(),
_check_quantization)
)
- except TVMError: # pylint: disable=broad-exception-caught
+ except (KeyError, AttributeError):
pass
return patterns
diff --git a/python/tvm/relax/binding_rewrite.py
b/python/tvm/relax/binding_rewrite.py
index 0198befd33..bd46029243 100644
--- a/python/tvm/relax/binding_rewrite.py
+++ b/python/tvm/relax/binding_rewrite.py
@@ -109,7 +109,7 @@ class DataflowBlockRewrite(Object):
Raises
------
- TVMError if the variable is used or undefined (allow_undef=False).
+ RuntimeError if the variable is used or undefined (allow_undef=False).
"""
_ffi_api.dfb_rewrite_remove_unused(self, var, allow_undef) # type:
ignore
diff --git a/python/tvm/relax/distributed/struct_info.py
b/python/tvm/relax/distributed/struct_info.py
index 39e61f4adc..0c94d3dda9 100644
--- a/python/tvm/relax/distributed/struct_info.py
+++ b/python/tvm/relax/distributed/struct_info.py
@@ -21,7 +21,6 @@ import enum
import tvm_ffi
-from tvm import TVMError
from tvm.ir import Span
from tvm.relax.struct_info import StructInfo, TensorStructInfo
from tvm.runtime import Object
@@ -52,7 +51,7 @@ class PlacementSpec(Object):
kind: PlacementSpecKind
def __init__(self, *args, **kwargs):
- raise TVMError("PlacementSpec is not intended to be constructed
directly, ")
+ raise RuntimeError("PlacementSpec is not intended to be constructed
directly, ")
@staticmethod
def sharding(axis: int) -> "PlacementSpec":
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 730febc51c..734d7bde67 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -236,11 +236,11 @@ class ExprWithOp(Expr, Scriptable):
"""
try:
return TupleGetItem(self, index)
- except tvm.TVMError as err:
+ except RuntimeError as err:
# For Python objects with __getitem__, but without
# __len__, tuple unpacking is done by iterating over
# sequential indices until IndexError is raised.
- # Therefore, convert from TVMError to IndexError for
+ # Therefore, convert from RuntimeError to IndexError for
# compatibility.
if "Index out of bounds" in err.args[0]:
raise IndexError from err
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 31971b30ab..3d9dfba9a1 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -57,7 +57,7 @@ except ImportError as err:
import tvm_ffi
import tvm
-from tvm import TVMError, relax, tirx, topi
+from tvm import relax, tirx, topi
from tvm.ir import IRModule
from tvm.ir.supply import NameSupply
from tvm.runtime import DataType, DataTypeCode
@@ -71,7 +71,7 @@ def _relax_dtype_is_floating_point(dtype: str) -> bool:
"""Whether a Relax dtype string is a floating point type."""
try:
code = DataType(dtype).type_code
- except (ValueError, TypeError, TVMError):
+ except (ValueError, TypeError, RuntimeError):
return False
return (
code == DataTypeCode.FLOAT
@@ -537,7 +537,7 @@ class Div(BinaryBase):
try:
lhs_code = DataType(inputs[0].struct_info.dtype).type_code
rhs_code = DataType(inputs[1].struct_info.dtype).type_code
- except (AttributeError, ValueError, TypeError, TVMError):
+ except (AttributeError, ValueError, TypeError, RuntimeError):
return cls.base_impl(bb, inputs, attr, params)
lhs_is_integer = lhs_code == DataTypeCode.INT or lhs_code ==
DataTypeCode.UINT
@@ -5576,7 +5576,7 @@ class ONNXGraphImporter:
# Create struct information for the new operator.
if isinstance(op, relax.Expr):
op = self.bb.normalize(op)
- except TVMError as err:
+ except Exception as err: # pylint: disable=broad-exception-caught
print(f"Error converting operator {op_name}, with inputs:
{inputs}")
raise err
diff --git a/python/tvm/relax/op/_op_gradient.py
b/python/tvm/relax/op/_op_gradient.py
index 377874f705..2d949ee276 100644
--- a/python/tvm/relax/op/_op_gradient.py
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -22,7 +22,6 @@ import operator
from tvm import relax
from tvm.arith import Analyzer
-from tvm.base import TVMError
from tvm.relax.struct_info import ShapeStructInfo
from ...tirx import PrimExpr
@@ -68,7 +67,7 @@ def _get_shape(expr: Expr) -> ShapeExpr:
try:
shape = expr.struct_info.shape
except Exception as error:
- raise TVMError(
+ raise RuntimeError(
f"Get the shape of {expr} failed. Please normalize it first and
ensure it is a Tensor."
) from error
return shape
@@ -79,7 +78,7 @@ def _get_dtype(expr: Expr) -> str:
try:
dtype = expr.struct_info.dtype
except Exception as error:
- raise TVMError(
+ raise RuntimeError(
f"Get the dtype of {expr} failed. Please normalize it first and
ensure it is a Tensor."
) from error
return dtype
diff --git a/python/tvm/relax/training/setup_trainer.py
b/python/tvm/relax/training/setup_trainer.py
index 99fa6fedec..03ae52a61f 100644
--- a/python/tvm/relax/training/setup_trainer.py
+++ b/python/tvm/relax/training/setup_trainer.py
@@ -18,7 +18,6 @@
"""Setup Trainer Pass."""
import tvm
-from tvm import TVMError
from tvm.ir.module import IRModule
from tvm.tirx.expr import IntImm
@@ -132,7 +131,7 @@ class SetupTrainer:
raise ValueError("SetupTrainer: The backbone module is not well
formed.")
try:
func = mod[self.BACKBONE_FUNC]
- except TVMError as exc:
+ except (KeyError, ValueError) as exc:
raise ValueError(
f"SetupTrainer: The backbone module does not contain a
function named "
f"{self.BACKBONE_FUNC}"
diff --git a/python/tvm/relax/training/trainer.py
b/python/tvm/relax/training/trainer.py
index 36c6992e89..d7bdea3e1f 100644
--- a/python/tvm/relax/training/trainer.py
+++ b/python/tvm/relax/training/trainer.py
@@ -20,7 +20,7 @@
import numpy as np # type: ignore
import tvm
-from tvm import TVMError, relax
+from tvm import relax
from tvm.ir.module import IRModule
from tvm.runtime._tensor import Tensor
@@ -243,14 +243,14 @@ class Trainer:
"""Check that all parameters and model states are initialized."""
idx_not_inited_param = next((i for i, p in enumerate(self._params) if
p is None), -1)
if idx_not_inited_param != -1:
- raise TVMError(
+ raise RuntimeError(
f"The {idx_not_inited_param}-th parameter is not initialized
before training or "
"inference."
)
idx_not_inited_state = next((i for i, s in enumerate(self._states) if
s is None), -1)
if idx_not_inited_state != -1:
- raise TVMError(
+ raise RuntimeError(
f"The {idx_not_inited_state}-th model state is not initialized
before training or "
"inference."
)
diff --git a/python/tvm/rpc/base.py b/python/tvm/rpc/base.py
index 28379d742d..dd1a33fc9e 100644
--- a/python/tvm/rpc/base.py
+++ b/python/tvm/rpc/base.py
@@ -25,8 +25,6 @@ import socket
import struct
import time
-from ..base import py_str
-
# Magic header for RPC data plane
RPC_MAGIC = 0xFF271
# magic header for RPC tracker(control plane)
@@ -117,7 +115,7 @@ def recvjson(sock):
The value received.
"""
size = struct.unpack("<i", recvall(sock, 4))[0]
- data = json.loads(py_str(recvall(sock, size)))
+ data = json.loads((recvall(sock, size)).decode("utf-8"))
return data
diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py
index 57ef8f1842..7dde53cc2f 100644
--- a/python/tvm/rpc/client.py
+++ b/python/tvm/rpc/client.py
@@ -28,7 +28,6 @@ import tvm_ffi
from tvm_ffi import DLDeviceType
import tvm.runtime
-from tvm.base import TVMError
from tvm.support import utils
from . import _ffi_api, base, server
@@ -433,7 +432,7 @@ class TrackerSession:
except OSError as err:
self.close()
last_err = err
- except TVMError as err:
+ except RuntimeError as err:
last_err = err
raise RuntimeError(f"Cannot request {key} after {max_retry} retry,
last_error:{last_err!s}")
@@ -468,7 +467,7 @@ class TrackerSession:
sess = self.request(key, priority=priority,
session_timeout=session_timeout)
tstart = time.time()
return func(sess)
- except TVMError as err:
+ except RuntimeError as err:
duration = time.time() - tstart
# roughly estimate if the error is due to timeout termination
if session_timeout and duration >= session_timeout * 0.95:
diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py
index bfaea7c69c..575580997e 100644
--- a/python/tvm/rpc/proxy.py
+++ b/python/tvm/rpc/proxy.py
@@ -45,7 +45,6 @@ except ImportError as error_msg:
from tvm.support.popen_pool import PopenWorker
-from ..base import py_str
from . import _ffi_api, base
from .base import TrackerCode
from .server import _server_env
@@ -90,7 +89,7 @@ class ForwardHandler:
self._init_req_nbytes = self._rpc_key_length
elif self.rpc_key is None:
assert len(message) == self._rpc_key_length
- self.rpc_key = py_str(message)
+ self.rpc_key = message.decode("utf-8")
# match key is used to do the matching
self.match_key = self.rpc_key[7:].split()[0]
self.on_start()
diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py
index 4a28da6726..963486eb90 100644
--- a/python/tvm/rpc/server.py
+++ b/python/tvm/rpc/server.py
@@ -42,7 +42,6 @@ from pathlib import Path
import tvm_ffi
-from tvm.base import py_str
from tvm.runtime.module import load_module as _load_module
from tvm.support import utils
from tvm.support.popen_pool import PopenWorker
@@ -243,7 +242,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr,
load_library, custom_addr):
conn.close()
continue
keylen = struct.unpack("<i", base.recvall(conn, 4))[0]
- key = py_str(base.recvall(conn, keylen))
+ key = (base.recvall(conn, keylen)).decode("utf-8")
arr = key.split()
expect_header = "client:" + matchkey
server_key = "server:" + rpc_key
@@ -309,7 +308,7 @@ def _connect_proxy_loop(addr, key, load_library):
elif magic != base.RPC_CODE_SUCCESS:
raise RuntimeError(f"{addr!s} is not RPC Proxy")
keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
- remote_key = py_str(base.recvall(sock, keylen))
+ remote_key = (base.recvall(sock, keylen)).decode("utf-8")
_serving(sock, addr, _parse_server_opt(remote_key.split()[1:]),
load_library)
retry_count = 0
diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py
index 07e0a4302b..b983112a33 100644
--- a/python/tvm/rpc/tracker.py
+++ b/python/tvm/rpc/tracker.py
@@ -62,7 +62,6 @@ except ImportError as error_msg:
f"RPCTracker module requires tornado package {error_msg}. Try 'pip
install tornado'."
)
-from ..base import py_str
from . import base
from .base import RPC_TRACKER_MAGIC, TrackerCode
@@ -242,7 +241,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
else:
return
if self._msg_size != 0 and len(self._data) >= self._msg_size:
- msg = py_str(bytes(self._data[: self._msg_size]))
+ msg = (bytes(self._data[: self._msg_size])).decode("utf-8")
del self._data[: self._msg_size]
self._msg_size = 0
try:
diff --git a/python/tvm/s_tir/meta_schedule/cost_model/mlp_model.py
b/python/tvm/s_tir/meta_schedule/cost_model/mlp_model.py
index a9bb7c784d..52bc9ce99d 100644
--- a/python/tvm/s_tir/meta_schedule/cost_model/mlp_model.py
+++ b/python/tvm/s_tir/meta_schedule/cost_model/mlp_model.py
@@ -31,7 +31,6 @@ from typing import NamedTuple
import numpy as np # type: ignore
import torch # type: ignore
-import tvm
from tvm.ir.utils import derived_object
from tvm.support.tar import tar, untar
@@ -545,7 +544,7 @@ class State:
"_workload.json", "_candidates.json"
),
)
- except tvm.base.TVMError:
+ except (ValueError, RuntimeError):
continue
candidates, results = [], []
tuning_records = database.get_all_tuning_records()
diff --git a/python/tvm/s_tir/meta_schedule/utils.py
b/python/tvm/s_tir/meta_schedule/utils.py
index d80fbf52b1..fdadf4a301 100644
--- a/python/tvm/s_tir/meta_schedule/utils.py
+++ b/python/tvm/s_tir/meta_schedule/utils.py
@@ -24,7 +24,6 @@ from typing import Any
import numpy as np # type: ignore
from tvm_ffi import Array, Function, Map, get_global_func, register_global_func
-from tvm.error import TVMError
from tvm.ir import IRModule
from tvm.rpc import RPCSession
from tvm.tirx import FloatImm, IntImm
@@ -159,7 +158,7 @@ def get_global_func_with_default_on_worker(
return name
try:
return get_global_func(name)
- except TVMError as error:
+ except (ValueError, RuntimeError) as error:
raise ValueError(
"Function '{name}' is not registered on the worker process. "
"The build function and export function should be registered in
the worker process. "
diff --git a/python/tvm/s_tir/schedule/schedule.py
b/python/tvm/s_tir/schedule/schedule.py
index 4e9dc70941..7f191df98d 100644
--- a/python/tvm/s_tir/schedule/schedule.py
+++ b/python/tvm/s_tir/schedule/schedule.py
@@ -22,7 +22,7 @@ from typing import Literal
from tvm_ffi import register_object as _register_object
-from tvm.error import TVMError, register_error
+from tvm.error import register_error
from tvm.ir import GlobalVar, IRModule, PrimExpr
from tvm.runtime import Object
from tvm.tirx import Buffer, FloatImm, For, IntImm, PrimFunc, SBlock
@@ -35,7 +35,7 @@ from .trace import Trace
@register_error
-class ScheduleError(TVMError):
+class ScheduleError(RuntimeError):
"""Error that happens during TensorIR scheduling."""
diff --git a/python/tvm/script/parser/core/parser.py
b/python/tvm/script/parser/core/parser.py
index 99c01b1641..f776e66cce 100644
--- a/python/tvm/script/parser/core/parser.py
+++ b/python/tvm/script/parser/core/parser.py
@@ -25,7 +25,6 @@ from typing import Any
import numpy as np
-from tvm.base import TVMError
from tvm.error import DiagnosticError
from tvm.ir import GlobalVar
@@ -616,7 +615,7 @@ class Parser(doc.NodeVisitor):
raise err
# Only take the last line of the error message
- if isinstance(err, TVMError):
+ if isinstance(err, RuntimeError):
lines = list(filter(None, str(err).split("\n")))
msg = lines[-1] if lines else (str(err) or type(err).__name__)
elif isinstance(err, KeyError):
diff --git a/python/tvm/support/cc.py b/python/tvm/support/cc.py
index 85efd46fb1..3bb4f9f41f 100644
--- a/python/tvm/support/cc.py
+++ b/python/tvm/support/cc.py
@@ -23,7 +23,6 @@ import subprocess
# pylint: disable=invalid-name
import sys
-from ..base import py_str
from . import tar as _tar
from . import utils as _utils
@@ -117,7 +116,7 @@ def _linux_ar(output, inputs, ar):
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "AR error:\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
msg += "\nCommand line: " + " ".join(cmd)
raise RuntimeError(msg)
@@ -211,10 +210,10 @@ def get_global_symbol_section_map(path, *, nm=None) ->
dict[str, str]:
if proc.returncode != 0:
msg = "Runtime error:\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
raise RuntimeError(msg)
- for line in py_str(out).split("\n"):
+ for line in out.decode("utf-8", errors="replace").split("\n"):
data = line.strip().split()
if len(data) != 3:
continue
@@ -246,9 +245,9 @@ def get_target_by_dump_machine(compiler):
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "dumpmachine error:\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
return None
- return py_str(out)
+ return out.decode("utf-8", errors="replace")
return None
return get_target_triple
@@ -367,7 +366,7 @@ def _linux_compile(
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Compilation error:\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
msg += "\nCommand line: " + " ".join(cmd)
raise RuntimeError(msg)
@@ -414,6 +413,6 @@ def _windows_compile(output, objects, options, cwd=None,
ccache_env=None):
if proc.returncode != 0:
msg = "Compilation error:\n"
msg += " ".join(cmd) + "\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
raise RuntimeError(msg)
diff --git a/python/tvm/support/clang.py b/python/tvm/support/clang.py
index fa7bef52ef..ff8e998ee8 100644
--- a/python/tvm/support/clang.py
+++ b/python/tvm/support/clang.py
@@ -20,7 +20,6 @@
import subprocess
import tvm.target
-from tvm.base import py_str
from . import utils
@@ -105,7 +104,7 @@ def create_llvm(inputs, output=None, options=None, cc=None):
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Compilation error:\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
raise RuntimeError(msg)
return open(output).read()
diff --git a/python/tvm/support/emcc.py b/python/tvm/support/emcc.py
index 9bd6d24036..f1455c580b 100644
--- a/python/tvm/support/emcc.py
+++ b/python/tvm/support/emcc.py
@@ -22,7 +22,6 @@ import subprocess
from pathlib import Path
from tvm import libinfo
-from tvm.base import py_str
def find_wasm_lib(name, optional=False):
@@ -139,7 +138,7 @@ def create_tvmjs_wasm(output, objects, options=None,
cc="emcc", libs=None):
if proc.returncode != 0:
msg = "Compilation error:\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
raise RuntimeError(msg)
diff --git a/python/tvm/support/ndk.py b/python/tvm/support/ndk.py
index 5c2de723cc..85b6ecb92f 100644
--- a/python/tvm/support/ndk.py
+++ b/python/tvm/support/ndk.py
@@ -25,7 +25,6 @@ from pathlib import Path
from tvm_ffi import register_global_func
-from ..base import py_str
from . import cc as _cc
from . import tar as _tar
from . import utils as _utils
@@ -67,7 +66,7 @@ def create_shared(output, objects, options=None):
if proc.returncode != 0:
msg = "Compilation error:\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
raise RuntimeError(msg)
@@ -110,7 +109,7 @@ def create_staticlib(output, inputs):
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "AR error:\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
msg += "\nCommand line: " + " ".join(cmd)
raise RuntimeError(msg)
@@ -121,7 +120,7 @@ def create_staticlib(output, inputs):
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Ranlib error:\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
msg += "\nCommand line: " + " ".join(cmd)
raise RuntimeError(msg)
diff --git a/python/tvm/support/nvcc.py b/python/tvm/support/nvcc.py
index 94dbd59ff6..859dbd3077 100644
--- a/python/tvm/support/nvcc.py
+++ b/python/tvm/support/nvcc.py
@@ -28,7 +28,6 @@ import tvm_ffi
import tvm
from tvm.target import Target
-from ..base import py_str
from . import utils
@@ -228,7 +227,7 @@ def _compile_cuda_nvcc(
if proc.returncode != 0:
msg = code
msg += "\nCompilation error:\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
raise RuntimeError(msg)
# Second stage for NVSHMEM
@@ -249,7 +248,7 @@ def _compile_cuda_nvcc(
if proc.returncode != 0:
msg = code
msg += "\nCompilation error:\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
raise RuntimeError(msg)
file_target = f"{file_prefix}.cubin"
@@ -660,7 +659,7 @@ def find_cuda_path():
cmd = ["which", "nvcc"]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
- out = py_str(out)
+ out = out.decode("utf-8", errors="replace")
if proc.returncode == 0:
return os.path.realpath(os.path.join(str(out).strip(), "../.."))
cuda_path = "/usr/local/cuda"
@@ -702,7 +701,7 @@ def get_cuda_version(cuda_path=None):
cmd = [os.path.join(cuda_path, "bin", "nvcc"), "--version"]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
- out = py_str(out)
+ out = out.decode("utf-8", errors="replace")
if proc.returncode == 0:
release_line = next(line for line in out.split("\n") if "release" in
line)
release_fields = [s.strip() for s in release_line.split(",")]
diff --git a/python/tvm/support/rocm.py b/python/tvm/support/rocm.py
index 628c95ac40..24a572f239 100644
--- a/python/tvm/support/rocm.py
+++ b/python/tvm/support/rocm.py
@@ -25,7 +25,6 @@ import tvm_ffi
import tvm.runtime
import tvm.target
-from tvm.base import py_str
from . import utils
@@ -97,7 +96,7 @@ def rocm_link(in_file, out_file, lld=None):
if proc.returncode != 0:
msg = "Linking error using ld.lld:\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
raise RuntimeError(msg)
@@ -285,7 +284,7 @@ def find_rocm_path():
cmd = ["which", "hipcc"]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
- out = out.decode("utf-8").strip()
+ out = out.decode("utf-8", errors="replace").strip()
if proc.returncode == 0:
return os.path.realpath(os.path.join(out, "../.."))
rocm_path = "/opt/rocm"
diff --git a/python/tvm/support/tar.py b/python/tvm/support/tar.py
index d0dc3f01eb..024c6a25ba 100644
--- a/python/tvm/support/tar.py
+++ b/python/tvm/support/tar.py
@@ -22,7 +22,6 @@ import os
import shutil
import subprocess
-from ..base import py_str
from . import utils
@@ -55,7 +54,7 @@ def tar(output, files):
if proc.returncode != 0:
msg = "Tar error:\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
raise RuntimeError(msg)
@@ -83,7 +82,7 @@ def untar(tar_file, directory):
if proc.returncode != 0:
msg = "Tar error:\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
raise RuntimeError(msg)
diff --git a/python/tvm/support/xcode.py b/python/tvm/support/xcode.py
index d67a86395b..ddfaa0e667 100644
--- a/python/tvm/support/xcode.py
+++ b/python/tvm/support/xcode.py
@@ -23,7 +23,6 @@ import os
import subprocess
import sys
-from ..base import py_str
from . import utils
@@ -100,7 +99,7 @@ def create_dylib(output, objects, arch, sdk="macosx",
min_os_version=None):
if proc.returncode != 0:
msg = "Compilation error:\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
raise RuntimeError(msg)
@@ -161,7 +160,7 @@ def compile_metal(code, path_target=None, sdk="macosx",
min_os_version=None):
(out, _) = proc.communicate()
if proc.returncode != 0:
sys.stderr.write("Compilation error:\n")
- sys.stderr.write(py_str(out))
+ sys.stderr.write(out.decode("utf-8", errors="replace"))
sys.stderr.flush()
libbin = None
else:
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 5cf96a7e07..3b0dea1bb5 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -93,7 +93,6 @@ import tvm.support.utils
import tvm.te
import tvm.tirx
from tvm.contrib import cudnn
-from tvm.error import TVMError
from tvm.support import nvcc, rocm
from tvm.target import codegen
@@ -441,7 +440,7 @@ def _get_targets(target_names=None):
)
return _get_targets(["llvm"])
- raise TVMError(
+ raise RuntimeError(
"None of the following targets are supported by this build of TVM:
%s."
" Try setting TVM_TEST_TARGETS to a supported target."
" Cannot default to llvm, as it is not enabled." % target_names
diff --git a/tests/python/arith/test_arith_analyzer_object.py
b/tests/python/arith/test_arith_analyzer_object.py
index 2b3931dfd9..4b4c4134b9 100644
--- a/tests/python/arith/test_arith_analyzer_object.py
+++ b/tests/python/arith/test_arith_analyzer_object.py
@@ -137,7 +137,7 @@ def test_analyzer_object_set_maximum_rewrite_steps():
capped = tvm.arith.Analyzer()
capped.set_maximum_rewrite_steps(1)
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
capped.rewrite_simplify(expr)
# A generous limit must not interfere with normal simplification.
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py
b/tests/python/arith/test_arith_rewrite_simplify.py
index 31c944e179..e0ef9da822 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -1265,10 +1265,10 @@ class TestDivZero(BaseCompare):
broadcast = tvm.tirx.Broadcast(0, 2)
test_case = tvm.testing.parameter(
- TestCase(tvm.tirx.Div(ramp, broadcast), tvm.error.TVMError),
- TestCase(tvm.tirx.Mod(ramp, broadcast), tvm.error.TVMError),
- TestCase(tvm.tirx.FloorDiv(ramp, broadcast), tvm.error.TVMError),
- TestCase(tvm.tirx.FloorMod(ramp, broadcast), tvm.error.TVMError),
+ TestCase(tvm.tirx.Div(ramp, broadcast), RuntimeError),
+ TestCase(tvm.tirx.Mod(ramp, broadcast), RuntimeError),
+ TestCase(tvm.tirx.FloorDiv(ramp, broadcast), RuntimeError),
+ TestCase(tvm.tirx.FloorMod(ramp, broadcast), RuntimeError),
)
diff --git a/tests/python/arith/test_arith_simplify.py
b/tests/python/arith/test_arith_simplify.py
index 5202dcba2c..556ce32b5b 100644
--- a/tests/python/arith/test_arith_simplify.py
+++ b/tests/python/arith/test_arith_simplify.py
@@ -150,9 +150,7 @@ def test_bind_allow_override():
ana.bind(x, tvm.ir.Range(0, 5), allow_override=True)
assert ana.can_prove(x < 5)
- with pytest.raises(
- tvm.error.TVMError, match="Trying to update var 'x' with a different
const bound"
- ):
+ with pytest.raises(RuntimeError, match="Trying to update var 'x' with a
different const bound"):
ana.bind(x, tvm.ir.Range(0, 3))
diff --git a/tests/python/codegen/test_target_codegen.py
b/tests/python/codegen/test_target_codegen.py
index 391470f95a..6c4464cbd0 100644
--- a/tests/python/codegen/test_target_codegen.py
+++ b/tests/python/codegen/test_target_codegen.py
@@ -31,7 +31,7 @@ def test_buffer_store_predicate_not_supported(target):
B.vstore([T.Ramp(0, 2, 4)], T.Broadcast(1.0, 4),
predicate=T.Broadcast(T.bool(True), 4))
err_msg = "Predicated buffer store is not supported."
- with pytest.raises(tvm.TVMError, match=err_msg):
+ with pytest.raises(RuntimeError, match=err_msg):
with tvm.target.Target(target):
tvm.compile(func)
@@ -51,7 +51,7 @@ def test_buffer_store_predicate_not_supported_gpu(target):
)
err_msg = "Predicated buffer store is not supported."
- with pytest.raises(tvm.TVMError, match=err_msg):
+ with pytest.raises(RuntimeError, match=err_msg):
with tvm.target.Target(target):
tvm.compile(func)
@@ -69,7 +69,7 @@ def test_buffer_load_predicate_not_supported(target):
)
err_msg = "Predicated buffer load is not supported."
- with pytest.raises(tvm.TVMError, match=err_msg):
+ with pytest.raises(RuntimeError, match=err_msg):
with tvm.target.Target(target):
tvm.compile(func)
@@ -89,7 +89,7 @@ def test_buffer_load_predicate_not_supported_gpu(target):
)
err_msg = "Predicated buffer load is not supported."
- with pytest.raises(tvm.TVMError, match=err_msg):
+ with pytest.raises(RuntimeError, match=err_msg):
with tvm.target.Target(target):
tvm.compile(func)
diff --git a/tests/python/codegen/test_target_codegen_cross_llvm.py
b/tests/python/codegen/test_target_codegen_cross_llvm.py
index 11800f1e61..204784f031 100644
--- a/tests/python/codegen/test_target_codegen_cross_llvm.py
+++ b/tests/python/codegen/test_target_codegen_cross_llvm.py
@@ -79,7 +79,7 @@ def test_llvm_add_pipeline():
port = int(os.environ["TVM_RPC_ARM_PORT"])
try:
remote = rpc.connect(host, port)
- except tvm.error.TVMError as e:
+ except RuntimeError as e:
pass
if remote:
diff --git a/tests/python/codegen/test_target_codegen_cuda.py
b/tests/python/codegen/test_target_codegen_cuda.py
index 7ffa189b64..31a2d4dd1a 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -900,7 +900,7 @@ def test_invalid_reinterpret():
for tx in T.thread_binding(4, "threadIdx.x"):
B[tx] = T.call_intrin("uint8", "tirx.reinterpret", A[tx])
- with pytest.raises(tvm.error.TVMError):
+ with pytest.raises(RuntimeError):
tvm.compile(func, target="cuda")
diff --git a/tests/python/codegen/test_target_codegen_llvm.py
b/tests/python/codegen/test_target_codegen_llvm.py
index 033d5af32f..8c2c8d6e07 100644
--- a/tests/python/codegen/test_target_codegen_llvm.py
+++ b/tests/python/codegen/test_target_codegen_llvm.py
@@ -956,7 +956,7 @@ def test_raise_exception_during_codegen():
for j in T.parallel(4):
B[i, j] = A[i, j] * 2.0
- with pytest.raises(tvm.TVMError) as e:
+ with pytest.raises(RuntimeError) as e:
tvm.compile(Module, target="llvm")
msg = str(e)
assert msg.find("Nested parallel loop is not supported") != -1
@@ -1146,7 +1146,7 @@ def test_call_packed_without_string_arg():
def main(A: T.Buffer(1, "float32")):
T.Call("int32", tvm.ir.Op.get("tirx.tvm_call_packed"), [A.data])
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
built = tvm.compile(Module, target="llvm")
@@ -1173,7 +1173,7 @@ def test_invalid_volatile_masked_buffer_load():
B[0:4] = A.vload([T.Ramp(0, 1, 4)],
predicate=T.Broadcast(T.bool(True), 4))
err_msg = "The masked load intrinsic does not support declaring load as
volatile."
- with pytest.raises(tvm.TVMError, match=err_msg):
+ with pytest.raises(RuntimeError, match=err_msg):
with tvm.target.Target("llvm"):
tvm.compile(Module)
@@ -1191,7 +1191,7 @@ def test_invalid_volatile_masked_buffer_store():
)
err_msg = "The masked store intrinsic does not support declaring store as
volatile."
- with pytest.raises(tvm.TVMError, match=err_msg):
+ with pytest.raises(RuntimeError, match=err_msg):
with tvm.target.Target("llvm"):
tvm.compile(Module)
diff --git a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py
b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py
index c5e4db1c3a..5ffd60626a 100644
--- a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py
+++ b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py
@@ -300,7 +300,7 @@ class TestElementWise:
is_hexagon = target_host.kind.name == "hexagon"
uses_2d_memory = "nchw-8h8w32c-2d" in [input_layout,
working_layout, output_layout]
if uses_2d_memory and not is_hexagon:
- stack.enter_context(pytest.raises(tvm.TVMError))
+ stack.enter_context(pytest.raises(RuntimeError))
tvm.compile(*schedule_args, target=target_host)
diff --git a/tests/python/contrib/test_hexagon/test_vtcm.py
b/tests/python/contrib/test_hexagon/test_vtcm.py
index 9ca5164e8a..7ac4327cc4 100644
--- a/tests/python/contrib/test_hexagon/test_vtcm.py
+++ b/tests/python/contrib/test_hexagon/test_vtcm.py
@@ -65,7 +65,7 @@ def test_vtcm_limit(vtcm_capacity, limited):
def _raises_exception(f):
try:
f()
- except tvm.base.TVMError:
+ except RuntimeError:
return True
return False
diff --git a/tests/python/contrib/test_rpc_tracker.py
b/tests/python/contrib/test_rpc_tracker.py
index a5351b62a6..67bad43ef0 100644
--- a/tests/python/contrib/test_rpc_tracker.py
+++ b/tests/python/contrib/test_rpc_tracker.py
@@ -84,7 +84,7 @@ def check_server_drop():
f1 = remote2.get_function("rpc.test2.addone")
assert f1(10) == 11
- except tvm.error.TVMError:
+ except RuntimeError:
pass
remote3 = tclient.request("abc")
f1 = remote3.get_function("rpc.test2.addone")
diff --git a/tests/python/ir/test_roundtrip_runtime_module.py
b/tests/python/ir/test_roundtrip_runtime_module.py
index 33bd1c11d1..d43c1544df 100644
--- a/tests/python/ir/test_roundtrip_runtime_module.py
+++ b/tests/python/ir/test_roundtrip_runtime_module.py
@@ -22,7 +22,6 @@ import pytest
import tvm
import tvm.testing
-from tvm import TVMError
def test_csource_module():
diff --git a/tests/python/relax/distributed/test_distributed_dtensor_sinfo.py
b/tests/python/relax/distributed/test_distributed_dtensor_sinfo.py
index 1bac08412a..4fdbdd82a1 100644
--- a/tests/python/relax/distributed/test_distributed_dtensor_sinfo.py
+++ b/tests/python/relax/distributed/test_distributed_dtensor_sinfo.py
@@ -14,15 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# ruff: noqa: F401
import pytest
import tvm_ffi
import tvm
import tvm.testing
-from tvm import TVMError, tirx
from tvm import relax as rx
+from tvm import tirx
from tvm.ir import Range
diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py
b/tests/python/relax/test_analysis_struct_info_analysis.py
index e2141bf94d..4d25f9f504 100644
--- a/tests/python/relax/test_analysis_struct_info_analysis.py
+++ b/tests/python/relax/test_analysis_struct_info_analysis.py
@@ -23,7 +23,7 @@ import tvm_ffi
import tvm
import tvm.testing
-from tvm import TVMError, ir, tirx
+from tvm import ir, tirx
from tvm import relax as rx
from tvm.script import relax as R
from tvm.script import tirx as T
@@ -405,7 +405,7 @@ def test_derive_call_ret_struct_info():
)
# Error: wrong number of arguments
- with pytest.raises(TVMError):
+ with pytest.raises(ValueError):
_check_derive(
bb,
func0(2),
@@ -414,7 +414,7 @@ def test_derive_call_ret_struct_info():
)
# Error:type mismatch
- with pytest.raises(TVMError):
+ with pytest.raises(ValueError):
_check_derive(bb, func0(2), [obj0], obj0)
# Tensor with vdevice
@@ -484,7 +484,7 @@ def test_derive_call_ret_struct_info():
)
# tuple length mismatch is not causes an error
- with pytest.raises(TVMError):
+ with pytest.raises(ValueError):
_check_derive(
bb,
func_tuple0(4),
diff --git a/tests/python/relax/test_analysis_suggest_layout_transforms.py
b/tests/python/relax/test_analysis_suggest_layout_transforms.py
index e6b8f6edf1..fb382acc91 100644
--- a/tests/python/relax/test_analysis_suggest_layout_transforms.py
+++ b/tests/python/relax/test_analysis_suggest_layout_transforms.py
@@ -80,7 +80,7 @@ def test_mismatch_transformations_and_num_params():
T.writes(relu[v_i0, v_i1, v_i2, v_i3])
relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2,
v_i3], T.float32(0))
- with pytest.raises(tvm.TVMError, match="Incompatible PrimFunc and
write_transformations"):
+ with pytest.raises(RuntimeError, match="Incompatible PrimFunc and
write_transformations"):
_ = relax.analysis.suggest_layout_transforms(
func=elemwise,
write_buffer_transforms=[
@@ -212,7 +212,7 @@ def test_invalid_index_map():
T.writes(relu[v_i0, v_i1, v_i2, v_i3])
relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2,
v_i3], T.float32(0))
- with pytest.raises(tvm.TVMError, match="Mismatch between output buffer
shape and index map"):
+ with pytest.raises(RuntimeError, match="Mismatch between output buffer
shape and index map"):
_ = relax.analysis.suggest_layout_transforms(
func=elemwise, write_buffer_transforms=[lambda n, h, w: (n, w, h)]
)
diff --git a/tests/python/relax/test_analysis_well_formed.py
b/tests/python/relax/test_analysis_well_formed.py
index f2c6657b70..06af93869c 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -162,7 +162,7 @@ def test_symbolic_var_across_functions():
def test_symbolic_var_invalid_type():
with pytest.raises(
- tvm.TVMError, match="the value in ShapeStructInfo can only have dtype
of int64"
+ RuntimeError, match="the value in ShapeStructInfo can only have dtype
of int64"
):
dim = tirx.Var("dim", "float32")
y = rx.Var("y", R.Tensor([dim], "float32"))
diff --git a/tests/python/relax/test_bind_params.py
b/tests/python/relax/test_bind_params.py
index dc13bc0992..9959f94b05 100644
--- a/tests/python/relax/test_bind_params.py
+++ b/tests/python/relax/test_bind_params.py
@@ -143,7 +143,7 @@ def test_error_on_unknown_var():
unknown_var = relax.Var("unknown_var")
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
before.bind_params({unknown_var: np.arange(16).astype("float32")})
@@ -153,7 +153,7 @@ def test_error_on_unknown_var_name():
R.func_attr({"global_symbol": "main"})
return A
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
before.bind_params({"unknown_var_name":
np.arange(16).astype("float32")})
diff --git a/tests/python/relax/test_bind_symbolic_vars.py
b/tests/python/relax/test_bind_symbolic_vars.py
index b822b589ee..9f36e46a38 100644
--- a/tests/python/relax/test_bind_symbolic_vars.py
+++ b/tests/python/relax/test_bind_symbolic_vars.py
@@ -68,7 +68,7 @@ def test_error_with_duplicate_var_names():
out: R.Tensor((N1, N2)) = R.matmul(A, B)
return out
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
func.bind_symbolic_vars({"N": 64})
@@ -106,7 +106,7 @@ def test_error_with_nonexisting_var_name():
def func(A: R.Tensor(("M", "N"))):
return A
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
func.bind_symbolic_vars({"non_existing_symbolic_var": 64})
@@ -117,7 +117,7 @@ def test_error_with_nonexisting_tir_var():
def func(A: R.Tensor(["M", "N"])):
return A
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
func.bind_symbolic_vars({tvm.tirx.Var("M", "int64"): 64})
@@ -131,7 +131,7 @@ def test_error_with_multiple_definitions():
tir_var = func.params[0].struct_info.shape[0]
symbolic_var_map = {tir_var: 0, "M": 0}
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
func.bind_symbolic_vars(symbolic_var_map)
@@ -144,7 +144,7 @@ def test_error_if_output_has_undefined():
outside_var = tvm.tirx.Var("outside_var", "int64")
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
func.bind_symbolic_vars({"M": outside_var * 2})
diff --git a/tests/python/relax/test_binding_rewrite.py
b/tests/python/relax/test_binding_rewrite.py
index b7faad31da..e33335ca64 100644
--- a/tests/python/relax/test_binding_rewrite.py
+++ b/tests/python/relax/test_binding_rewrite.py
@@ -20,7 +20,6 @@ import pytest
import tvm
import tvm.testing
-from tvm.base import TVMError
from tvm.relax.analysis import name_to_binding
from tvm.relax.binding_rewrite import DataflowBlockRewrite
from tvm.relax.expr import DataflowVar, Var
@@ -141,7 +140,7 @@ def test_remove_unused_undef():
root_fn = Identity["main"]
dfb = root_fn.body.blocks[0]
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
rwt = DataflowBlockRewrite(dfb, root_fn)
rwt.remove_unused(Var("whatever"))
diff --git a/tests/python/relax/test_blockbuilder_core.py
b/tests/python/relax/test_blockbuilder_core.py
index bef5a58c07..c89597bbef 100644
--- a/tests/python/relax/test_blockbuilder_core.py
+++ b/tests/python/relax/test_blockbuilder_core.py
@@ -918,7 +918,7 @@ def test_error_when_unwrapping_dataflowvar():
rhs = bb.emit(func.bind_params({lhs: local_lhs}).body, "f")
out = bb.emit_output(rhs, "f")
- with pytest.raises(tvm.TVMError, match="Malformed AST"):
+ with pytest.raises(RuntimeError, match="Malformed AST"):
bb.emit_func_output(out)
diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py
index 30a59ae30d..31bd012b44 100644
--- a/tests/python/relax/test_expr.py
+++ b/tests/python/relax/test_expr.py
@@ -241,7 +241,7 @@ def test_shape_expr():
m = tirx.Var("m", "int32")
with pytest.raises(
- tvm.TVMError, match="the value in ShapeStructInfo can only have dtype
of int64"
+ RuntimeError, match="the value in ShapeStructInfo can only have dtype
of int64"
):
rx.ShapeExpr([m, 3])
diff --git a/tests/python/relax/test_frontend_nn_extern_module.py
b/tests/python/relax/test_frontend_nn_extern_module.py
index dba87c3fde..b837eb6866 100644
--- a/tests/python/relax/test_frontend_nn_extern_module.py
+++ b/tests/python/relax/test_frontend_nn_extern_module.py
@@ -121,11 +121,6 @@ def _check_ir_equality(mod):
def _compile_cc(src: Path, dst: Path):
- # pylint: disable=import-outside-toplevel
- from tvm.base import py_str
-
- # pylint: enable=import-outside-toplevel
-
cmd = ["g++", str(src)]
default_include_paths = [
tvm.libinfo.find_include_path(),
@@ -145,7 +140,7 @@ def _compile_cc(src: Path, dst: Path):
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Compilation error:\n"
- msg += py_str(out)
+ msg += out.decode("utf-8", errors="replace")
msg += "\nCommand line: " + " ".join(cmd)
raise RuntimeError(msg)
diff --git a/tests/python/relax/test_relax_operators.py
b/tests/python/relax/test_relax_operators.py
index 8a2eac04d1..7392435476 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# ruff: noqa: E501, F401, F841
+# ruff: noqa: E501, F841
import sys
import tempfile
@@ -26,7 +26,6 @@ import tvm_ffi
import tvm
import tvm.testing
from tvm import relax
-from tvm.base import TVMError
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tirx as T
diff --git a/tests/python/relax/test_runtime_builtin_rnn_state.py
b/tests/python/relax/test_runtime_builtin_rnn_state.py
index 5cead461b2..bc54ce1fd1 100644
--- a/tests/python/relax/test_runtime_builtin_rnn_state.py
+++ b/tests/python/relax/test_runtime_builtin_rnn_state.py
@@ -161,7 +161,7 @@ def test_rnn_state_popn(rnn_state): # pylint:
disable=redefined-outer-name
verify_state(state, [0], [[np_two, np_three]])
f_popn(state, 0, 1)
verify_state(state, [0], [[np_zero, np_one]])
- with pytest.raises(tvm.error.TVMError):
+ with pytest.raises(RuntimeError):
f_popn(state, 0, 1) # no available history to pop
diff --git a/tests/python/relax/test_struct_info.py
b/tests/python/relax/test_struct_info.py
index 622f1e369b..dd2e415f9a 100644
--- a/tests/python/relax/test_struct_info.py
+++ b/tests/python/relax/test_struct_info.py
@@ -20,8 +20,8 @@ import tvm_ffi
import tvm
import tvm.testing
-from tvm import TVMError, tirx
from tvm import relax as rx
+from tvm import tirx
def _check_equal(x, y, map_free_vars=False):
@@ -87,7 +87,7 @@ def test_prim_struct_info():
assert s2.dtype == "int32"
# wrong API constructors
- with pytest.raises((TVMError, TypeError)):
+ with pytest.raises((RuntimeError, TypeError)):
rx.PrimStructInfo([1])
@@ -136,7 +136,7 @@ def test_shape_struct_info():
str(s0)
# wrong argument type
- with pytest.raises((TVMError, TypeError)):
+ with pytest.raises((RuntimeError, TypeError)):
rx.ShapeStructInfo(1)
# cannot pass both ndim and values
diff --git a/tests/python/relax/test_training_append_loss.py
b/tests/python/relax/test_training_append_loss.py
index eff93570a6..abf5e3ac86 100644
--- a/tests/python/relax/test_training_append_loss.py
+++ b/tests/python/relax/test_training_append_loss.py
@@ -18,7 +18,6 @@
import pytest
import tvm.testing
-from tvm import TVMError
from tvm.ir.base import assert_structural_equal
from tvm.relax.training import AppendLoss
from tvm.script import ir as I
@@ -181,7 +180,7 @@ def test_error_return_value_vs_parameter():
return gv0
# fmt: on
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
AppendLoss("main", loss1, 2)(Module1)
# The numbers of backbone return value and loss parameter are not enough
@@ -203,7 +202,7 @@ def test_error_return_value_vs_parameter():
return gv0
# fmt: on
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
AppendLoss("main", loss2, 2)(Module2)
# Backbone returns nested tuple
@@ -227,7 +226,7 @@ def test_error_return_value_vs_parameter():
return gv0
# fmt: on
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
AppendLoss("main", loss3, 1)(Module3)
@@ -252,7 +251,7 @@ def test_error_more_blocks():
return gv
# fmt: on
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
AppendLoss("main", loss1)(Module1)
# loss more than one blocks
@@ -275,7 +274,7 @@ def test_error_more_blocks():
return gv1
# fmt: on
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
AppendLoss("main", loss2)(Module2)
@@ -299,7 +298,7 @@ def test_loss_return_value():
return gv0
# fmt: on
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
AppendLoss("main", loss)(Module)
# loss returns tuple
@@ -322,7 +321,7 @@ def test_loss_return_value():
return gv0, gv1
# fmt: on
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
AppendLoss("main", loss)(Module)
diff --git a/tests/python/relax/test_training_setup_trainer.py
b/tests/python/relax/test_training_setup_trainer.py
index a9343ca417..441dd7011d 100644
--- a/tests/python/relax/test_training_setup_trainer.py
+++ b/tests/python/relax/test_training_setup_trainer.py
@@ -19,7 +19,7 @@ import pytest
import tvm
import tvm.testing
-from tvm import TVMError, relax
+from tvm import relax
from tvm.ir.base import assert_structural_equal
from tvm.relax.training import SetupTrainer
from tvm.relax.training.loss import MSELoss
@@ -203,7 +203,7 @@ def test_invalid_mod():
[pred_sinfo, pred_sinfo],
)
- with pytest.raises((TVMError, ValueError)):
+ with pytest.raises((RuntimeError, ValueError)):
SetupTrainer(
MSELoss(reduction="sum"),
SGD(0.001),
diff --git a/tests/python/relax/test_training_trainer_numeric.py
b/tests/python/relax/test_training_trainer_numeric.py
index b96c46dd8f..0026d72af7 100644
--- a/tests/python/relax/test_training_trainer_numeric.py
+++ b/tests/python/relax/test_training_trainer_numeric.py
@@ -19,7 +19,7 @@ import pytest
import tvm
import tvm.testing
-from tvm import TVMError, relax
+from tvm import relax
from tvm.relax.training import SetupTrainer, Trainer
from tvm.relax.training.loss import MSELoss
from tvm.relax.training.optimizer import SGD, Adam
@@ -160,9 +160,9 @@ def test_setting_error(target, dev):
dataset = _make_dataset()
# parameters are not inited
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
trainer.predict(dataset[0][0])
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
trainer.update(dataset[0][0], dataset[0][1])
diff --git a/tests/python/relax/test_transform_bind_symbolic_vars.py
b/tests/python/relax/test_transform_bind_symbolic_vars.py
index d7deae250f..6b0f3a075a 100644
--- a/tests/python/relax/test_transform_bind_symbolic_vars.py
+++ b/tests/python/relax/test_transform_bind_symbolic_vars.py
@@ -262,7 +262,7 @@ def test_error_for_unused_replacement():
def main(x: R.Tensor(("m", "n"), dtype="float32")):
return x
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.BindSymbolicVars({"non_existing_var_name": 16})(Before)
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index bf60261a7b..1362f93758 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -500,7 +500,7 @@ def test_cyclic_dependency():
relu_pat = is_op("relax.nn.relu")(conv_pat)
add_pat = is_op("relax.add")(relu_pat, wildcard())
- with pytest.raises(tvm.error.TVMError) as err:
+ with pytest.raises(RuntimeError) as err:
relax.transform.FuseOpsByPattern(
[("compiler_A.conv2d_relu_add", add_pat)], bind_constants=True
)(Branch)
diff --git a/tests/python/relax/test_transform_fuse_tir.py
b/tests/python/relax/test_transform_fuse_tir.py
index 536e124ffa..57b91b9448 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -2443,7 +2443,7 @@ def
test_fuse_with_axis_separators_inconsistent_buffer_mapping():
return gv
with pytest.raises(
- tvm.TVMError, match=r"Inconsistent buffers.*and.*mapped to the same
relax var:.*"
+ RuntimeError, match=r"Inconsistent buffers.*and.*mapped to the same
relax var:.*"
):
relax.transform.FuseTIR()(Before)
diff --git a/tests/python/relax/test_transform_gradient.py
b/tests/python/relax/test_transform_gradient.py
index 75278e474e..962cf82104 100644
--- a/tests/python/relax/test_transform_gradient.py
+++ b/tests/python/relax/test_transform_gradient.py
@@ -21,7 +21,6 @@ import pytest
import tvm
import tvm.testing
from tvm import relax
-from tvm.base import TVMError
from tvm.ir.base import assert_structural_equal
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
@@ -366,7 +365,7 @@ def test_intermediate_var_require_grads():
# z does not occur in function
z = relax.Var("z", R.Tensor((3, 3), "float32"))
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.Gradient("main", [x, lv1, z])(Before)
@@ -1109,7 +1108,7 @@ def test_report_error():
R.output(gv)
return gv
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.Gradient("main")(TargetNotTensor)
@I.ir_module(s_tir=True)
@@ -1121,7 +1120,7 @@ def test_report_error():
R.output(gv)
return gv
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.Gradient("main")(TargetNotScalar)
@I.ir_module(s_tir=True)
@@ -1133,7 +1132,7 @@ def test_report_error():
R.output(gv)
return gv
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.Gradient("main")(TargetNotFloat)
@I.ir_module(s_tir=True)
@@ -1145,7 +1144,7 @@ def test_report_error():
R.output(gv)
return gv
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.Gradient("main",
target_index=1)(ReturnScalarAndWrongTargetIndex)
@I.ir_module(s_tir=True)
@@ -1158,7 +1157,7 @@ def test_report_error():
R.output(gv1, gv2)
return gv1, gv2
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.Gradient("main",
target_index=2)(ReturnTupleAndWrongTargetIndex)
@I.ir_module(s_tir=True)
@@ -1170,7 +1169,7 @@ def test_report_error():
R.output(gv)
return gv, (gv, gv)
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.Gradient("main", target_index=1)(IndexedTargetNotVar)
@I.ir_module(s_tir=True)
@@ -1180,7 +1179,7 @@ def test_report_error():
gv = R.sum(x0)
return gv
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.Gradient("main")(NoDataflow)
@I.ir_module(s_tir=True)
@@ -1195,7 +1194,7 @@ def test_report_error():
gv1 = R.sum(x0)
return gv1
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.Gradient("main")(MultiBlocks)
@I.ir_module(s_tir=True)
@@ -1226,10 +1225,10 @@ def test_report_error():
with pytest.raises(ValueError):
relax.transform.Gradient("main1")(NormalModule)
# wrong function type
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.Gradient("sum")(NormalModule)
# no such var
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.Gradient("main",
require_grads=MultiBlocks["main"].params[0])(NormalModule)
@I.ir_module(s_tir=True)
@@ -1242,7 +1241,7 @@ def test_report_error():
R.output(gv)
return gv
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.Gradient("main")(IntDtype)
@I.ir_module(s_tir=True)
@@ -1257,7 +1256,7 @@ def test_report_error():
R.output(gv)
return gv
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.Gradient("main")(IntDtypeTuple)
diff --git a/tests/python/relax/test_transform_lift_transform_params.py
b/tests/python/relax/test_transform_lift_transform_params.py
index 199cf9ab27..56299fc4fe 100644
--- a/tests/python/relax/test_transform_lift_transform_params.py
+++ b/tests/python/relax/test_transform_lift_transform_params.py
@@ -604,7 +604,7 @@ def
test_incompatible_weights_in_shared_transform_raises_error():
R.output(output)
return output
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.LiftTransformParams(shared_transform=True)(Before)
@@ -649,7 +649,7 @@ def
test_incompatible_shape_in_shared_transform_raises_error():
R.output(output)
return output
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.LiftTransformParams(shared_transform=True)(Before)
@@ -694,7 +694,7 @@ def
test_incompatible_dtype_in_shared_transform_raises_error():
R.output(output)
return output
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.LiftTransformParams(shared_transform=True)(Before)
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 61a17f3991..7a658c0a35 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -20,7 +20,7 @@ import pytest
import tvm
import tvm.testing
-from tvm import TVMError, relax
+from tvm import relax
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tirx as T
@@ -1242,7 +1242,7 @@ def test_invalid_tir_var_upper_bound():
R.func_attr({"tir_var_upper_bound": {"n": [4]},
"relax.force_pure": True})
return x
- with pytest.raises((TVMError, TypeError)):
+ with pytest.raises((RuntimeError, TypeError)):
relax.transform.StaticPlanBlockMemory()(Module)
@@ -1254,7 +1254,7 @@ def test_invalid_tir_var_lower_bound():
R.func_attr({"tir_var_lower_bound": {"n": [4]},
"relax.force_pure": True})
return x
- with pytest.raises((TVMError, TypeError)):
+ with pytest.raises((RuntimeError, TypeError)):
relax.transform.StaticPlanBlockMemory()(Module)
diff --git a/tests/python/relax/test_vm_build.py
b/tests/python/relax/test_vm_build.py
index e591eb6fd3..2ef1f780ef 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -604,7 +604,7 @@ def test_vm_relax_multiple_symbolic_prim_value(exec_mode):
with pytest.raises(RuntimeError):
func(2, Shape([4, 12]), 1)
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
func(Shape([2]))
diff --git a/tests/python/relax/test_vm_builtin_lower.py
b/tests/python/relax/test_vm_builtin_lower.py
index 50bbc4fb72..59ac5c3f12 100644
--- a/tests/python/relax/test_vm_builtin_lower.py
+++ b/tests/python/relax/test_vm_builtin_lower.py
@@ -79,7 +79,7 @@ def test_vm_builtin_alloc_tensor_raises_error():
gv0 = alloc
return gv0
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
relax.transform.LowerRuntimeBuiltin()(Before)
diff --git a/tests/python/relax/test_vm_cuda_graph.py
b/tests/python/relax/test_vm_cuda_graph.py
index a7390cc9a2..d558e6f51b 100644
--- a/tests/python/relax/test_vm_cuda_graph.py
+++ b/tests/python/relax/test_vm_cuda_graph.py
@@ -174,7 +174,7 @@ def test_capture_error_is_recoverable():
arg = tvm.runtime.tensor(np.arange(16).astype("float16"), dev)
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
vm["main"](arg)
diff --git a/tests/python/relax/test_vm_execbuilder.py
b/tests/python/relax/test_vm_execbuilder.py
index 9d8d19d747..35ac921956 100644
--- a/tests/python/relax/test_vm_execbuilder.py
+++ b/tests/python/relax/test_vm_execbuilder.py
@@ -21,7 +21,7 @@ import pytest
import tvm_ffi
import tvm
-from tvm import TVMError, relax
+from tvm import relax
from tvm.relax.testing.vm import check_saved_func
from tvm.script import relax as R
@@ -76,7 +76,7 @@ def test_vm_multiple_func():
def test_vm_checker():
ib = relax.ExecBuilder()
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
with ib.function("func0", num_inputs=2):
ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(2)], dst=ib.r(2))
ib.emit_ret(ib.r(2))
diff --git a/tests/python/runtime/test_runtime_rpc.py
b/tests/python/runtime/test_runtime_rpc.py
index 2db325b135..079116281f 100644
--- a/tests/python/runtime/test_runtime_rpc.py
+++ b/tests/python/runtime/test_runtime_rpc.py
@@ -106,7 +106,7 @@ def test_rpc_simple():
assert f1(10) == 11
f3 = client.get_function("rpc.test.except")
- with pytest.raises(tvm.base.TVMError):
+ with pytest.raises(RuntimeError):
f3("abc")
f2 = client.get_function("rpc.test.strcat")
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py
index 1dee52eb57..9ec11c077b 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py
@@ -27,7 +27,6 @@ from tvm_ffi import register_global_func
import tvm
import tvm.testing
from tvm import te
-from tvm.error import TVMError
from tvm.ir.module import IRModule
from tvm.ir.utils import derived_object
from tvm.s_tir.meta_schedule import TuneContext
@@ -314,7 +313,7 @@ def test_meta_schedule_post_order_apply_duplicate_matmul():
)
post_order_apply = context.space_generator
with pytest.raises(
- TVMError,
+ RuntimeError,
match=r".*Duplicated block name matmul in function main not
supported!",
):
post_order_apply.generate_design_space(mod)
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py
index a1b68134e7..5dd2d46e76 100644
---
a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py
+++
b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py
@@ -64,7 +64,7 @@ def _apply_rewrite_layout(mod):
sch = tvm.s_tir.Schedule(mod, debug_mask="all")
sch.enter_postproc()
if not ctx.space_generator.postprocs[0].apply(sch):
- raise tvm.TVMError("RewriteLayout postproc failed")
+ raise RuntimeError("RewriteLayout postproc failed")
return sch.mod
@@ -130,7 +130,7 @@ def test_rewritten_buffers_must_occur_within_block():
T.evaluate(A[i, j])
mod = tvm.IRModule.from_expr(before)
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
_apply_rewrite_layout(mod)
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_generator.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_generator.py
index 5515d66f9a..08c84b36d6 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_generator.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_generator.py
@@ -23,7 +23,6 @@ import pytest
import tvm
import tvm.testing
-from tvm.base import TVMError
from tvm.ir.utils import derived_object
from tvm.s_tir.meta_schedule.space_generator import (
PySpaceGenerator,
@@ -101,7 +100,7 @@ def test_meta_schedule_design_space_generator_NIE():
self.mutator_probs = {}
with pytest.raises(
- TVMError, match="PySpaceGenerator's InitializeWithTuneContext method
not implemented!"
+ RuntimeError, match="PySpaceGenerator's InitializeWithTuneContext
method not implemented!"
):
generator = TestPySpaceGenerator()
generator._initialize_with_tune_context(TuneContext())
diff --git a/tests/python/s_tir/schedule/test_tir_schedule_rfactor.py
b/tests/python/s_tir/schedule/test_tir_schedule_rfactor.py
index b8af9c975c..306f4ed12d 100644
--- a/tests/python/s_tir/schedule/test_tir_schedule_rfactor.py
+++ b/tests/python/s_tir/schedule/test_tir_schedule_rfactor.py
@@ -1526,7 +1526,7 @@ def test_reduction_rfactor_predicate(): # pylint:
disable=invalid-name
B = s.get_sblock("B")
_, ko, _ = s.get_loops(B)
# TODO: should be a tvm.s_tir.ScheduleError
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
rf_block = s.rfactor(ko, 1)
diff --git a/tests/python/s_tir/schedule/test_tir_schedule_transform_layout.py
b/tests/python/s_tir/schedule/test_tir_schedule_transform_layout.py
index 3b8e5d4611..54ae20ac74 100644
--- a/tests/python/s_tir/schedule/test_tir_schedule_transform_layout.py
+++ b/tests/python/s_tir/schedule/test_tir_schedule_transform_layout.py
@@ -1209,7 +1209,7 @@ def test_index_map_dtype_legalize():
sch = tvm.s_tir.Schedule(func)
# # The following error is raised from the IterVar constructor without the
dtype legalization.
- # # TVMError: Check failed: dom->extent.dtype() == var.dtype() (int64 vs.
int32) :
+ # # RuntimeError: Check failed: dom->extent.dtype() == var.dtype() (int64
vs. int32) :
# # The dtype of the extent of an IterVar (int64) must match its
associated Var's dtype (int32)
sch.transform_layout(
sch.get_sblock("block"), buffer="A", index_map=lambda h: [h // 8, h %
8], pad_value=0
@@ -1247,7 +1247,7 @@ def test_index_map_dtype_legalize_with_constant():
# Prior to the bugfix, this resulted in the following error is
# raised from the IterVar constructor.
#
- # TVMError: Check failed: dom->extent.dtype() == var.dtype() (int64 vs.
int32) :
+ # RuntimeError: Check failed: dom->extent.dtype() == var.dtype() (int64
vs. int32) :
# The dtype of the extent of an IterVar (int64) must match its associated
Var's dtype (int32)
sch.transform_layout(block="block", buffer="A", index_map=func,
pad_value=0)
diff --git
a/tests/python/s_tir/transform/test_s_tir_transform_convert_blocks_to_opaque.py
b/tests/python/s_tir/transform/test_s_tir_transform_convert_blocks_to_opaque.py
index 60c628b1e2..ff93ad3319 100644
---
a/tests/python/s_tir/transform/test_s_tir_transform_convert_blocks_to_opaque.py
+++
b/tests/python/s_tir/transform/test_s_tir_transform_convert_blocks_to_opaque.py
@@ -89,7 +89,7 @@ def test_error_if_predicate_uses_block_variables():
T.where(vi < 6)
T.evaluate(0)
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
tvm.s_tir.transform.ConvertBlocksToOpaque()(Before)
diff --git
a/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py
b/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py
index 338ab63b21..1ba1b4839d 100644
---
a/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py
+++
b/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py
@@ -23,7 +23,7 @@ import pytest
import tvm
import tvm.s_tir.tensor_intrin.cuda
import tvm.testing
-from tvm import TVMError, te, tirx
+from tvm import te, tirx
from tvm.s_tir.meta_schedule.testing import te_workload
from tvm.s_tir.tensor_intrin.cuda import (
LDMATRIX_f16_A_DYN_INTRIN,
diff --git
a/tests/python/s_tir/transform/test_s_tir_transform_lower_match_buffer.py
b/tests/python/s_tir/transform/test_s_tir_transform_lower_match_buffer.py
index 42f0c98d05..b385ce7249 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_lower_match_buffer.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_match_buffer.py
@@ -32,7 +32,7 @@ def _check(original, transformed):
def _check_fail(original):
mod = tvm.IRModule.from_expr(original)
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
mod = tvm.s_tir.transform.LowerMatchBuffer()(mod)
diff --git a/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py
b/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py
index a1dc101c74..bd59e5e395 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py
@@ -19,7 +19,6 @@ import pytest
import tvm
import tvm.testing
-from tvm import TVMError
from tvm.script import ir as I
from tvm.script import tirx as T
@@ -107,7 +106,7 @@ def test_raise_error_for_undef_as_store_indices():
val: T.let[T.int32] = T.undef(dtype="int32")
A[val] = 5
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
tvm.s_tir.transform.RemoveStoreUndef()(Before)
@@ -124,7 +123,7 @@ def test_raise_error_for_undef_as_load_indices():
def main(A: T.Buffer(1, "int32"), B: T.Buffer(1, "int32")):
B[0] = A[T.undef(dtype="int32")]
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
tvm.s_tir.transform.RemoveStoreUndef()(Before)
diff --git a/tests/python/target/test_target_target.py
b/tests/python/target/test_target_target.py
index c037fcadd2..fc6cf209d3 100644
--- a/tests/python/target/test_target_target.py
+++ b/tests/python/target/test_target_target.py
@@ -443,7 +443,7 @@ def test_webgpu_target_subgroup_attrs():
{"kind": "webgpu", "thread_warp_size": 32},
{"kind": "webgpu", "thread_warp_size": 32, "supports_subgroups":
False},
]:
- with pytest.raises(tvm.TVMError, match="requires
supports_subgroups=true"):
+ with pytest.raises(ValueError, match="requires
supports_subgroups=true"):
Target(config)
diff --git a/tests/python/te/test_te_verify_compute.py
b/tests/python/te/test_te_verify_compute.py
index cef420cf7f..a8886de8a0 100644
--- a/tests/python/te/test_te_verify_compute.py
+++ b/tests/python/te/test_te_verify_compute.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# ruff: noqa: E731, F841
-import tvm
from tvm import te
@@ -36,14 +35,14 @@ def test_verify_compute():
# Valid compute
try:
B = te.compute((n,), f1, name="B")
- except tvm.base.TVMError as ex:
+ except RuntimeError as ex:
assert False
#
# Valid compute
try:
B = te.compute((n,), f2, name="B")
- except tvm.base.TVMError as ex:
+ except RuntimeError as ex:
assert False
#
@@ -51,7 +50,7 @@ def test_verify_compute():
try:
B = te.compute((n,), f3, name="B")
assert False
- except tvm.base.TVMError as ex:
+ except RuntimeError as ex:
pass
#
@@ -59,7 +58,7 @@ def test_verify_compute():
try:
B = te.compute((n,), f4, name="B")
assert False
- except tvm.base.TVMError as ex:
+ except RuntimeError as ex:
pass
#
@@ -67,7 +66,7 @@ def test_verify_compute():
try:
B0, B1 = te.compute((n,), f5, name="B")
assert False
- except tvm.base.TVMError as ex:
+ except RuntimeError as ex:
pass
#
@@ -75,7 +74,7 @@ def test_verify_compute():
try:
B0, B1 = te.compute((n,), f6, name="B")
assert False
- except tvm.base.TVMError as ex:
+ except RuntimeError as ex:
pass
diff --git a/tests/python/tirx-base/test_tir_base.py
b/tests/python/tirx-base/test_tir_base.py
index 5011147998..2abe612890 100644
--- a/tests/python/tirx-base/test_tir_base.py
+++ b/tests/python/tirx-base/test_tir_base.py
@@ -22,7 +22,6 @@ import pytest
import tvm
from tvm import tirx
-from tvm.base import TVMError
from tvm.ir.transform import PassContext
from tvm.script import tirx as T
@@ -70,7 +69,7 @@ def test_fail_implicit_downcasts_same_type():
bits = [8, 16, 32, 64]
for type in ["float", "int", "uint"]:
for i in range(len(bits) - 1):
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
assignment_helper(
store_dtype=f"{type}{bits[i]}",
value_dtype=f"{type}{bits[i + 1]}"
)
@@ -89,7 +88,7 @@ def test_cast_between_types():
assignment_helper(store_dtype, value_dtype)
else:
# TODO: we might want to allow casts between uint and int types
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
assignment_helper(store_dtype, value_dtype)
diff --git a/tests/python/tirx-base/test_tir_constructor.py
b/tests/python/tirx-base/test_tir_constructor.py
index b2628e1b00..3ad59119cd 100644
--- a/tests/python/tirx-base/test_tir_constructor.py
+++ b/tests/python/tirx-base/test_tir_constructor.py
@@ -249,7 +249,8 @@ def test_stmt_constructor():
def test_float_constructor_requires_float_dtype():
- with pytest.raises(tvm.TVMError):
+ # FloatImm dtype validation raises a builtin ValueError.
+ with pytest.raises(ValueError):
tvm.tirx.FloatImm("int32", 1.0)
diff --git a/tests/python/tirx-base/test_tir_imm_values.py
b/tests/python/tirx-base/test_tir_imm_values.py
index 2e873896a1..bf0002bea4 100644
--- a/tests/python/tirx-base/test_tir_imm_values.py
+++ b/tests/python/tirx-base/test_tir_imm_values.py
@@ -56,7 +56,11 @@ def test_tir_make_intimm(dtype, literals):
)
def test_tir_invalid_intimm(dtype, literals):
for l in literals:
- with pytest.raises(tvm.TVMError):
+ # Out-of-range positive literals raise a builtin ValueError from
+ # the IntImm range check; negative-into-unsigned raises an
+ # InternalError ("cannot make uint from negative value") which is a
+ # RuntimeError subclass. Accept either.
+ with pytest.raises((RuntimeError, ValueError)):
tirx.const(l, dtype)
@@ -130,7 +134,8 @@ def test_tir_make_floatimm(dtype, literals):
def test_tir_invalid_floatimm(dtype, literals):
"""Currently only fp16 and fp32 have range check."""
for l in literals:
- with pytest.raises(tvm.TVMError):
+ # FloatImm out-of-range raises a builtin ValueError.
+ with pytest.raises(ValueError):
tirx.const(l, dtype)
@@ -292,7 +297,7 @@ def test_tir_floatimm_const_fold():
check_tir_const_fold("float32", lambda x, y: x / y, fdiv, 3.0e30, 3.0e-30,
np.inf)
# divide by zero
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
check_tir_const_fold("float32", lambda x, y: x / y, fdiv, 1.0, 0.0)
# nan and inf
@@ -344,9 +349,9 @@ def test_tir_int8_const_fold():
check_tir_const_fold("int8", lambda x, y: x * y, fmul, 127, 127, 1)
# divide by zero
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
check_tir_const_fold("int8", lambda x, y: tirx.floordiv(x, y),
ffloordiv, 1, 0)
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
check_tir_const_fold("int8", lambda x, y: tirx.truncdiv(x, y),
ftruncdiv, 1, 0)
# i8 mod folding is not implemented
@@ -399,13 +404,13 @@ def test_tir_uint8_const_fold():
check_tir_const_fold("uint8", lambda x, y: x + y, fadd, 255, 1, 0)
# zero sub
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
check_tir_const_fold("uint8", lambda x, y: x - y, fsub, 0, 10)
# divide by zero
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
check_tir_const_fold("uint8", lambda x, y: tirx.floordiv(x, y),
ffloordiv, 1, 0)
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
check_tir_const_fold("uint8", lambda x, y: tirx.truncdiv(x, y),
ftruncdiv, 1, 0)
# u8 mod folding is not implemented
@@ -473,13 +478,13 @@ def test_tir_int32_const_fold():
assert -(2**31) <= int(tirx.const(-(2**31), "int32") - tirx.const(1,
"int32")) < 2**31
# divide by zero
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
check_tir_const_fold("int32", lambda x, y: tirx.floordiv(x, y),
ffloordiv, 1, 0)
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
check_tir_const_fold("int32", lambda x, y: tirx.floormod(x, y),
ffloormod, 1, 0)
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
check_tir_const_fold("int32", lambda x, y: tirx.truncdiv(x, y),
ftruncdiv, 1, 0)
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
check_tir_const_fold("int32", lambda x, y: tirx.truncmod(x, y),
ftruncmod, 1, 0)
# randomized check
@@ -550,9 +555,9 @@ def test_tir_uint32_const_fold():
assert 0 <= int(tirx.const(2**32 - 1, "uint32") + tirx.const(1, "uint32"))
< 2**32
# divide by zero
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
check_tir_const_fold("uint32", lambda x, y: tirx.floordiv(x, y),
ffloordiv, 1, 0)
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
check_tir_const_fold("uint32", lambda x, y: tirx.truncdiv(x, y),
ftruncdiv, 1, 0)
# u8 mod folding is not implemented
diff --git a/tests/python/tirx-base/test_tir_index_map.py
b/tests/python/tirx-base/test_tir_index_map.py
index 539ff34804..2e2d6fa223 100644
--- a/tests/python/tirx-base/test_tir_index_map.py
+++ b/tests/python/tirx-base/test_tir_index_map.py
@@ -117,7 +117,7 @@ def test_inverse_accepts_external_analyzer():
def test_nonbijective_inverse_gives_error():
index_map = IndexMap.from_func(lambda i: [i // 4, i % 4])
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
index_map.inverse([14])
diff --git a/tests/python/tirx-base/test_tir_nodes.py
b/tests/python/tirx-base/test_tir_nodes.py
index 4a5ee3ae35..5ed20f99c6 100644
--- a/tests/python/tirx-base/test_tir_nodes.py
+++ b/tests/python/tirx-base/test_tir_nodes.py
@@ -103,7 +103,7 @@ def test_cast():
assert z.lanes == 4
s = tvm.tirx.StringImm("s")
- with pytest.raises(tvm.error.TVMError):
+ with pytest.raises(RuntimeError):
try:
s.astype("int")
except Exception as e:
@@ -228,7 +228,7 @@ def test_float_bitwise():
try:
test(t, 10.0)
assert False
- except tvm.TVMError:
+ except RuntimeError:
pass
try:
~t
@@ -245,7 +245,7 @@ def test_shift_bounds():
try:
test(*testcase)
assert False
- except tvm.TVMError:
+ except RuntimeError:
pass
# positive case
@@ -264,7 +264,7 @@ def test_divide_by_zero():
try:
test(tvm.tirx.const(5, "int32"), tvm.tirx.const(0, "int32"))
assert False
- except tvm.TVMError:
+ except RuntimeError:
pass
@@ -395,7 +395,7 @@ def test_scalable_vec(lanes, node_func):
)
@pytest.mark.parametrize("node_func", [_create_ramp, _create_broadcast])
def test_scalable_vec_error(lanes, node_func):
- with pytest.raises(tvm.error.TVMError):
+ with pytest.raises(RuntimeError):
node_func(lanes)
@@ -435,7 +435,7 @@ def test_buffer_store_predicate_invalid_scalability():
predicate = tvm.tirx.expr.Broadcast(tvm.tirx.IntImm("int1", 1), 4)
err_msg = "Predicate mask dtype and value dtype must both be scalable."
- with pytest.raises(tvm.TVMError, match=err_msg):
+ with pytest.raises(RuntimeError, match=err_msg):
tvm.tirx.BufferStore(b, value, [index], predicate)
@@ -449,7 +449,7 @@ def test_buffer_store_predicate_invalid_lanes():
"Got a predicate mask with 8 lanes, but trying to store a "
"value with 4 lanes. The number of lanes must match."
)
- with pytest.raises(tvm.TVMError, match=err_msg):
+ with pytest.raises(RuntimeError, match=err_msg):
tvm.tirx.BufferStore(b, value, [index], predicate)
@@ -460,7 +460,7 @@ def test_buffer_store_predicate_elements_invalid_type():
predicate = tvm.tirx.expr.Broadcast(1, 4 * tvm.tirx.vscale())
err_msg = "Predicate mask elements must be boolean values, but got int32."
- with pytest.raises(tvm.TVMError, match=err_msg):
+ with pytest.raises(RuntimeError, match=err_msg):
tvm.tirx.BufferStore(b, value, [index], predicate)
@@ -470,7 +470,7 @@ def test_buffer_load_predicate_elements_invalid_type():
predicate = tvm.tirx.expr.Broadcast(1, 4 * tvm.tirx.vscale())
err_msg = "Predicate mask elements must be boolean values, but got int32."
- with pytest.raises(tvm.TVMError, match=err_msg):
+ with pytest.raises(RuntimeError, match=err_msg):
tvm.tirx.BufferLoad(b, [index], predicate)
@@ -480,7 +480,7 @@ def test_buffer_store_predicate_invalid_scalability():
predicate = tvm.tirx.expr.Broadcast(tvm.tirx.IntImm("int1", 1), 4)
err_msg = "Predicate mask dtype and load indices must both be scalable."
- with pytest.raises(tvm.TVMError, match=err_msg):
+ with pytest.raises(RuntimeError, match=err_msg):
tvm.tirx.BufferLoad(b, [index], predicate)
@@ -493,7 +493,7 @@ def test_buffer_store_predicate_invalid_lanes():
"Got a predicate mask with 8 lanes, but trying to load a "
"vector with 4 lanes. The number of lanes must match."
)
- with pytest.raises(tvm.TVMError, match=err_msg):
+ with pytest.raises(RuntimeError, match=err_msg):
tvm.tirx.BufferLoad(b, [index], predicate)
diff --git a/tests/python/tirx-base/test_tir_ops.py
b/tests/python/tirx-base/test_tir_ops.py
index 34b62e0fd1..076a4a8afc 100644
--- a/tests/python/tirx-base/test_tir_ops.py
+++ b/tests/python/tirx-base/test_tir_ops.py
@@ -24,7 +24,7 @@ import tvm.testing
def check_throws(f):
try:
f()
- except tvm.error.TVMError:
+ except RuntimeError:
pass
else:
raise AssertionError("Should have raised an exception but didn't.")
diff --git a/tests/python/tirx-base/test_tir_unsafe_hide_buffer_access.py
b/tests/python/tirx-base/test_tir_unsafe_hide_buffer_access.py
index 12b4833628..222ea7260e 100644
--- a/tests/python/tirx-base/test_tir_unsafe_hide_buffer_access.py
+++ b/tests/python/tirx-base/test_tir_unsafe_hide_buffer_access.py
@@ -92,14 +92,14 @@ def test_hide_buffer_access_write():
def test_hide_buffer_access_fail_buffer_type():
sch = tvm.s_tir.Schedule(indirect_mem_access, debug_mask="all")
block_b = sch.get_sblock("B")
- with pytest.raises(tvm.error.TVMError):
+ with pytest.raises(RuntimeError):
sch.unsafe_hide_buffer_access(block_b, "opaque", [0])
def test_hide_buffer_access_fail_buffer_index():
sch = tvm.s_tir.Schedule(indirect_mem_access, debug_mask="all")
block_b = sch.get_sblock("B")
- with pytest.raises(tvm.error.TVMError):
+ with pytest.raises(RuntimeError):
sch.unsafe_hide_buffer_access(block_b, "read", [2])
diff --git
a/tests/python/tirx-transform/test_tir_transform_force_narrow_index_to_i32.py
b/tests/python/tirx-transform/test_tir_transform_force_narrow_index_to_i32.py
index 6668100719..8bf1965a64 100644
---
a/tests/python/tirx-transform/test_tir_transform_force_narrow_index_to_i32.py
+++
b/tests/python/tirx-transform/test_tir_transform_force_narrow_index_to_i32.py
@@ -19,7 +19,6 @@ import pytest
import tvm
import tvm.testing
-from tvm import TVMError
from tvm.script import tirx as T
@@ -216,7 +215,7 @@ def test_fail_on_buffer_map():
B[vi] = A[vi] + T.int64(1)
mod = tvm.IRModule.from_expr(func)
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
tvm.tirx.transform.ForceNarrowIndexToInt32()(mod)["main"]
@@ -236,7 +235,7 @@ def test_fail_on_buffer_map():
B[vi] = T.cast(C[vi] + T.int64(1), "int32")
mod = tvm.IRModule.from_expr(func)
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
tvm.tirx.transform.ForceNarrowIndexToInt32()(mod)["main"]
diff --git
a/tests/python/tirx-transform/test_tir_transform_lower_tvm_builtin.py
b/tests/python/tirx-transform/test_tir_transform_lower_tvm_builtin.py
index 8a4e49a755..bbd6df01ef 100644
--- a/tests/python/tirx-transform/test_tir_transform_lower_tvm_builtin.py
+++ b/tests/python/tirx-transform/test_tir_transform_lower_tvm_builtin.py
@@ -239,7 +239,7 @@ def test_lower_allocate_requires_device_id():
buf = T.decl_buffer(16, "float32", data=ptr.data)
buf[0] = 0.0
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
tvm.tirx.transform.LowerTVMBuiltin()(Before)
@@ -263,7 +263,7 @@ def test_lower_allocate_requires_device_type():
buf = T.decl_buffer(1024 * 1024, "float32", data=ptr.data)
buf[0] = 0.0
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
tvm.tirx.transform.LowerTVMBuiltin()(Before)
diff --git a/tests/python/tirx-transform/test_tir_transform_vectorize.py
b/tests/python/tirx-transform/test_tir_transform_vectorize.py
index 13c8534e80..301e05b653 100644
--- a/tests/python/tirx-transform/test_tir_transform_vectorize.py
+++ b/tests/python/tirx-transform/test_tir_transform_vectorize.py
@@ -354,7 +354,7 @@ def test_vectorize_while_fail():
try:
tvm.compile(Module, target="llvm")
assert False
- except tvm.error.TVMError as e:
+ except RuntimeError as e:
error_msg = str(e).split("\n")[-1]
expected = "A while loop inside a vectorized loop not supported"
assert expected in error_msg
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index 878fd39743..a26f74b4af 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -608,7 +608,7 @@ def test_block_annotation_merge():
assert _to_dict(func2.body.block.annotations) == {"key1": "block1"}
- with pytest.raises(tvm.TVMError):
+ with pytest.raises(RuntimeError):
@T.prim_func(s_tir=True)
def func3():
diff --git a/tests/python/tvmscript/test_tvmscript_printer_ir.py
b/tests/python/tvmscript/test_tvmscript_printer_ir.py
index f2044d63c0..027317497c 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_ir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_ir.py
@@ -18,7 +18,7 @@
import pytest
-from tvm import IRModule, TVMError
+from tvm import IRModule
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import ir as I
from tvm.script.ir_builder import tirx as T
@@ -59,7 +59,7 @@ def test_failed_invalid_prefix():
T.func_name("foo")
mod = ib.get()
- with pytest.raises(TVMError):
+ with pytest.raises(RuntimeError):
mod.script(ir_prefix="2I")