This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6252fa5802 [TIR] Enhance CLZ intrinsic support (#16952)
6252fa5802 is described below

commit 6252fa5802c94df522306519da94b874b3a45eda
Author: Siyuan Feng <hzfen...@sjtu.edu.cn>
AuthorDate: Tue Apr 30 14:14:44 2024 +0800

    [TIR] Enhance CLZ intrinsic support (#16952)
---
 .github/workflows/main.yml                         |  2 +
 src/target/intrin_rule.h                           | 18 +++++--
 src/target/source/intrin_rule_cuda.cc              | 12 +++++
 src/target/source/intrin_rule_metal.cc             |  3 ++
 src/target/source/intrin_rule_opencl.cc            |  3 ++
 src/tir/ir/data_type_rewriter.cc                   |  6 ++-
 .../codegen/test_target_codegen_gpu_common.py      | 55 ++++++++++++++++++++++
 7 files changed, 94 insertions(+), 5 deletions(-)

diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index d63af560d7..759acd1fa5 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -77,6 +77,8 @@ jobs:
       - name: Minimal Metal Compile-and-Run
         shell: bash -l {0}
         run: >-
+          python -m pytest -v -s 
'tests/python/codegen/test_target_codegen_metal.py'
+          python -m pytest -v -s 
'tests/python/codegen/test_target_codegen_gpu_common.py'
           python -m pytest -v -s 
'tests/python/codegen/test_gpu_codegen_allreduce.py::test_allreduce_sum[dims0-metal]'
 #      - name: Test iOS RPC
 #        shell: bash -l {0}
diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h
index 2695c43173..ea8ccd98b1 100644
--- a/src/target/intrin_rule.h
+++ b/src/target/intrin_rule.h
@@ -53,8 +53,13 @@ struct Direct {
   std::string operator()(DataType t, std::string name) const { return name; }
 };
 
-// Call pure extern function.
-template <typename T>
+/*!
+ * \brief Dispatch pure extern function.
+ * \param e The call expression.
+ * \tparam T The function to dispatch.
+ * \tparam dtype_from_arg Whether the dtype is from the first argument or the 
call node
+ */
+template <typename T, bool dtype_from_arg = false>
 inline PrimExpr DispatchPureExtern(const PrimExpr& e) {
   const CallNode* call = e.as<CallNode>();
   ICHECK(call != nullptr);
@@ -64,7 +69,14 @@ inline PrimExpr DispatchPureExtern(const PrimExpr& e) {
   ICHECK(op != nullptr);
   std::string name = op->name;
   ICHECK_EQ(name.substr(0, 4), "tir.");
-  name = T()(call->dtype, name.substr(4));
+  DataType dtype;
+  if (dtype_from_arg) {
+    ICHECK_EQ(call->args.size(), 1U);
+    dtype = call->args[0].dtype();
+  } else {
+    dtype = call->dtype;
+  }
+  name = T()(dtype, name.substr(4));
 
   if (name.length() != 0) {
     Array<PrimExpr> new_args = {StringImm(name)};
diff --git a/src/target/source/intrin_rule_cuda.cc 
b/src/target/source/intrin_rule_cuda.cc
index 95fbf7f1a5..79ea7a458f 100644
--- a/src/target/source/intrin_rule_cuda.cc
+++ b/src/target/source/intrin_rule_cuda.cc
@@ -54,6 +54,15 @@ struct CUDAMath {
       }
     } else if (t.is_bfloat16()) {
       return 'h' + name;
+    } else if (t.is_int() || t.is_uint()) {
+      switch (t.bits()) {
+        case 32:
+          return "__" + name;
+        case 64:
+          return "__" + name + "ll";
+        default:
+          return "";
+      }
     }
     return "";
   }
@@ -133,6 +142,9 @@ static PrimExpr DispatchCUDAShuffle(const PrimExpr& e) {
   return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), 
cuda_args);
 }
 
+TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>(
+    "cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath, 
/*dtype_from_arg=*/true>);
+
 TVM_REGISTER_OP("tir.floor")
     .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", 
DispatchPureExtern<CUDAMath>);
 
