This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 92de8e9afe [Transform] Use callable() instead of isinstance() for type
checking (#14248)
92de8e9afe is described below
commit 92de8e9afeac74567a3157ce1f2d4375245bde3a
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sat Mar 11 10:41:10 2023 -0600
[Transform] Use callable() instead of isinstance() for type checking
(#14248)
Previously, type-checking of a callable arguments, such as to
`tvm.ir.transform.module_pass`, was done using
`isinstance(arg, (types.FunctionType, types.LambdaType))`. This check
can give false negatives for valid python types, such as a bound
method or an instance of a class that implements `__call__`.
This commit replaces the checks with the builtin function `callable()`,
which handles any Python object that can be called using function-like
syntax.
---
python/tvm/ir/transform.py | 3 +--
python/tvm/relay/transform/transform.py | 2 +-
python/tvm/te/hybrid/parser.py | 3 +--
python/tvm/tir/transform/function_pass.py | 3 +--
4 files changed, 4 insertions(+), 7 deletions(-)
diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py
index 17995bfa78..f7d40dc681 100644
--- a/python/tvm/ir/transform.py
+++ b/python/tvm/ir/transform.py
@@ -16,7 +16,6 @@
# under the License.
# pylint: disable=invalid-name,unused-argument
"""Common pass infrastructure across IR variants."""
-import types
import inspect
import functools
@@ -340,7 +339,7 @@ def module_pass(pass_func=None, opt_level=None, name=None,
required=None):
info = PassInfo(opt_level, fname, required)
if inspect.isclass(pass_arg):
return _wrap_class_module_pass(pass_arg, info)
- if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
+ if not callable(pass_arg):
raise TypeError("pass_func must be a callable for Module pass")
return _ffi_transform_api.MakeModulePass(pass_arg, info)
diff --git a/python/tvm/relay/transform/transform.py
b/python/tvm/relay/transform/transform.py
index 1f5b91da44..4c609620cb 100644
--- a/python/tvm/relay/transform/transform.py
+++ b/python/tvm/relay/transform/transform.py
@@ -1049,7 +1049,7 @@ def function_pass(pass_func=None, opt_level=None,
name=None, required=None):
info = tvm.transform.PassInfo(opt_level, fname, required)
if inspect.isclass(pass_arg):
return _wrap_class_function_pass(pass_arg, info)
- if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
+ if not callable(pass_arg):
raise TypeError("pass_func must be a callable for Module pass")
return _ffi_api.MakeFunctionPass(pass_arg, info)
diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py
index ec103ac188..bd47e41630 100644
--- a/python/tvm/te/hybrid/parser.py
+++ b/python/tvm/te/hybrid/parser.py
@@ -20,7 +20,6 @@ import ast
import operator
import logging
import sys
-import types
import numbers
from enum import Enum
@@ -142,7 +141,7 @@ class HybridParser(ast.NodeVisitor):
self.symbols = {} # Symbol table
for k, v in symbols.items():
- if isinstance(v, types.FunctionType):
+ if callable(v):
self.add_symbol(k, Symbol.Callable, v)
self.closure_vars = closure_vars
diff --git a/python/tvm/tir/transform/function_pass.py
b/python/tvm/tir/transform/function_pass.py
index 9450ade34e..9fa0e3bc18 100644
--- a/python/tvm/tir/transform/function_pass.py
+++ b/python/tvm/tir/transform/function_pass.py
@@ -16,7 +16,6 @@
# under the License.
"""TIR specific function pass support."""
import inspect
-import types
import functools
from typing import Callable, List, Optional, Union
@@ -151,7 +150,7 @@ def prim_func_pass(
info = PassInfo(opt_level, fname, required)
if inspect.isclass(pass_arg):
return _wrap_class_function_pass(pass_arg, info)
- if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
+ if not callable(pass_arg):
raise TypeError("pass_func must be a callable for Module pass")
return _ffi_api.CreatePrimFuncPass(pass_arg, info) # type: ignore