This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 95cb0de27a [VULKAN] Fix CLZ support for Vulkan (#16858)
95cb0de27a is described below

commit 95cb0de27a8bcfe0586f38d8b0d2da955cf01432
Author: Siyuan Feng <hzfen...@sjtu.edu.cn>
AuthorDate: Wed Apr 10 20:21:20 2024 +0800

    [VULKAN] Fix CLZ support for Vulkan (#16858)
    
    CLZ (counting leading zeros) is used for improving ceil_log2 performance
    on vulkan. however, the current implantation is incorrect during dtype
    converting. This PR contains:
    
    1. Simplify clz for index calculation (happens in vulkan sort)
    2. Fix clz for data type conversion
---
 python/tvm/target/detect_target.py                   |  3 ++-
 src/arith/rewrite_simplify.cc                        | 11 +++++++++++
 src/tir/ir/data_type_rewriter.cc                     | 11 +++++++++++
 tests/python/arith/test_arith_rewrite_simplify.py    | 20 ++++++++++++++++++--
 .../test_tir_transform_force_narrow_index_to_i32.py  | 19 +++++++++++++++++++
 5 files changed, 61 insertions(+), 3 deletions(-)

diff --git a/python/tvm/target/detect_target.py 
b/python/tvm/target/detect_target.py
index aada611642..a2fe5e1f8b 100644
--- a/python/tvm/target/detect_target.py
+++ b/python/tvm/target/detect_target.py
@@ -67,8 +67,9 @@ def _detect_vulkan(dev: Device) -> Target:
             "max_shared_memory_per_block": dev.max_shared_memory_per_block,
             "thread_warp_size": dev.warp_size,
             "supports_float16": f_get_target_property(dev, "supports_float16"),
-            "supports_int16": f_get_target_property(dev, "supports_int16"),
             "supports_int8": f_get_target_property(dev, "supports_int8"),
+            "supports_int16": f_get_target_property(dev, "supports_int16"),
+            "supports_int64": f_get_target_property(dev, "supports_int64"),
             "supports_16bit_buffer": f_get_target_property(dev, 
"supports_16bit_buffer"),
         }
     )
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index e7e58a80fc..a4602bb8b9 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -2250,6 +2250,17 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
CallNode* op) {
         }
       }
     }
+  } else if (op->op.same_as(Op::Get("tir.clz"))) {
+    if (const auto* arg_int = op->args[0].as<IntImmNode>()) {
+      int bits = arg_int->dtype.bits();
+      if (arg_int->value == 0) return make_const(op->dtype, bits);
+      for (int i = bits - 1; i >= 0; --i) {
+        if ((int64_t(1) << i) & arg_int->value) {
+          return IntImm(op->dtype, bits - i - 1);
+        }
+      }
+      LOG(FATAL) << "Should not reach here";
+    }
   }
 
   if (op->op.same_as(tir::builtin::likely())) {
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index 3461597b8e..a613b8d4bb 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -215,6 +215,7 @@ TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, 
operator>=);
 #undef TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
 
 PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) {
+  Call before = GetRef<Call>(op);
   PrimExpr e = StmtExprMutator::VisitExpr_(op);
   op = e.as<CallNode>();
   static const Op& builtin_pow_ = Op::Get("tir.pow");
@@ -234,6 +235,16 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) 
{
     return pow(op->args[0], op->args[1]);
   } else if (op->op.same_as(builtin::if_then_else())) {
     return if_then_else(op->args[0], op->args[1], op->args[2]);
+  } else if (op->op.same_as(Op::Get("tir.clz"))) {
+    DataType before_dtype = before->args[0]->dtype;
+    DataType after_dtype = op->args[0]->dtype;
+    CHECK(before_dtype.is_int() && (before_dtype.bits() == 32 || 
before_dtype.bits() == 64))
+        << "clz only supports 32 or 64 bit integer types, but get type before 
legalizing: "
+        << before_dtype;
+    CHECK(after_dtype.is_int() && (after_dtype.bits() == 32 || 
after_dtype.bits() == 64))
+        << "clz only supports 32 or 64 bit integer types, but get type after 
legalizing: "
+        << after_dtype;
+    return e - after_dtype.bits() + before_dtype.bits();
   }
   return e;
 }
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py 
b/tests/python/arith/test_arith_rewrite_simplify.py
index 9cc44aa6a2..6180167555 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -20,9 +20,12 @@ import inspect
 import pytest
 
 import tvm
+import tvm.testing
 from tvm import te, tir
-
-from tvm.tir import truncdiv as tdiv, truncmod as tmod, floordiv as fld, 
floormod as flm
+from tvm.tir import floordiv as fld
+from tvm.tir import floormod as flm
+from tvm.tir import truncdiv as tdiv
+from tvm.tir import truncmod as tmod
 
 
 class TestCase:
@@ -1150,5 +1153,18 @@ class TestIfThenElse(BaseCompare):
     )
 
 
+class TestCLZ(BaseCompare):
+    test_case = tvm.testing.parameter(
+        TestCase(tvm.tir.call_intrin("int32", "tir.clz", 0), 32),
+        TestCase(tvm.tir.call_intrin("int32", "tir.clz", 1), 31),
+        TestCase(tvm.tir.call_intrin("int32", "tir.clz", 2), 30),
+        TestCase(tvm.tir.call_intrin("int32", "tir.clz", 128), 24),
+        TestCase(tvm.tir.call_intrin("int32", "tir.clz", 
tvm.tir.IntImm("int64", 0)), 64),
+        TestCase(tvm.tir.call_intrin("int32", "tir.clz", 
tvm.tir.IntImm("int64", 1)), 63),
+        TestCase(tvm.tir.call_intrin("int32", "tir.clz", 
tvm.tir.IntImm("int64", 2)), 62),
+        TestCase(tvm.tir.call_intrin("int32", "tir.clz", 
tvm.tir.IntImm("int64", 128)), 56),
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git 
a/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py 
b/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py
index c1b81853de..0be0e5fbb5 100644
--- a/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py
+++ b/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py
@@ -259,5 +259,24 @@ def test_pod_params_and_select():
     tvm.ir.assert_structural_equal(Expected, after)
 
 
+def test_clz():
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def main(B: T.Buffer((T.int64(4),), "int32")):
+            for i in T.serial(T.int64(4)):
+                B[i] = T.clz(i)
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def main(B: T.Buffer((4,), "int32")):
+            for i in range(4):
+                B[i] = T.clz(i) - 32 + 64
+
+    after = tvm.tir.transform.ForceNarrowIndexToInt32()(Before)
+    tvm.ir.assert_structural_equal(Expected, after)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to