diff --git a/src/target/source/intrin_rule_metal.cc 
b/src/target/source/intrin_rule_metal.cc
index 50685f6ef2..b7561e8671 100644
--- a/src/target/source/intrin_rule_metal.cc
+++ b/src/target/source/intrin_rule_metal.cc
@@ -52,6 +52,9 @@ static PrimExpr DispatchMetalShuffle(const PrimExpr& e) {
   return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), 
metal_args);
 }
 
+TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
+                                                     
DispatchPureExtern<Direct>);
+
 TVM_REGISTER_OP("tir.floor")
     .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", 
DispatchPureExtern<Direct>);
 
diff --git a/src/target/source/intrin_rule_opencl.cc 
b/src/target/source/intrin_rule_opencl.cc
index 94ab9d8b9d..bd9e148b18 100644
--- a/src/target/source/intrin_rule_opencl.cc
+++ b/src/target/source/intrin_rule_opencl.cc
@@ -31,6 +31,9 @@ namespace codegen {
 namespace intrin {
 using tir::FLowerIntrinsic;
 
+TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic",
+                                                     
DispatchPureExtern<Direct>);
+
 TVM_REGISTER_OP("tir.floor")
     .set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", 
DispatchPureExtern<Direct>);
 
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index a613b8d4bb..c03e19137e 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -238,10 +238,12 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* 
op) {
   } else if (op->op.same_as(Op::Get("tir.clz"))) {
     DataType before_dtype = before->args[0]->dtype;
     DataType after_dtype = op->args[0]->dtype;
-    CHECK(before_dtype.is_int() && (before_dtype.bits() == 32 || 
before_dtype.bits() == 64))
+    CHECK((before_dtype.is_int() || before_dtype.is_uint()) &&
+          (before_dtype.bits() == 32 || before_dtype.bits() == 64))
         << "clz only supports 32 or 64 bit integer types, but get type before 
legalizing: "
         << before_dtype;
-    CHECK(after_dtype.is_int() && (after_dtype.bits() == 32 || 
after_dtype.bits() == 64))
+    CHECK((after_dtype.is_int() || after_dtype.is_uint()) &&
+          (after_dtype.bits() == 32 || after_dtype.bits() == 64))
         << "clz only supports 32 or 64 bit integer types, but get type after 
legalizing: "
         << after_dtype;
     return e - after_dtype.bits() + before_dtype.bits();
diff --git a/tests/python/codegen/test_target_codegen_gpu_common.py 
b/tests/python/codegen/test_target_codegen_gpu_common.py
new file mode 100644
index 0000000000..2941f366a4
--- /dev/null
+++ b/tests/python/codegen/test_target_codegen_gpu_common.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from functools import partial
+
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+from tvm import te
+
+
+@tvm.testing.requires_gpu
+@tvm.testing.parametrize_targets("cuda", "metal", "vulkan -supports_int64=1", 
"opencl")
+@pytest.mark.parametrize("dtype", ["int32", "uint32", "int64", "uint64"])
+def test_int_intrin(target, dev, dtype):
+    test_funcs = [
+        (tvm.tir.clz, lambda x, dtype: int(dtype[-2:]) - (len(bin(x)) - 2)),
+    ]
+
+    def run_test(tvm_intrin, np_func, dtype):
+        n = 128
+        A = te.placeholder((n,), name="A", dtype=dtype)
+        B = te.compute(A.shape, lambda *i: tvm_intrin(A(*i)), name="B")
+        func = te.create_prim_func([A, B])
+        sch = tvm.tir.Schedule(func)
+        (x,) = sch.get_loops(sch.get_block("B"))
+        sch.bind(x, "threadIdx.x")
+        f = tvm.build(sch.mod, target=target)
+        a = tvm.nd.array(np.random.randint(0, 100000, size=n).astype(A.dtype), 
dev)
+        b = tvm.nd.array(np.zeros(shape=(n,)).astype(B.dtype), dev)
+        f(a, b)
+        ref = np.vectorize(partial(np_func, dtype=dtype))(a.numpy())
+        tvm.testing.assert_allclose(b.numpy(), ref)
+
+    for func in test_funcs:
+        run_test(*func, dtype)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to