This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 92de8e9afe [Transform] Use callable() instead of isinstance() for type 
checking (#14248)
92de8e9afe is described below

commit 92de8e9afeac74567a3157ce1f2d4375245bde3a
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sat Mar 11 10:41:10 2023 -0600

    [Transform] Use callable() instead of isinstance() for type checking 
(#14248)
    
    Previously, type-checking of a callable arguments, such as to
    `tvm.ir.transform.module_pass`, was done using
    `isinstance(arg, (types.FunctionType, types.LambdaType))`.  This check
    can give false negatives for valid python types, such as a bound
    method or an instance of a class that implements `__call__`.
    
    This commit replaces the checks with the builtin function `callable()`,
    which handles any Python object that can be called using function-like
    syntax.
---
 python/tvm/ir/transform.py                | 3 +--
 python/tvm/relay/transform/transform.py   | 2 +-
 python/tvm/te/hybrid/parser.py            | 3 +--
 python/tvm/tir/transform/function_pass.py | 3 +--
 4 files changed, 4 insertions(+), 7 deletions(-)

diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py
index 17995bfa78..f7d40dc681 100644
--- a/python/tvm/ir/transform.py
+++ b/python/tvm/ir/transform.py
@@ -16,7 +16,6 @@
 # under the License.
 # pylint: disable=invalid-name,unused-argument
 """Common pass infrastructure across IR variants."""
-import types
 import inspect
 import functools
 
@@ -340,7 +339,7 @@ def module_pass(pass_func=None, opt_level=None, name=None, 
required=None):
         info = PassInfo(opt_level, fname, required)
         if inspect.isclass(pass_arg):
             return _wrap_class_module_pass(pass_arg, info)
-        if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
+        if not callable(pass_arg):
             raise TypeError("pass_func must be a callable for Module pass")
         return _ffi_transform_api.MakeModulePass(pass_arg, info)
 
diff --git a/python/tvm/relay/transform/transform.py 
b/python/tvm/relay/transform/transform.py
index 1f5b91da44..4c609620cb 100644
--- a/python/tvm/relay/transform/transform.py
+++ b/python/tvm/relay/transform/transform.py
@@ -1049,7 +1049,7 @@ def function_pass(pass_func=None, opt_level=None, 
name=None, required=None):
         info = tvm.transform.PassInfo(opt_level, fname, required)
         if inspect.isclass(pass_arg):
             return _wrap_class_function_pass(pass_arg, info)
-        if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
+        if not callable(pass_arg):
             raise TypeError("pass_func must be a callable for Module pass")
         return _ffi_api.MakeFunctionPass(pass_arg, info)
 
diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py
index ec103ac188..bd47e41630 100644
--- a/python/tvm/te/hybrid/parser.py
+++ b/python/tvm/te/hybrid/parser.py
@@ -20,7 +20,6 @@ import ast
 import operator
 import logging
 import sys
-import types
 import numbers
 
 from enum import Enum
@@ -142,7 +141,7 @@ class HybridParser(ast.NodeVisitor):
 
         self.symbols = {}  # Symbol table
         for k, v in symbols.items():
-            if isinstance(v, types.FunctionType):
+            if callable(v):
                 self.add_symbol(k, Symbol.Callable, v)
 
         self.closure_vars = closure_vars
diff --git a/python/tvm/tir/transform/function_pass.py 
b/python/tvm/tir/transform/function_pass.py
index 9450ade34e..9fa0e3bc18 100644
--- a/python/tvm/tir/transform/function_pass.py
+++ b/python/tvm/tir/transform/function_pass.py
@@ -16,7 +16,6 @@
 # under the License.
 """TIR specific function pass support."""
 import inspect
-import types
 import functools
 from typing import Callable, List, Optional, Union
 
@@ -151,7 +150,7 @@ def prim_func_pass(
         info = PassInfo(opt_level, fname, required)
         if inspect.isclass(pass_arg):
             return _wrap_class_function_pass(pass_arg, info)
-        if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
+        if not callable(pass_arg):
             raise TypeError("pass_func must be a callable for Module pass")
         return _ffi_api.CreatePrimFuncPass(pass_arg, info)  # type: ignore
 

Reply via email to