Lunderberg commented on a change in pull request #8010:
URL: https://github.com/apache/tvm/pull/8010#discussion_r656737002



##########
File path: python/tvm/testing.py
##########
@@ -718,31 +844,421 @@ def parametrize_targets(*args):
 
     Example
     -------
-    >>> @tvm.testing.parametrize
+    >>> @tvm.testing.parametrize_targets("llvm", "cuda")
+    >>> def test_mytest(target, dev):
+    >>>     ...  # do something
+    """
+
+    def wrap(targets):
+        def func(f):
+            return pytest.mark.parametrize(
+                "target", _pytest_target_params(targets), scope="session"
+            )(f)
+
+        return func
+
+    if len(args) == 1 and callable(args[0]):
+        return wrap(None)(args[0])
+    return wrap(args)
+
+
+def exclude_targets(*args):
+    """Exclude a test from running on a particular target.
+
+    Use this decorator when you want your test to be run over a
+    variety of targets and devices (including cpu and gpu devices),
+    but want to exclude some particular target or targets.  For
+    example, a test may wish to be run against all targets in
+    tvm.testing.enabled_targets(), except for a particular target that
+    does not support the capabilities.
+
+    Applies pytest.mark.skipif to the targets given.
+
+    Parameters
+    ----------
+    f : function
+        Function to parametrize. Must be of the form `def 
test_xxxxxxxxx(target, dev)`:,
+        where `xxxxxxxxx` is any name.
+    targets : list[str]
+        Set of targets to exclude.
+
+    Example
+    -------
+    >>> @tvm.testing.exclude_targets("cuda")
     >>> def test_mytest(target, dev):
     >>>     ...  # do something
 
     Or
 
-    >>> @tvm.testing.parametrize("llvm", "cuda")
+    >>> @tvm.testing.exclude_targets("llvm", "cuda")
     >>> def test_mytest(target, dev):
     >>>     ...  # do something
+
     """
 
-    def wrap(targets):
-        def func(f):
-            params = [
-                pytest.param(target, tvm.device(target, 0), 
marks=_target_to_requirement(target))
-                for target in targets
-            ]
-            return pytest.mark.parametrize("target,dev", params)(f)
+    def wraps(func):
+        func.tvm_excluded_targets = args
+        return func
+
+    return wraps
+
+
+def known_failing_targets(*args):
+    """Skip a test that is known to fail on a particular target.
+
+    Use this decorator when you want your test to be run over a
+    variety of targets and devices (including cpu and gpu devices),
+    but know that it fails for some targets.  For example, a newly
+    implemented runtime may not support all features being tested, and
+    should be excluded.
+
+    Applies pytest.mark.xfail to the targets given.
 
+    Parameters
+    ----------
+    f : function
+        Function to parametrize. Must be of the form `def 
test_xxxxxxxxx(target, dev)`:,
+        where `xxxxxxxxx` is any name.
+    targets : list[str]
+        Set of targets to skip.
+
+    Example
+    -------
+    >>> @tvm.testing.known_failing_targets("cuda")
+    >>> def test_mytest(target, dev):
+    >>>     ...  # do something
+
+    Or
+
+    >>> @tvm.testing.known_failing_targets("llvm", "cuda")
+    >>> def test_mytest(target, dev):
+    >>>     ...  # do something
+
+    """
+
+    def wraps(func):
+        func.tvm_known_failing_targets = args
         return func
 
-    if len(args) == 1 and callable(args[0]):
-        targets = [t for t, _ in enabled_targets()]
-        return wrap(targets)(args[0])
-    return wrap(args)
+    return wraps
+
+
+def parameter(*values, ids=None):
+    """Convenience function to define pytest parametrized fixtures.
+
+    Declaring a variable using ``tvm.testing.parameter`` will define a
+    parametrized pytest fixture that can be used by test
+    functions. This is intended for cases that have no setup cost,
+    such as strings, integers, tuples, etc.  For cases that have a
+    significant setup cost, please use :py:func:`tvm.testing.fixture`
+    instead.
+
+    If a test function accepts multiple parameters defined using
+    ``tvm.testing.parameter``, then the test will be run using every
+    combination of those parameters.
+
+    The parameter definition applies to all tests in a module.  If a
+    specific test should have different values for the parameter, that
+    test should be marked with ``@pytest.mark.parametrize``.
+
+    Parameters
+    ----------
+    values
+       A list of parameter values.  A unit test that accepts this
+       parameter as an argument will be run once for each parameter
+       given.
+
+    ids : List[str], optional
+       A list of names for the parameters.  If None, pytest will
+       generate a name from the value.  These generated names may not
+       be readable/useful for composite types such as tuples.
+
+    Returns
+    -------
+    function
+       A function output from pytest.fixture.
+
+    Example
+    -------
+    >>> size = tvm.testing.parameter(1, 10, 100)
+    >>> def test_using_size(size):
+    >>>     ... # Test code here
+
+    Or
+
+    >>> shape = tvm.testing.parameter((5,10), (512,1024), 
ids=['small','large'])
+    >>> def test_using_size(shape):
+    >>>     ... # Test code here
+
+    """
+
+    # Optional cls parameter in case a parameter is defined inside a
+    # class scope.
+    @pytest.fixture(params=values, ids=ids)
+    def as_fixture(*cls, request):
+        return request.param
+
+    return as_fixture
+
+
+_parametrize_group = 0
+
+
+def parameters(*value_sets):
+    """Convenience function to define pytest parametrized fixtures.
+
+    Declaring a variable using tvm.testing.parameters will define a
+    parametrized pytest fixture that can be used by test
+    functions. Like :py:func:`tvm.testing.parameter`, this is intended
+    for cases that have no setup cost, such as strings, integers,
+    tuples, etc.  For cases that have a significant setup cost, please
+    use :py:func:`tvm.testing.fixture` instead.
+
+    Unlike :py:func:`tvm.testing.parameter`, if a test function
+    accepts multiple parameters defined using a single call to
+    ``tvm.testing.parameters``, then the test will only be run once
+    for each set of parameters, not for all combinations of
+    parameters.
+
+    These parameter definitions apply to all tests in a module.  If a
+    specific test should have different values for some parameters,
+    that test should be marked with ``@pytest.mark.parametrize``.
+
+    Parameters
+    ----------
+    values : List[tuple]
+       A list of parameter value sets.  Each set of values represents
+       a single combination of values to be tested.  A unit test that
+       accepts parameters defined will be run once for every set of
+       parameters in the list.
+
+    Returns
+    -------
+    List[function]
+       Function outputs from pytest.fixture.  These should be unpacked
+       into individual named parameters.
+
+    Example
+    -------
+    >>> size, dtype = tvm.testing.parameters( (16,'float32'), (512,'float16') )
+    >>> def test_feature_x(size, dtype):
+    >>>     # Test code here
+    >>>     assert( (size,dtype) in [(16,'float32'), (512,'float16')])
+
+    """
+    global _parametrize_group
+    parametrize_group = _parametrize_group
+    _parametrize_group += 1
+
+    outputs = []
+    for param_values in zip(*value_sets):
+
+        # Optional cls parameter in case a parameter is defined inside a
+        # class scope.
+        def fixture_func(*cls, request):

Review comment:
       And same change made here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to