This is an automated email from the ASF dual-hosted git repository. wuwei pushed a commit to branch vk-i64 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 15189caf985e2edcb56b679497cd265f298395d2 Author: Wuwei Lin <wu...@apache.org> AuthorDate: Thu Mar 4 17:03:35 2021 -0500 add test --- tests/python/topi/python/test_topi_cumsum.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py index 6b99239..79330e7 100644 --- a/tests/python/topi/python/test_topi_cumsum.py +++ b/tests/python/topi/python/test_topi_cumsum.py @@ -29,6 +29,7 @@ def test_cumsum(ctx, target): "cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), "nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), "vulkan": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), + "metal": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule) @@ -47,6 +48,9 @@ def test_cumsum(ctx, target): check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32") for in_dtype in ["float32", "float64"]: + if str(target.kind) == 'metal' and in_dtype == 'float64': + # float64 is not supported in metal + continue data = np.random.randn(10, 10).astype(in_dtype) check_cumsum(np.cumsum(data), data) check_cumsum(np.cumsum(data, axis=0), data, axis=0) @@ -74,3 +78,4 @@ if __name__ == "__main__": test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda")) test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx")) test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan")) + test_cumsum(tvm.context("metal"), tvm.target.Target("metal"))