This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b906554af [OpenCL] Add OpenCL device for automatic target detection 
(#16854)
4b906554af is described below

commit 4b906554af2bad9859b405f694b1c59d77d74785
Author: Mengshiun Yu <mengs...@gmail.com>
AuthorDate: Thu Apr 11 19:12:23 2024 +0800

    [OpenCL] Add OpenCL device for automatic target detection (#16854)
    
    This PR adds OpenCL device for automatic target detection.
---
 python/tvm/target/detect_target.py        | 14 +++++++++++++-
 tests/python/target/test_target_target.py | 12 ++++++++++++
 2 files changed, 25 insertions(+), 1 deletion(-)

diff --git a/python/tvm/target/detect_target.py 
b/python/tvm/target/detect_target.py
index a2fe5e1f8b..b23baa0313 100644
--- a/python/tvm/target/detect_target.py
+++ b/python/tvm/target/detect_target.py
@@ -58,6 +58,17 @@ def _detect_rocm(dev: Device) -> Target:
     )
 
 
+def _detect_opencl(dev: Device) -> Target:
+    return Target(
+        {
+            "kind": "opencl",
+            "max_shared_memory_per_block": dev.max_shared_memory_per_block,
+            "max_threads_per_block": dev.max_threads_per_block,
+            "thread_warp_size": dev.warp_size,
+        }
+    )
+
+
 def _detect_vulkan(dev: Device) -> Target:
     f_get_target_property = 
get_global_func("device_api.vulkan.get_target_property")
     return Target(
@@ -100,7 +111,7 @@ def detect_target_from_device(dev: Union[str, Device]) -> 
Target:
     ----------
     dev : Union[str, Device]
         The device to detect the target for.
-        Supported device types: ["cuda", "metal", "rocm", "vulkan"]
+        Supported device types: ["cuda", "metal", "rocm", "vulkan", "opencl"]
 
     Returns
     -------
@@ -129,4 +140,5 @@ SUPPORT_DEVICE = {
     "metal": _detect_metal,
     "vulkan": _detect_vulkan,
     "rocm": _detect_rocm,
+    "opencl": _detect_opencl,
 }
diff --git a/tests/python/target/test_target_target.py 
b/tests/python/target/test_target_target.py
index 83bd864970..e977ef10aa 100644
--- a/tests/python/target/test_target_target.py
+++ b/tests/python/target/test_target_target.py
@@ -547,5 +547,17 @@ def test_target_from_device_rocm(input_device):
     )
 
 
+@tvm.testing.requires_opencl
+@pytest.mark.parametrize("input_device", ["opencl", tvm.opencl()])
+def test_target_from_device_opencl(input_device):
+    target = Target.from_device(input_device)
+
+    dev = tvm.opencl()
+    assert target.kind.name == "opencl"
+    assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block
+    assert target.max_shared_memory_per_block == 
dev.max_shared_memory_per_block
+    assert target.thread_warp_size == dev.warp_size
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to