This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b9881e  [CODEGEN][OpenCL]: fix tir.erf codegen to opencl directly 
(#8756)
4b9881e is described below

commit 4b9881ec50008bc14fc1ae7805413544cf962011
Author: Yuan-Chuan-YUE <69908243+yuan-chuan-...@users.noreply.github.com>
AuthorDate: Sun Aug 22 05:42:21 2021 +0800

    [CODEGEN][OpenCL]: fix tir.erf codegen to opencl directly (#8756)
    
    * register tir.erf to lower opencl directly
    
    * add opencl codegen unit test
    
    * change erf opencl codegen unit test for checking there is erf in the 
source not erff
---
 src/target/source/intrin_rule_opencl.cc             |  3 +++
 tests/python/unittest/test_target_codegen_opencl.py | 20 ++++++++++++++++++++
 2 files changed, 23 insertions(+)

diff --git a/src/target/source/intrin_rule_opencl.cc 
b/src/target/source/intrin_rule_opencl.cc
index 288bb2c..64a50c3 100644
--- a/src/target/source/intrin_rule_opencl.cc
+++ b/src/target/source/intrin_rule_opencl.cc
@@ -49,6 +49,9 @@ TVM_REGISTER_OP("tir.round")
 TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic",
                                                      
DispatchPureExtern<Direct>);
 
+TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic",
+                                                     
DispatchPureExtern<Direct>);
+
 TVM_REGISTER_OP("tir.exp2")
     .set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", 
DispatchPureExtern<Direct>);
 
diff --git a/tests/python/unittest/test_target_codegen_opencl.py 
b/tests/python/unittest/test_target_codegen_opencl.py
index 98340f0..56392ec 100644
--- a/tests/python/unittest/test_target_codegen_opencl.py
+++ b/tests/python/unittest/test_target_codegen_opencl.py
@@ -17,6 +17,7 @@
 import tvm
 from tvm import te
 import tvm.testing
+import re
 
 target = "opencl"
 
@@ -120,6 +121,25 @@ def test_opencl_max():
     check_max(dev, 1, "float64")
 
 
+def test_opencl_erf():
+    def check_erf(dev, n, dtype):
+        A = te.placeholder((n,), name="A", dtype=dtype)
+        C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C")
+        s = te.create_schedule(C.op)
+        s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
+        fun = tvm.build(s, [A, C], target)
+        source_str = fun.imported_modules[0].get_source()
+        matches = re.findall("erf", source_str)
+        error_matches = re.findall("erff", source_str)
+        assert len(matches) == 1 and len(error_matches) == 0
+
+    dev = tvm.device(target, 0)
+
+    check_erf(dev, 1, "float32")
+    check_erf(dev, 1, "float64")
+
+
 if __name__ == "__main__":
     test_opencl_ternary_expression()
     test_opencl_inf_nan()
+    test_opencl_erf()

Reply via email to