This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 4b906554af [OpenCL] Add OpenCL device for automatic target detection (#16854) 4b906554af is described below commit 4b906554af2bad9859b405f694b1c59d77d74785 Author: Mengshiun Yu <mengs...@gmail.com> AuthorDate: Thu Apr 11 19:12:23 2024 +0800 [OpenCL] Add OpenCL device for automatic target detection (#16854) This PR adds OpenCL device for automatic target detection. --- python/tvm/target/detect_target.py | 14 +++++++++++++- tests/python/target/test_target_target.py | 12 ++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/python/tvm/target/detect_target.py b/python/tvm/target/detect_target.py index a2fe5e1f8b..b23baa0313 100644 --- a/python/tvm/target/detect_target.py +++ b/python/tvm/target/detect_target.py @@ -58,6 +58,17 @@ def _detect_rocm(dev: Device) -> Target: ) +def _detect_opencl(dev: Device) -> Target: + return Target( + { + "kind": "opencl", + "max_shared_memory_per_block": dev.max_shared_memory_per_block, + "max_threads_per_block": dev.max_threads_per_block, + "thread_warp_size": dev.warp_size, + } + ) + + def _detect_vulkan(dev: Device) -> Target: f_get_target_property = get_global_func("device_api.vulkan.get_target_property") return Target( @@ -100,7 +111,7 @@ def detect_target_from_device(dev: Union[str, Device]) -> Target: ---------- dev : Union[str, Device] The device to detect the target for. - Supported device types: ["cuda", "metal", "rocm", "vulkan"] + Supported device types: ["cuda", "metal", "rocm", "vulkan", "opencl"] Returns ------- @@ -129,4 +140,5 @@ SUPPORT_DEVICE = { "metal": _detect_metal, "vulkan": _detect_vulkan, "rocm": _detect_rocm, + "opencl": _detect_opencl, } diff --git a/tests/python/target/test_target_target.py b/tests/python/target/test_target_target.py index 83bd864970..e977ef10aa 100644 --- a/tests/python/target/test_target_target.py +++ b/tests/python/target/test_target_target.py @@ -547,5 +547,17 @@ def test_target_from_device_rocm(input_device): ) +@tvm.testing.requires_opencl +@pytest.mark.parametrize("input_device", ["opencl", tvm.opencl()]) +def test_target_from_device_opencl(input_device): + target = Target.from_device(input_device) + + dev = tvm.opencl() + assert target.kind.name == "opencl" + assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block + assert target.max_shared_memory_per_block == dev.max_shared_memory_per_block + assert target.thread_warp_size == dev.warp_size + + if __name__ == "__main__": tvm.testing.main()