This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 95cb0de27a [VULKAN] Fix CLZ support for Vulkan (#16858) 95cb0de27a is described below commit 95cb0de27a8bcfe0586f38d8b0d2da955cf01432 Author: Siyuan Feng <hzfen...@sjtu.edu.cn> AuthorDate: Wed Apr 10 20:21:20 2024 +0800 [VULKAN] Fix CLZ support for Vulkan (#16858) CLZ (counting leading zeros) is used for improving ceil_log2 performance on vulkan. however, the current implantation is incorrect during dtype converting. This PR contains: 1. Simplify clz for index calculation (happens in vulkan sort) 2. Fix clz for data type conversion --- python/tvm/target/detect_target.py | 3 ++- src/arith/rewrite_simplify.cc | 11 +++++++++++ src/tir/ir/data_type_rewriter.cc | 11 +++++++++++ tests/python/arith/test_arith_rewrite_simplify.py | 20 ++++++++++++++++++-- .../test_tir_transform_force_narrow_index_to_i32.py | 19 +++++++++++++++++++ 5 files changed, 61 insertions(+), 3 deletions(-) diff --git a/python/tvm/target/detect_target.py b/python/tvm/target/detect_target.py index aada611642..a2fe5e1f8b 100644 --- a/python/tvm/target/detect_target.py +++ b/python/tvm/target/detect_target.py @@ -67,8 +67,9 @@ def _detect_vulkan(dev: Device) -> Target: "max_shared_memory_per_block": dev.max_shared_memory_per_block, "thread_warp_size": dev.warp_size, "supports_float16": f_get_target_property(dev, "supports_float16"), - "supports_int16": f_get_target_property(dev, "supports_int16"), "supports_int8": f_get_target_property(dev, "supports_int8"), + "supports_int16": f_get_target_property(dev, "supports_int16"), + "supports_int64": f_get_target_property(dev, "supports_int64"), "supports_16bit_buffer": f_get_target_property(dev, "supports_16bit_buffer"), } ) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index e7e58a80fc..a4602bb8b9 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -2250,6 +2250,17 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { } } } + } else if (op->op.same_as(Op::Get("tir.clz"))) { + if (const auto* arg_int = op->args[0].as<IntImmNode>()) { + int bits = arg_int->dtype.bits(); + if (arg_int->value == 0) return make_const(op->dtype, bits); + for (int i = bits - 1; i >= 0; --i) { + if ((int64_t(1) << i) & arg_int->value) { + return IntImm(op->dtype, bits - i - 1); + } + } + LOG(FATAL) << "Should not reach here"; + } } if (op->op.same_as(tir::builtin::likely())) { diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 3461597b8e..a613b8d4bb 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -215,6 +215,7 @@ TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); #undef TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) { + Call before = GetRef<Call>(op); PrimExpr e = StmtExprMutator::VisitExpr_(op); op = e.as<CallNode>(); static const Op& builtin_pow_ = Op::Get("tir.pow"); @@ -234,6 +235,16 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) { return pow(op->args[0], op->args[1]); } else if (op->op.same_as(builtin::if_then_else())) { return if_then_else(op->args[0], op->args[1], op->args[2]); + } else if (op->op.same_as(Op::Get("tir.clz"))) { + DataType before_dtype = before->args[0]->dtype; + DataType after_dtype = op->args[0]->dtype; + CHECK(before_dtype.is_int() && (before_dtype.bits() == 32 || before_dtype.bits() == 64)) + << "clz only supports 32 or 64 bit integer types, but get type before legalizing: " + << before_dtype; + CHECK(after_dtype.is_int() && (after_dtype.bits() == 32 || after_dtype.bits() == 64)) + << "clz only supports 32 or 64 bit integer types, but get type after legalizing: " + << after_dtype; + return e - after_dtype.bits() + before_dtype.bits(); } return e; } diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 9cc44aa6a2..6180167555 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -20,9 +20,12 @@ import inspect import pytest import tvm +import tvm.testing from tvm import te, tir - -from tvm.tir import truncdiv as tdiv, truncmod as tmod, floordiv as fld, floormod as flm +from tvm.tir import floordiv as fld +from tvm.tir import floormod as flm +from tvm.tir import truncdiv as tdiv +from tvm.tir import truncmod as tmod class TestCase: @@ -1150,5 +1153,18 @@ class TestIfThenElse(BaseCompare): ) +class TestCLZ(BaseCompare): + test_case = tvm.testing.parameter( + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 0), 32), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 1), 31), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 2), 30), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 128), 24), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 0)), 64), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 1)), 63), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 2)), 62), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 128)), 56), + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py b/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py index c1b81853de..0be0e5fbb5 100644 --- a/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py +++ b/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py @@ -259,5 +259,24 @@ def test_pod_params_and_select(): tvm.ir.assert_structural_equal(Expected, after) +def test_clz(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main(B: T.Buffer((T.int64(4),), "int32")): + for i in T.serial(T.int64(4)): + B[i] = T.clz(i) + + @tvm.script.ir_module + class Expected: + @T.prim_func + def main(B: T.Buffer((4,), "int32")): + for i in range(4): + B[i] = T.clz(i) - 32 + 64 + + after = tvm.tir.transform.ForceNarrowIndexToInt32()(Before) + tvm.ir.assert_structural_equal(Expected, after) + + if __name__ == "__main__": tvm.testing.main()